Skip to content

Commit 1a1d0d5

Browse files
committed
revert: unrevert #5758
1 parent 557c5c4 commit 1a1d0d5

File tree

4 files changed

+647
-1
lines changed

4 files changed

+647
-1
lines changed

marimo/_runtime/primitives.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,14 @@ def is_pure_function(
197197
if not inspect.isfunction(value):
198198
return False
199199

200+
# If this object wraps another, call the check on the wrapped object.
201+
if hasattr(value, "__wrapped__"):
202+
wrapped = getattr(value, "__wrapped__", None)
203+
if wrapped is not None:
204+
# This still catches impure decorated functions since the impure
205+
# reference will be captured by required refs.
206+
return is_pure_function(ref, wrapped, defs, cache, graph)
207+
200208
# We assume all external module function references to be pure. Cache can
201209
# still be be invalidated by pin_modules attribute. Note this also captures
202210
# cases like functors from an external module.

marimo/_save/hash.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from marimo._runtime.context import ContextNotInitializedError, get_context
2424
from marimo._runtime.dataflow import induced_subgraph
2525
from marimo._runtime.primitives import (
26+
CLONE_PRIMITIVES,
2627
FN_CACHE_TYPE,
28+
build_ref_predicate_for_primitives,
2729
is_data_primitive,
2830
is_data_primitive_container,
2931
is_primitive,
@@ -90,6 +92,24 @@ def process(code_obj: CodeType) -> None:
9092
return hash_alg.digest()
9193

9294

95+
def hash_wrapped_functions(
96+
wrapped: Callable[..., Any], hash_type: str = DEFAULT_HASH
97+
) -> bytes:
98+
seen = set()
99+
100+
# there is a chance for a circular reference
101+
# likely manually created, but easy to guard against.
102+
def process_function(fn: Callable[..., Any]) -> bytes:
103+
fn_hash = hash_module(fn.__code__, hash_type)
104+
if fn_hash not in seen and hasattr(fn, "__wrapped__"):
105+
child_hash = hash_wrapped_functions(fn.__wrapped__, hash_type)
106+
return child_hash + fn_hash
107+
seen.add(fn_hash)
108+
return fn_hash
109+
110+
return process_function(wrapped)
111+
112+
93113
def hash_raw_module(
94114
module: ast.Module, hash_type: str = DEFAULT_HASH
95115
) -> bytes:
@@ -564,6 +584,27 @@ def collect_for_content_hash(
564584
exceptions = []
565585
# By rights, could just fail here - but this final attempt should
566586
# provide better user experience.
587+
#
588+
# Get a transitive closure over the object, and attempt to pickle
589+
# each dependent object.
590+
#
591+
# TODO: Maybe just try dill?
592+
closure = self.graph.get_transitive_references(
593+
unhashable,
594+
predicate=build_ref_predicate_for_primitives(
595+
scope, CLONE_PRIMITIVES
596+
),
597+
)
598+
closure -= set(content_serialization.keys()) | self.execution_refs
599+
unhashable_closure, relevant_serialization, _ = (
600+
self.serialize_and_dequeue_content_refs(
601+
closure - unhashable, scope
602+
)
603+
)
604+
unhashable |= unhashable_closure
605+
content_serialization.update(relevant_serialization)
606+
refs |= unhashable_closure
607+
567608
for ref in unhashable:
568609
try:
569610
_hashed = pickle.dumps(scope[ref])
@@ -778,7 +819,9 @@ def serialize_and_dequeue_content_refs(
778819
elif is_pure_function(
779820
local_ref, value, scope, self.fn_cache, self.graph
780821
):
781-
serial_value = hash_module(value.__code__, self.hash_alg.name)
822+
serial_value = hash_wrapped_functions(
823+
value, self.hash_alg.name
824+
)
782825
# An external module variable is assumed to be pure, with module
783826
# pinning being the mechanism for invalidation.
784827
elif getattr(value, "__module__", "__main__") == "__main__":

tests/_runtime/test_primitives.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""Tests for marimo._runtime.primitives module."""
2+
3+
import functools
4+
from typing import Any
5+
6+
from marimo._runtime.primitives import is_pure_function
7+
8+
9+
class TestWrappedFunctionHandling:
10+
"""Test handling of wrapped functions (decorators) in is_pure_function."""
11+
12+
def test_wrapped_function_follows_wrapped_object(self):
13+
"""Test that is_pure_function follows __wrapped__ attribute to check the underlying function."""
14+
15+
def external_function():
16+
"""A function from external module."""
17+
return 42
18+
19+
external_function.__module__ = "external_module"
20+
21+
# Create a decorator that wraps the function
22+
def decorator(func):
23+
@functools.wraps(func)
24+
def wrapper():
25+
return func()
26+
27+
return wrapper
28+
29+
decorated_function = decorator(external_function)
30+
31+
# Mock globals dict
32+
defs = {"decorated_function": decorated_function}
33+
cache = {}
34+
35+
# Should follow the wrapped function and determine purity based on that
36+
result = is_pure_function(
37+
"decorated_function", decorated_function, defs, cache
38+
)
39+
40+
# Should be True since the wrapped function is external
41+
assert result is True
42+
43+
def test_nested_wrapped_functions(self):
44+
"""Test handling of functions with multiple layers of wrapping."""
45+
46+
def original_function():
47+
return "original"
48+
49+
original_function.__module__ = "external_module"
50+
51+
def decorator1(func: Any) -> Any:
52+
@functools.wraps(func)
53+
def wrapper1(*args: Any, **kwargs: Any):
54+
return func(*args, **kwargs)
55+
56+
return wrapper1
57+
58+
def decorator2(func: Any) -> Any:
59+
@functools.wraps(func)
60+
def wrapper2(*args: Any, **kwargs: Any):
61+
return func(*args, **kwargs)
62+
63+
return wrapper2
64+
65+
# Apply multiple decorators
66+
@decorator2
67+
@decorator1
68+
def nested_decorated():
69+
return original_function()
70+
71+
defs = {"nested_decorated": nested_decorated}
72+
cache = {}
73+
74+
# Should handle nested wrapping correctly
75+
result = is_pure_function(
76+
"nested_decorated", nested_decorated, defs, cache
77+
)
78+
assert isinstance(result, bool)
79+
80+
def test_wrapped_attribute_is_none(self):
81+
"""Test handling when __wrapped__ exists but is None."""
82+
83+
def function_with_none_wrapped():
84+
return 42
85+
86+
function_with_none_wrapped.__module__ = "external_module"
87+
88+
# Set __wrapped__ to None
89+
function_with_none_wrapped.__wrapped__ = None
90+
91+
defs = {"function_with_none_wrapped": function_with_none_wrapped}
92+
cache = {}
93+
94+
# Should handle None __wrapped__ gracefully
95+
result = is_pure_function(
96+
"function_with_none_wrapped",
97+
function_with_none_wrapped,
98+
defs,
99+
cache,
100+
)
101+
assert result is True # External function should be pure
102+
103+
def test_main_module_wrapped_function(self):
104+
"""Test wrapped function from __main__ module."""
105+
106+
def internal_function():
107+
return 42
108+
109+
internal_function.__module__ = "__main__"
110+
111+
def decorator(func):
112+
@functools.wraps(func)
113+
def wrapper():
114+
return func()
115+
116+
return wrapper
117+
118+
decorated_function = decorator(internal_function)
119+
120+
defs = {
121+
"decorated_function": decorated_function,
122+
"internal_function": internal_function,
123+
}
124+
cache = {}
125+
126+
# Should follow wrapped function and check if it's pure
127+
result = is_pure_function(
128+
"decorated_function", decorated_function, defs, cache
129+
)
130+
131+
# Should be True since the wrapped function is also pure (no external refs)
132+
assert result is True

0 commit comments

Comments
 (0)