Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion frontend/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
var scrollbarHeight = 20; // Max between windows, mac, and linux

function setHeight() {
// Guard against race condition where iframe isn't ready
if (!obj.contentWindow?.document?.documentElement) {
return;
}
var element = obj.contentWindow.document.documentElement;
// If there is no vertical scrollbar, we don't need to resize the iframe
if (element.scrollHeight === element.clientHeight) {
Expand All @@ -49,7 +53,10 @@
const resizeObserver = new ResizeObserver((entries) => {
setHeight();
});
resizeObserver.observe(obj.contentWindow.document.body);
// Only observe if iframe content is ready
if (obj.contentWindow?.document?.body) {
resizeObserver.observe(obj.contentWindow.document.body);
}
}
</script>
<marimo-filename hidden>{{ filename }}</marimo-filename>
Expand Down
151 changes: 116 additions & 35 deletions marimo/_plugins/stateless/mpl/_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import io
import mimetypes
import os
import signal
import threading
import time
from pathlib import Path
Expand All @@ -32,7 +31,6 @@
from marimo._runtime.runtime import app_meta
from marimo._server.utils import find_free_port
from marimo._utils.platform import is_pyodide
from marimo._utils.signals import get_signals

LOGGER = _loggers.marimo_logger()

Expand All @@ -58,12 +56,100 @@ def get(self, figure_id: str) -> FigureManagerWebAgg:
return self.figure_managers[str(figure_id)]

def remove(self, manager: FigureManagerWebAgg) -> None:
del self.figure_managers[str(manager.num)]
try:
del self.figure_managers[str(manager.num)]
except KeyError:
# Figure already removed, this can happen during server restart
LOGGER.debug(f"Figure {manager.num} already removed from manager")


figure_managers = FigureManagers()


class MplServerManager:
"""Manages the matplotlib server lifecycle with lazy recovery."""

def __init__(self) -> None:
self.process: Optional[threading.Thread] = None
self._restart_lock = threading.Lock()

def is_running(self) -> bool:
"""Check if the server thread is still running."""
if self.process is None:
return False
# Check if the thread is still alive
return self.process.is_alive()

def start(
self,
app_host: Optional[str] = None,
free_port: Optional[int] = None,
secure_host: Optional[bool] = None,
) -> Starlette:
"""Start the matplotlib server and return the Starlette app."""
import uvicorn

host = app_host if app_host is not None else _get_host()
secure = secure_host if secure_host is not None else _get_secure()

# Find a free port, with some randomization to avoid conflicts
import random

base_port = 10_000 + random.randint(0, 1000) # Add some randomization
port = (
free_port if free_port is not None else find_free_port(base_port)
)
app = create_application()
app.state.host = host
app.state.port = port
app.state.secure = secure

def start_server() -> None:
# Don't try to set signal handlers in background thread
# The original signal handlers will remain in place
server = uvicorn.Server(
uvicorn.Config(
app=app,
port=port,
host=host,
log_level="critical",
)
)
try:
server.run()
except Exception as e:
LOGGER.error(f"Matplotlib server failed: {e}")
# Thread will exit, making is_running() return False
# This allows for automatic restart on next use

# Start server in background thread
thread = threading.Thread(target=start_server, daemon=True)
thread.start()

# Store thread reference to track server
self.process = thread

# TODO: Consider if we need this sleep from original code
# Original comment: "arbitrary wait 200ms for the server to start"
# With lazy recovery, this may no longer be necessary
time.sleep(0.02)

LOGGER.info(f"Started matplotlib server at {host}:{port}")
return app

def stop(self) -> None:
"""Stop the server process."""
if self.process is not None:
# Note: We can't easily terminate uvicorn server from here,
# but marking process as None will cause is_running() to return False
# and trigger a restart on next use
self.process = None
LOGGER.debug("Marked matplotlib server for restart")


_server_manager = MplServerManager()


def _get_host() -> str:
"""
Get the host from environment variable or fall back to localhost.
Expand Down Expand Up @@ -205,7 +291,7 @@ def send_binary(self, blob: Any) -> None:
await websocket.send_json(
{
"type": "error",
"message": f"Figure with id '{figure_id}' not found",
"message": f"Figure with id '{figure_id}' not found. The matplotlib server may have restarted. Please re-run the cell containing this plot.",
}
)
await websocket.close()
Expand All @@ -229,7 +315,10 @@ async def receive() -> None:
except Exception as e:
if websocket.application_state != WebSocketState.DISCONNECTED:
await websocket.send_json(
{"type": "error", "message": str(e)}
{
"type": "error",
"message": f"WebSocket receive error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
}
)
finally:
if websocket.application_state != WebSocketState.DISCONNECTED:
Expand All @@ -249,7 +338,10 @@ async def send() -> None:
except Exception as e:
if websocket.application_state != WebSocketState.DISCONNECTED:
await websocket.send_json(
{"type": "error", "message": str(e)}
{
"type": "error",
"message": f"WebSocket send error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
}
)
finally:
if websocket.application_state != WebSocketState.DISCONNECTED:
Expand All @@ -259,7 +351,12 @@ async def send() -> None:
await asyncio.gather(receive(), send())
except Exception as e:
if websocket.application_state != WebSocketState.DISCONNECTED:
await websocket.send_json({"type": "error", "message": str(e)})
await websocket.send_json(
{
"type": "error",
"message": f"WebSocket connection error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
}
)
await websocket.close()

return Starlette(
Expand Down Expand Up @@ -295,36 +392,20 @@ def get_or_create_application(
) -> Starlette:
global _app

import uvicorn

if _app is None:
host = app_host if app_host is not None else _get_host()
port = free_port if free_port is not None else find_free_port(10_000)
secure = secure_host if secure_host is not None else _get_secure()
app = create_application()
app.state.host = host
app.state.port = port
app.state.secure = secure
_app = app

def start_server() -> None:
signal_handlers = get_signals()
uvicorn.Server(
uvicorn.Config(
app=app,
port=port,
host=host,
log_level="critical",
# Thread-safe lazy restart logic
with _server_manager._restart_lock:
if _app is None or not _server_manager.is_running():
if _app is not None:
LOGGER.info(
"Matplotlib server appears to have died, restarting..."
)
).run()
for signo, handler in signal_handlers.items():
signal.signal(signo, handler)

threading.Thread(target=start_server).start()
_server_manager.stop()
# Clear existing figure managers to prevent stale state
figure_managers.figure_managers.clear()
_app = None

# arbitrary wait 200ms for the server to start
# this only happens once per session
time.sleep(0.02)
# Start new server
_app = _server_manager.start(app_host, free_port, secure_host)

return _app

Expand Down
31 changes: 24 additions & 7 deletions marimo/_server/api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,30 @@ async def __call__(
content=request.stream(),
)

rp_resp = await client.send(rp_req, stream=True)
response = StreamingResponse(
rp_resp.aiter_raw(),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)
try:
rp_resp = await client.send(rp_req, stream=True)
response = StreamingResponse(
rp_resp.aiter_raw(),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)
except ConnectionRefusedError:
# Check if this is a matplotlib server request (contains /mpl/ in path)
if "/mpl/" in request.url.path:
# Log at debug level and return a helpful error response
# instead of letting the exception bubble up
LOGGER.debug(
f"Matplotlib server connection refused for {request.url.path}"
)
return Response(
content="Matplotlib server is not available. Please re-run the cell containing this plot.",
status_code=503,
media_type="text/plain",
)
else:
# For non-matplotlib requests, re-raise the exception
raise
await response(scope, receive, send)

async def _proxy_websocket(
Expand Down
79 changes: 79 additions & 0 deletions tests/_plugins/stateless/test_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,82 @@ def test_template_contains_html_structure() -> None:
assert '<div id="figure"></div>' in result
assert "12345" in result
assert "9000" in result


def test_mpl_server_manager() -> None:
"""Test MplServerManager basic functionality"""
from marimo._plugins.stateless.mpl._mpl import MplServerManager

manager = MplServerManager()

# Initially should not be running
assert not manager.is_running()

# Mock threading.Thread to avoid actually starting a server
with patch("threading.Thread") as mock_thread_class:
mock_thread = MagicMock()
mock_thread.is_alive.return_value = True
mock_thread_class.return_value = mock_thread

# Start should create and return an app
app = manager.start(app_host="localhost", free_port=12345)

# Should now be running
assert manager.is_running()

# Verify app state
assert app.state.host == "localhost"
assert app.state.port == 12345

# Thread should have been started
mock_thread.start.assert_called_once()

# Stop should mark as not running
manager.stop()
assert not manager.is_running()


def test_get_or_create_application_with_restart() -> None:
"""Test get_or_create_application handles server restart"""
from marimo._plugins.stateless.mpl._mpl import (
_server_manager,
figure_managers,
get_or_create_application,
)

# Clear any existing state
globals()["_app"] = None
figure_managers.figure_managers.clear()

with patch("threading.Thread") as mock_thread_class:
# First thread: running
mock_thread1 = MagicMock()
mock_thread1.is_alive.return_value = True

# Second thread: also running (for restart)
mock_thread2 = MagicMock()
mock_thread2.is_alive.return_value = True

mock_thread_class.side_effect = [mock_thread1, mock_thread2]

# First call should create app
app1 = get_or_create_application()
assert app1 is not None
assert _server_manager.is_running()

# Simulate server death
mock_thread1.is_alive.return_value = False

# Add a figure to test cleanup
figure_managers.figure_managers["test"] = MagicMock()

# Next call should restart server and clear figures
app2 = get_or_create_application()
assert app2 is not None
assert app2 is not app1 # Should be a new app instance
assert (
len(figure_managers.figure_managers) == 0
) # Figures should be cleared

# Should have started two threads (original + restart)
assert mock_thread_class.call_count == 2
Loading