Skip to content

Commit 0af3e32

Browse files
authored
fix: matplotlib interactive stability fixes (#6550)
## 📝 Summary fixes #5577 #5847 QOL fixes for mo.mpl.interactive 1. Safari has a race condition that caused the iframe resizing to break, and subsequently cause issues from there. 2. If the background service failed, there was no way to reconnect. Move over port exhaustion seemed a bit too easy. As such, introducing a global manager that seems to add a bit more stability
1 parent 7d6efb9 commit 0af3e32

File tree

5 files changed

+262
-43
lines changed

5 files changed

+262
-43
lines changed

frontend/index.html

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
var scrollbarHeight = 20; // Max between windows, mac, and linux
2727

2828
function setHeight() {
29+
// Guard against race condition where iframe isn't ready
30+
if (!obj.contentWindow?.document?.documentElement) {
31+
return;
32+
}
2933
var element = obj.contentWindow.document.documentElement;
3034
// If there is no vertical scrollbar, we don't need to resize the iframe
3135
if (element.scrollHeight === element.clientHeight) {
@@ -49,7 +53,10 @@
4953
const resizeObserver = new ResizeObserver((entries) => {
5054
setHeight();
5155
});
52-
resizeObserver.observe(obj.contentWindow.document.body);
56+
// Only observe if iframe content is ready
57+
if (obj.contentWindow?.document?.body) {
58+
resizeObserver.observe(obj.contentWindow.document.body);
59+
}
5360
}
5461
</script>
5562
<marimo-filename hidden>{{ filename }}</marimo-filename>

marimo/_plugins/stateless/mpl/_mpl.py

Lines changed: 116 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import io
1313
import mimetypes
1414
import os
15-
import signal
1615
import threading
1716
import time
1817
from pathlib import Path
@@ -32,7 +31,6 @@
3231
from marimo._runtime.runtime import app_meta
3332
from marimo._server.utils import find_free_port
3433
from marimo._utils.platform import is_pyodide
35-
from marimo._utils.signals import get_signals
3634

3735
LOGGER = _loggers.marimo_logger()
3836

@@ -58,12 +56,100 @@ def get(self, figure_id: str) -> FigureManagerWebAgg:
5856
return self.figure_managers[str(figure_id)]
5957

6058
def remove(self, manager: FigureManagerWebAgg) -> None:
61-
del self.figure_managers[str(manager.num)]
59+
try:
60+
del self.figure_managers[str(manager.num)]
61+
except KeyError:
62+
# Figure already removed, this can happen during server restart
63+
LOGGER.debug(f"Figure {manager.num} already removed from manager")
6264

6365

6466
figure_managers = FigureManagers()
6567

6668

69+
class MplServerManager:
70+
"""Manages the matplotlib server lifecycle with lazy recovery."""
71+
72+
def __init__(self) -> None:
73+
self.process: Optional[threading.Thread] = None
74+
self._restart_lock = threading.Lock()
75+
76+
def is_running(self) -> bool:
77+
"""Check if the server thread is still running."""
78+
if self.process is None:
79+
return False
80+
# Check if the thread is still alive
81+
return self.process.is_alive()
82+
83+
def start(
84+
self,
85+
app_host: Optional[str] = None,
86+
free_port: Optional[int] = None,
87+
secure_host: Optional[bool] = None,
88+
) -> Starlette:
89+
"""Start the matplotlib server and return the Starlette app."""
90+
import uvicorn
91+
92+
host = app_host if app_host is not None else _get_host()
93+
secure = secure_host if secure_host is not None else _get_secure()
94+
95+
# Find a free port, with some randomization to avoid conflicts
96+
import random
97+
98+
base_port = 10_000 + random.randint(0, 1000) # Add some randomization
99+
port = (
100+
free_port if free_port is not None else find_free_port(base_port)
101+
)
102+
app = create_application()
103+
app.state.host = host
104+
app.state.port = port
105+
app.state.secure = secure
106+
107+
def start_server() -> None:
108+
# Don't try to set signal handlers in background thread
109+
# The original signal handlers will remain in place
110+
server = uvicorn.Server(
111+
uvicorn.Config(
112+
app=app,
113+
port=port,
114+
host=host,
115+
log_level="critical",
116+
)
117+
)
118+
try:
119+
server.run()
120+
except Exception as e:
121+
LOGGER.error(f"Matplotlib server failed: {e}")
122+
# Thread will exit, making is_running() return False
123+
# This allows for automatic restart on next use
124+
125+
# Start server in background thread
126+
thread = threading.Thread(target=start_server, daemon=True)
127+
thread.start()
128+
129+
# Store thread reference to track server
130+
self.process = thread
131+
132+
# TODO: Consider if we need this sleep from original code
133+
# Original comment: "arbitrary wait 200ms for the server to start"
134+
# With lazy recovery, this may no longer be necessary
135+
time.sleep(0.02)
136+
137+
LOGGER.info(f"Started matplotlib server at {host}:{port}")
138+
return app
139+
140+
def stop(self) -> None:
141+
"""Stop the server process."""
142+
if self.process is not None:
143+
# Note: We can't easily terminate uvicorn server from here,
144+
# but marking process as None will cause is_running() to return False
145+
# and trigger a restart on next use
146+
self.process = None
147+
LOGGER.debug("Marked matplotlib server for restart")
148+
149+
150+
_server_manager = MplServerManager()
151+
152+
67153
def _get_host() -> str:
68154
"""
69155
Get the host from environment variable or fall back to localhost.
@@ -205,7 +291,7 @@ def send_binary(self, blob: Any) -> None:
205291
await websocket.send_json(
206292
{
207293
"type": "error",
208-
"message": f"Figure with id '{figure_id}' not found",
294+
"message": f"Figure with id '{figure_id}' not found. The matplotlib server may have restarted. Please re-run the cell containing this plot.",
209295
}
210296
)
211297
await websocket.close()
@@ -229,7 +315,10 @@ async def receive() -> None:
229315
except Exception as e:
230316
if websocket.application_state != WebSocketState.DISCONNECTED:
231317
await websocket.send_json(
232-
{"type": "error", "message": str(e)}
318+
{
319+
"type": "error",
320+
"message": f"WebSocket receive error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
321+
}
233322
)
234323
finally:
235324
if websocket.application_state != WebSocketState.DISCONNECTED:
@@ -249,7 +338,10 @@ async def send() -> None:
249338
except Exception as e:
250339
if websocket.application_state != WebSocketState.DISCONNECTED:
251340
await websocket.send_json(
252-
{"type": "error", "message": str(e)}
341+
{
342+
"type": "error",
343+
"message": f"WebSocket send error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
344+
}
253345
)
254346
finally:
255347
if websocket.application_state != WebSocketState.DISCONNECTED:
@@ -259,7 +351,12 @@ async def send() -> None:
259351
await asyncio.gather(receive(), send())
260352
except Exception as e:
261353
if websocket.application_state != WebSocketState.DISCONNECTED:
262-
await websocket.send_json({"type": "error", "message": str(e)})
354+
await websocket.send_json(
355+
{
356+
"type": "error",
357+
"message": f"WebSocket connection error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
358+
}
359+
)
263360
await websocket.close()
264361

265362
return Starlette(
@@ -295,36 +392,20 @@ def get_or_create_application(
295392
) -> Starlette:
296393
global _app
297394

298-
import uvicorn
299-
300-
if _app is None:
301-
host = app_host if app_host is not None else _get_host()
302-
port = free_port if free_port is not None else find_free_port(10_000)
303-
secure = secure_host if secure_host is not None else _get_secure()
304-
app = create_application()
305-
app.state.host = host
306-
app.state.port = port
307-
app.state.secure = secure
308-
_app = app
309-
310-
def start_server() -> None:
311-
signal_handlers = get_signals()
312-
uvicorn.Server(
313-
uvicorn.Config(
314-
app=app,
315-
port=port,
316-
host=host,
317-
log_level="critical",
395+
# Thread-safe lazy restart logic
396+
with _server_manager._restart_lock:
397+
if _app is None or not _server_manager.is_running():
398+
if _app is not None:
399+
LOGGER.info(
400+
"Matplotlib server appears to have died, restarting..."
318401
)
319-
).run()
320-
for signo, handler in signal_handlers.items():
321-
signal.signal(signo, handler)
322-
323-
threading.Thread(target=start_server).start()
402+
_server_manager.stop()
403+
# Clear existing figure managers to prevent stale state
404+
figure_managers.figure_managers.clear()
405+
_app = None
324406

325-
# arbitrary wait 200ms for the server to start
326-
# this only happens once per session
327-
time.sleep(0.02)
407+
# Start new server
408+
_app = _server_manager.start(app_host, free_port, secure_host)
328409

329410
return _app
330411

marimo/_server/api/middleware.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,35 @@
5353
LOGGER = _loggers.marimo_logger()
5454

5555

56+
def _handle_proxy_connection_error(
57+
_error: ConnectionRefusedError,
58+
path: str,
59+
custom_message: str | None = None,
60+
) -> Response:
61+
"""Handle connection errors for proxy requests to backend services."""
62+
LOGGER.debug(f"Connection refused for {path}")
63+
content = (
64+
custom_message
65+
or "Service is not available. Please try again or restart the service."
66+
)
67+
return Response(
68+
content=content,
69+
status_code=503,
70+
media_type="text/plain",
71+
)
72+
73+
74+
def create_proxy_error_handler(
75+
custom_message: str,
76+
) -> Callable[[ConnectionRefusedError, str], Response]:
77+
"""Create a custom error handler that wraps the default with a custom message."""
78+
79+
def handler(error: ConnectionRefusedError, path: str) -> Response:
80+
return _handle_proxy_connection_error(error, path, custom_message)
81+
82+
return handler
83+
84+
5685
class AuthBackend(AuthenticationBackend):
5786
def __init__(self, should_authenticate: bool = True) -> None:
5887
self.should_authenticate = should_authenticate
@@ -340,11 +369,20 @@ def __init__(
340369
proxy_path: str,
341370
target_url: Union[str, Callable[[str], str]],
342371
path_rewrite: Callable[[str], str] | None = None,
372+
connection_error_handler: Callable[
373+
[ConnectionRefusedError, str], Response
374+
]
375+
| None = None,
343376
) -> None:
344377
self.app = app
345378
self.path = proxy_path.rstrip("/")
346379
self.target_url = target_url
347380
self.path_rewrite = path_rewrite
381+
self.connection_error_handler = (
382+
connection_error_handler
383+
if connection_error_handler
384+
else _handle_proxy_connection_error
385+
)
348386

349387
def _get_target_url(self, path: str) -> str:
350388
"""Get target URL either from rewrite function or default MPL logic."""
@@ -408,13 +446,21 @@ async def __call__(
408446
content=request.stream(),
409447
)
410448

411-
rp_resp = await client.send(rp_req, stream=True)
412-
response = StreamingResponse(
413-
rp_resp.aiter_raw(),
414-
status_code=rp_resp.status_code,
415-
headers=rp_resp.headers,
416-
background=BackgroundTask(rp_resp.aclose),
417-
)
449+
response: Union[StreamingResponse, Response]
450+
try:
451+
rp_resp = await client.send(rp_req, stream=True)
452+
response = StreamingResponse(
453+
rp_resp.aiter_raw(),
454+
status_code=rp_resp.status_code,
455+
headers=rp_resp.headers,
456+
background=BackgroundTask(rp_resp.aclose),
457+
)
458+
except ConnectionRefusedError as e:
459+
if self.connection_error_handler is not None:
460+
response = self.connection_error_handler(e, request.url.path)
461+
else:
462+
raise
463+
418464
await response(scope, receive, send)
419465

420466
async def _proxy_websocket(

marimo/_server/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ProxyMiddleware,
2323
SkewProtectionMiddleware,
2424
TimeoutMiddleware,
25+
create_proxy_error_handler,
2526
)
2627
from marimo._server.api.router import build_routes
2728
from marimo._server.api.status import (
@@ -172,11 +173,16 @@ def mpl_path_rewrite(path: str) -> str:
172173
rest_parts = parts[2:]
173174
return "/" + "/".join(rest_parts) if rest_parts else "/"
174175

176+
mpl_error_handler = create_proxy_error_handler(
177+
"Matplotlib server is not available. Please rerun this cell or restart the service."
178+
)
179+
175180
return Middleware(
176181
ProxyMiddleware,
177182
proxy_path=proxy_path,
178183
target_url=mpl_target_url,
179184
path_rewrite=mpl_path_rewrite,
185+
connection_error_handler=mpl_error_handler,
180186
)
181187

182188

0 commit comments

Comments
 (0)