Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion frontend/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
var scrollbarHeight = 20; // Max between windows, mac, and linux

function setHeight() {
// Guard against race condition where iframe isn't ready
if (!obj.contentWindow?.document?.documentElement) {
return;
}
var element = obj.contentWindow.document.documentElement;
// If there is no vertical scrollbar, we don't need to resize the iframe
if (element.scrollHeight === element.clientHeight) {
Expand All @@ -49,7 +53,10 @@
const resizeObserver = new ResizeObserver((entries) => {
setHeight();
});
resizeObserver.observe(obj.contentWindow.document.body);
// Only observe if iframe content is ready
if (obj.contentWindow?.document?.body) {
resizeObserver.observe(obj.contentWindow.document.body);
}
}
</script>
<marimo-filename hidden>{{ filename }}</marimo-filename>
Expand Down
151 changes: 116 additions & 35 deletions marimo/_plugins/stateless/mpl/_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import io
import mimetypes
import os
import signal
import threading
import time
from pathlib import Path
Expand All @@ -32,7 +31,6 @@
from marimo._runtime.runtime import app_meta
from marimo._server.utils import find_free_port
from marimo._utils.platform import is_pyodide
from marimo._utils.signals import get_signals

LOGGER = _loggers.marimo_logger()

Expand All @@ -58,12 +56,100 @@ def get(self, figure_id: str) -> FigureManagerWebAgg:
return self.figure_managers[str(figure_id)]

def remove(self, manager: FigureManagerWebAgg) -> None:
del self.figure_managers[str(manager.num)]
try:
del self.figure_managers[str(manager.num)]
except KeyError:
# Figure already removed, this can happen during server restart
LOGGER.debug(f"Figure {manager.num} already removed from manager")


figure_managers = FigureManagers()


class MplServerManager:
"""Manages the matplotlib server lifecycle with lazy recovery."""

def __init__(self) -> None:
self.process: Optional[threading.Thread] = None
self._restart_lock = threading.Lock()

def is_running(self) -> bool:
"""Check if the server thread is still running."""
if self.process is None:
return False
# Check if the thread is still alive
return self.process.is_alive()

def start(
self,
app_host: Optional[str] = None,
free_port: Optional[int] = None,
secure_host: Optional[bool] = None,
) -> Starlette:
"""Start the matplotlib server and return the Starlette app."""
import uvicorn

host = app_host if app_host is not None else _get_host()
secure = secure_host if secure_host is not None else _get_secure()

# Find a free port, with some randomization to avoid conflicts
import random

base_port = 10_000 + random.randint(0, 1000) # Add some randomization
port = (
free_port if free_port is not None else find_free_port(base_port)
)
app = create_application()
app.state.host = host
app.state.port = port
app.state.secure = secure

def start_server() -> None:
# Don't try to set signal handlers in background thread
# The original signal handlers will remain in place
server = uvicorn.Server(
uvicorn.Config(
app=app,
port=port,
host=host,
log_level="critical",
)
)
try:
server.run()
except Exception as e:
LOGGER.error(f"Matplotlib server failed: {e}")
# Thread will exit, making is_running() return False
# This allows for automatic restart on next use

# Start server in background thread
thread = threading.Thread(target=start_server, daemon=True)
thread.start()

# Store thread reference to track server
self.process = thread

# TODO: Consider if we need this sleep from original code
# Original comment: "arbitrary wait 200ms for the server to start"
# With lazy recovery, this may no longer be necessary
time.sleep(0.02)

LOGGER.info(f"Started matplotlib server at {host}:{port}")
return app

def stop(self) -> None:
"""Stop the server process."""
if self.process is not None:
# Note: We can't easily terminate uvicorn server from here,
# but marking process as None will cause is_running() to return False
# and trigger a restart on next use
self.process = None
LOGGER.debug("Marked matplotlib server for restart")


_server_manager = MplServerManager()


def _get_host() -> str:
"""
Get the host from environment variable or fall back to localhost.
Expand Down Expand Up @@ -205,7 +291,7 @@ def send_binary(self, blob: Any) -> None:
await websocket.send_json(
{
"type": "error",
"message": f"Figure with id '{figure_id}' not found",
"message": f"Figure with id '{figure_id}' not found. The matplotlib server may have restarted. Please re-run the cell containing this plot.",
}
)
await websocket.close()
Expand All @@ -229,7 +315,10 @@ async def receive() -> None:
except Exception as e:
if websocket.application_state != WebSocketState.DISCONNECTED:
await websocket.send_json(
{"type": "error", "message": str(e)}
{
"type": "error",
"message": f"WebSocket receive error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
}
)
finally:
if websocket.application_state != WebSocketState.DISCONNECTED:
Expand All @@ -249,7 +338,10 @@ async def send() -> None:
except Exception as e:
if websocket.application_state != WebSocketState.DISCONNECTED:
await websocket.send_json(
{"type": "error", "message": str(e)}
{
"type": "error",
"message": f"WebSocket send error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
}
)
finally:
if websocket.application_state != WebSocketState.DISCONNECTED:
Expand All @@ -259,7 +351,12 @@ async def send() -> None:
await asyncio.gather(receive(), send())
except Exception as e:
if websocket.application_state != WebSocketState.DISCONNECTED:
await websocket.send_json({"type": "error", "message": str(e)})
await websocket.send_json(
{
"type": "error",
"message": f"WebSocket connection error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
}
)
await websocket.close()

return Starlette(
Expand Down Expand Up @@ -295,36 +392,20 @@ def get_or_create_application(
) -> Starlette:
global _app

import uvicorn

if _app is None:
host = app_host if app_host is not None else _get_host()
port = free_port if free_port is not None else find_free_port(10_000)
secure = secure_host if secure_host is not None else _get_secure()
app = create_application()
app.state.host = host
app.state.port = port
app.state.secure = secure
_app = app

def start_server() -> None:
signal_handlers = get_signals()
uvicorn.Server(
uvicorn.Config(
app=app,
port=port,
host=host,
log_level="critical",
# Thread-safe lazy restart logic
with _server_manager._restart_lock:
if _app is None or not _server_manager.is_running():
if _app is not None:
LOGGER.info(
"Matplotlib server appears to have died, restarting..."
)
).run()
for signo, handler in signal_handlers.items():
signal.signal(signo, handler)

threading.Thread(target=start_server).start()
_server_manager.stop()
# Clear existing figure managers to prevent stale state
figure_managers.figure_managers.clear()
_app = None

# arbitrary wait 200ms for the server to start
# this only happens once per session
time.sleep(0.02)
# Start new server
_app = _server_manager.start(app_host, free_port, secure_host)

return _app

Expand Down
59 changes: 52 additions & 7 deletions marimo/_server/api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,35 @@
LOGGER = _loggers.marimo_logger()


def _handle_proxy_connection_error(
_error: ConnectionRefusedError,
path: str,
custom_message: str | None = None,
) -> Response:
"""Handle connection errors for proxy requests to backend services."""
LOGGER.debug(f"Connection refused for {path}")
content = (
custom_message
or "Service is not available. Please try again or restart the service."
)
return Response(
content=content,
status_code=503,
media_type="text/plain",
)


def create_proxy_error_handler(
custom_message: str,
) -> Callable[[ConnectionRefusedError, str], Response]:
"""Create a custom error handler that wraps the default with a custom message."""

def handler(error: ConnectionRefusedError, path: str) -> Response:
return _handle_proxy_connection_error(error, path, custom_message)

return handler


class AuthBackend(AuthenticationBackend):
def __init__(self, should_authenticate: bool = True) -> None:
self.should_authenticate = should_authenticate
Expand Down Expand Up @@ -340,11 +369,20 @@ def __init__(
proxy_path: str,
target_url: Union[str, Callable[[str], str]],
path_rewrite: Callable[[str], str] | None = None,
connection_error_handler: Callable[
[ConnectionRefusedError, str], Response
]
| None = None,
) -> None:
self.app = app
self.path = proxy_path.rstrip("/")
self.target_url = target_url
self.path_rewrite = path_rewrite
self.connection_error_handler = (
connection_error_handler
if connection_error_handler
else _handle_proxy_connection_error
)

def _get_target_url(self, path: str) -> str:
"""Get target URL either from rewrite function or default MPL logic."""
Expand Down Expand Up @@ -408,13 +446,20 @@ async def __call__(
content=request.stream(),
)

rp_resp = await client.send(rp_req, stream=True)
response = StreamingResponse(
rp_resp.aiter_raw(),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)
try:
rp_resp = await client.send(rp_req, stream=True)
response = StreamingResponse(
rp_resp.aiter_raw(),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)
except ConnectionRefusedError as e:
if self.connection_error_handler:
response = self.connection_error_handler(e, request.url.path)
else:
raise

await response(scope, receive, send)

async def _proxy_websocket(
Expand Down
6 changes: 6 additions & 0 deletions marimo/_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ProxyMiddleware,
SkewProtectionMiddleware,
TimeoutMiddleware,
create_proxy_error_handler,
)
from marimo._server.api.router import build_routes
from marimo._server.api.status import (
Expand Down Expand Up @@ -172,11 +173,16 @@ def mpl_path_rewrite(path: str) -> str:
rest_parts = parts[2:]
return "/" + "/".join(rest_parts) if rest_parts else "/"

mpl_error_handler = create_proxy_error_handler(
"Matplotlib server is not available. Please rerun this cell or restart the service."
)

return Middleware(
ProxyMiddleware,
proxy_path=proxy_path,
target_url=mpl_target_url,
path_rewrite=mpl_path_rewrite,
connection_error_handler=mpl_error_handler,
)


Expand Down
Loading
Loading