Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions marimo/_runtime/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def is_pure_function(
if not inspect.isfunction(value):
return False

# If this object wraps another, call the check on the wrapped object.
if hasattr(value, "__wrapped__"):
wrapped = getattr(value, "__wrapped__", None)
if wrapped is not None:
# This still catches impure decorated functions since the impure
# reference will be captured by required refs.
return is_pure_function(ref, wrapped, defs, cache, graph)

# We assume all external module function references to be pure. Cache can
# still be be invalidated by pin_modules attribute. Note this also captures
# cases like functors from an external module.
Expand Down
45 changes: 44 additions & 1 deletion marimo/_save/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from marimo._runtime.context import ContextNotInitializedError, get_context
from marimo._runtime.dataflow import induced_subgraph
from marimo._runtime.primitives import (
CLONE_PRIMITIVES,
FN_CACHE_TYPE,
build_ref_predicate_for_primitives,
is_data_primitive,
is_data_primitive_container,
is_primitive,
Expand Down Expand Up @@ -90,6 +92,24 @@ def process(code_obj: CodeType) -> None:
return hash_alg.digest()


def hash_wrapped_functions(
wrapped: Callable[..., Any], hash_type: str = DEFAULT_HASH
) -> bytes:
seen = set()

# there is a chance for a circular reference
# likely manually created, but easy to guard against.
def process_function(fn: Callable[..., Any]) -> bytes:
fn_hash = hash_module(fn.__code__, hash_type)
if fn_hash not in seen and hasattr(fn, "__wrapped__"):
child_hash = hash_wrapped_functions(fn.__wrapped__, hash_type)
return child_hash + fn_hash
seen.add(fn_hash)
return fn_hash

return process_function(wrapped)


def hash_raw_module(
module: ast.Module, hash_type: str = DEFAULT_HASH
) -> bytes:
Expand Down Expand Up @@ -564,6 +584,27 @@ def collect_for_content_hash(
exceptions = []
# By rights, could just fail here - but this final attempt should
# provide better user experience.
#
# Get a transitive closure over the object, and attempt to pickle
# each dependent object.
#
# TODO: Maybe just try dill?
closure = self.graph.get_transitive_references(
unhashable,
predicate=build_ref_predicate_for_primitives(
scope, CLONE_PRIMITIVES
),
)
closure -= set(content_serialization.keys()) | self.execution_refs
unhashable_closure, relevant_serialization, _ = (
self.serialize_and_dequeue_content_refs(
closure - unhashable, scope
)
)
unhashable |= unhashable_closure
content_serialization.update(relevant_serialization)
refs |= unhashable_closure

for ref in unhashable:
try:
_hashed = pickle.dumps(scope[ref])
Expand Down Expand Up @@ -774,7 +815,9 @@ def serialize_and_dequeue_content_refs(
elif is_pure_function(
local_ref, value, scope, self.fn_cache, self.graph
):
serial_value = hash_module(value.__code__, self.hash_alg.name)
serial_value = hash_wrapped_functions(
value, self.hash_alg.name
)
# An external module variable is assumed to be pure, with module
# pinning being the mechanism for invalidation.
elif getattr(value, "__module__", "__main__") == "__main__":
Expand Down
132 changes: 132 additions & 0 deletions tests/_runtime/test_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Tests for marimo._runtime.primitives module."""

import functools
from typing import Any

from marimo._runtime.primitives import is_pure_function


class TestWrappedFunctionHandling:
"""Test handling of wrapped functions (decorators) in is_pure_function."""

def test_wrapped_function_follows_wrapped_object(self):
"""Test that is_pure_function follows __wrapped__ attribute to check the underlying function."""

def external_function():
"""A function from external module."""
return 42

external_function.__module__ = "external_module"

# Create a decorator that wraps the function
def decorator(func):
@functools.wraps(func)
def wrapper():
return func()

return wrapper

decorated_function = decorator(external_function)

# Mock globals dict
defs = {"decorated_function": decorated_function}
cache = {}

# Should follow the wrapped function and determine purity based on that
result = is_pure_function(
"decorated_function", decorated_function, defs, cache
)

# Should be True since the wrapped function is external
assert result is True

def test_nested_wrapped_functions(self):
"""Test handling of functions with multiple layers of wrapping."""

def original_function():
return "original"

original_function.__module__ = "external_module"

def decorator1(func: Any) -> Any:
@functools.wraps(func)
def wrapper1(*args: Any, **kwargs: Any):
return func(*args, **kwargs)

return wrapper1

def decorator2(func: Any) -> Any:
@functools.wraps(func)
def wrapper2(*args: Any, **kwargs: Any):
return func(*args, **kwargs)

return wrapper2

# Apply multiple decorators
@decorator2
@decorator1
def nested_decorated():
return original_function()

defs = {"nested_decorated": nested_decorated}
cache = {}

# Should handle nested wrapping correctly
result = is_pure_function(
"nested_decorated", nested_decorated, defs, cache
)
assert isinstance(result, bool)

def test_wrapped_attribute_is_none(self):
"""Test handling when __wrapped__ exists but is None."""

def function_with_none_wrapped():
return 42

function_with_none_wrapped.__module__ = "external_module"

# Set __wrapped__ to None
function_with_none_wrapped.__wrapped__ = None

defs = {"function_with_none_wrapped": function_with_none_wrapped}
cache = {}

# Should handle None __wrapped__ gracefully
result = is_pure_function(
"function_with_none_wrapped",
function_with_none_wrapped,
defs,
cache,
)
assert result is True # External function should be pure

def test_main_module_wrapped_function(self):
"""Test wrapped function from __main__ module."""

def internal_function():
return 42

internal_function.__module__ = "__main__"

def decorator(func):
@functools.wraps(func)
def wrapper():
return func()

return wrapper

decorated_function = decorator(internal_function)

defs = {
"decorated_function": decorated_function,
"internal_function": internal_function,
}
cache = {}

# Should follow wrapped function and check if it's pure
result = is_pure_function(
"decorated_function", decorated_function, defs, cache
)

# Should be True since the wrapped function is also pure (no external refs)
assert result is True
Loading
Loading