diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index cc816075b..c20f859a7 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -17,6 +17,7 @@ import dataclasses import functools import inspect +import types import typing as tp from flax import struct @@ -676,11 +677,25 @@ def copy_var_structure(tree: A) -> A: ) def check_no_aliases( - fn_name: str, /, *, check: tp.Iterable[str] = (), **kwargs + fn_name: str, + /, + *, + check: tp.Iterable[str] = (), + capture_ids: frozenset[int] = frozenset(), + **kwargs, ) -> dict[jax.tree_util.KeyPath, variablelib.Variable]: + """Check that no Variable appears more than once in the given collections. + + When ``capture_ids`` is provided (ids of Variables belonging to closure + captures), a Variable that was first seen as a capture is allowed to appear + exactly once more as an explicit argument (deduplication case). A Variable + seen twice in non-capture positions, or three or more times total, always + raises. + """ container = labeled(**kwargs) is_leaf = lambda x: isinstance(x, variablelib.Variable) seen: dict[int, jax.tree_util.KeyPath] = {} + dedup_used: set[int] = set() # capture_ids that have matched one explicit arg all_variables: dict[jax.tree_util.KeyPath, variablelib.Variable] = {} for path, leaf in jax.tree.leaves_with_path(container, is_leaf=is_leaf): if not isinstance(leaf, variablelib.Variable): @@ -700,6 +715,12 @@ def check_no_aliases( var_id = id(leaf) if var_id in seen: + if var_id in capture_ids and var_id not in dedup_used: + # First duplicate: capture prepended + same Variable passed explicitly. + # Allow this once (deduplication); record so a third copy still errors. + dedup_used.add(var_id) + all_variables[path] = leaf + continue path_str = jax.tree_util.keystr(path) seen_path_str = jax.tree_util.keystr(seen[var_id]) raise ValueError( @@ -1119,6 +1140,97 @@ def to_masked(tree, all_updates: OrderedDict[tp.Any, Updates]): is_leaf=lambda x: x is None ) +@dataclasses.dataclass(frozen=True) +class CapturedInfo: + """Info about closure-captured graph nodes/Variables.""" + nodes: tuple[object, ...] + freevar_indices: tuple[int, ...] + + def __len__(self): + return len(self.nodes) + + def __bool__(self): + return len(self.nodes) > 0 + + @property + def variable_ids(self) -> frozenset[int]: + """Ids of all Variables contained within the captured nodes.""" + ids: set[int] = set() + for _, value in graphlib.iter_graph(self.nodes, graph=True): + if isinstance(value, variablelib.Variable): + ids.add(id(value)) + return frozenset(ids) + + def exclude_from_args( + self, + args: tuple, + kwargs: dict | tp.Mapping, + ) -> 'CapturedInfo': + """Remove captured nodes that also appear in explicit args (by id).""" + if not self.nodes: + return self + arg_ids: set[int] = set() + for _, value in graphlib.iter_graph((args, kwargs), graph=True): + if isinstance(value, variablelib.Variable): + arg_ids.add(id(value)) + keep = [(node, idx) for node, idx in zip(self.nodes, self.freevar_indices) + if id(node) not in arg_ids] + if len(keep) == len(self.nodes): + return self + if not keep: + return CapturedInfo((), ()) + nodes, indices = zip(*keep) + return CapturedInfo(tuple(nodes), tuple(indices)) + + +_EMPTY_CAPTURED = CapturedInfo((), ()) + + +def find_captured_nodes(f) -> CapturedInfo: + """Find top-level nnx Variables and graph nodes in f's closure.""" + if not hasattr(f, '__closure__') or f.__closure__ is None: + return _EMPTY_CAPTURED + closure = f.__closure__ + seen: set[int] = set() + captured_nodes: list[object] = [] + captured_indices: list[int] = [] + for i, cell in enumerate(closure): + try: + obj = cell.cell_contents + except ValueError: + continue + if graphlib.is_graph_node(obj) or isinstance(obj, variablelib.Variable): + obj_id = id(obj) + if obj_id not in seen: + seen.add(obj_id) + captured_nodes.append(obj) + captured_indices.append(i) + return CapturedInfo(tuple(captured_nodes), tuple(captured_indices)) + + +def _make_cell(val): + """Create a closure cell containing val.""" + x = val + return (lambda: x).__closure__[0] # type: ignore + + +def replace_closure_cells( + f: tp.Callable, + captured_info: CapturedInfo, + replacements: tuple, +) -> tp.Callable: + """Return a copy of f with captured closure cells replaced by replacements.""" + if not captured_info or not f.__closure__: + return f + cells = list(f.__closure__) + for idx, new_val in zip(captured_info.freevar_indices, replacements): + cells[idx] = _make_cell(new_val) + return types.FunctionType( + f.__code__, f.__globals__, f.__name__, + f.__defaults__, tuple(cells), + ) + + def filter_kwargs(f, **kwargs): sig = inspect.signature(f) has_var_keyword = any( diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 52ec0205a..b94328789 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -453,6 +453,7 @@ class SimpleJitFn: donate_argnames: frozenset[str] graph: bool update_shardings: tuple[tp.Any, ...] + captured_info: extract.CapturedInfo = extract._EMPTY_CAPTURED def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @@ -464,7 +465,14 @@ def __call__(self, *args, **kwargs): ) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) - out = self.f(*args, **kwargs) + n_captured = len(self.captured_info) + user_args = args[n_captured:] + if self.captured_info: + f = extract.replace_closure_cells( + self.f, self.captured_info, args[:n_captured]) + else: + f = self.f + out = f(*user_args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_shardings) extract.check_no_aliases('jit', **current, out=out, check=['out']) @@ -508,6 +516,9 @@ def __init__( self.partial_args = partial_args self.graph = graph + # Capture closure nodes once at construction time + self._captured_info = extract.find_captured_nodes(fun) + resolved = _resolve_argnums(fun, static_argnums, static_argnames) if isinstance(in_shardings, (tuple, list)) and resolved: expanded = list(in_shardings) @@ -517,6 +528,13 @@ def __init__( else: self.in_shardings = in_shardings + # Prepend None shardings for captured nodes + n_captured = len(self._captured_info) + if n_captured > 0 and isinstance(in_shardings, (tuple, list)): + jit_in_shardings = (None,) * n_captured + tuple(in_shardings) + else: + jit_in_shardings = in_shardings + donate_argnums_set = frozenset( (donate_argnums,) if isinstance(donate_argnums, int) else donate_argnums or () @@ -528,14 +546,15 @@ def __init__( self.jitted_fn = jax.jit( SimpleJitFn( fun, - self.in_shardings, + jit_in_shardings if n_captured > 0 else self.in_shardings, out_shardings, donate_argnums_set, donate_argnames_set, graph, tuple(update_shardings), + captured_info=self._captured_info, ), - in_shardings=in_shardings, + in_shardings=jit_in_shardings, out_shardings=(out_shardings, update_shardings), static_argnums=static_argnums, static_argnames=static_argnames, @@ -565,9 +584,13 @@ def _maybe_from_tree(self, out): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: args = (*self.partial_args, *args) # type: ignore[assignment] - args, kwargs = self._maybe_to_tree(args, kwargs) - variables = extract.check_no_aliases('jit', args=args, kwargs=kwargs) - out, updates = self.jitted_fn(*args, **kwargs) + all_args = (*self._captured_info.nodes, *args) + all_args, kwargs = self._maybe_to_tree(all_args, kwargs) + variables = extract.check_no_aliases( + 'jit', args=all_args, kwargs=kwargs, + capture_ids=self._captured_info.variable_ids, + ) + out, updates = self.jitted_fn(*all_args, **kwargs) extract.apply_updates(variables, updates) return self._maybe_from_tree(out) @@ -578,26 +601,38 @@ def __get__(self, obj, objtype=None): def eval_shape(self, *args, **kwargs): args = (*self.partial_args, *args) - args, kwargs = self._maybe_to_tree(args, kwargs) + all_args = (*self._captured_info.nodes, *args) + all_args, kwargs = self._maybe_to_tree(all_args, kwargs) if not self.graph: - extract.check_no_aliases('jit', args=args, kwargs=kwargs) - out, updates = self.jitted_fn.eval_shape(*args, **kwargs) + extract.check_no_aliases( + 'jit', args=all_args, kwargs=kwargs, + capture_ids=self._captured_info.variable_ids, + ) + out, updates = self.jitted_fn.eval_shape(*all_args, **kwargs) return self._maybe_from_tree(out) def trace(self, *args, **kwargs): args = (*self.partial_args, *args) - args, kwargs = self._maybe_to_tree(args, kwargs) + all_args = (*self._captured_info.nodes, *args) + all_args, kwargs = self._maybe_to_tree(all_args, kwargs) if not self.graph: - extract.check_no_aliases('jit', args=args, kwargs=kwargs) - traced = self.jitted_fn.trace(*args, **kwargs) + extract.check_no_aliases( + 'jit', args=all_args, kwargs=kwargs, + capture_ids=self._captured_info.variable_ids, + ) + traced = self.jitted_fn.trace(*all_args, **kwargs) return SimpleTraced(traced, self) def lower(self, *args, **kwargs): args = (*self.partial_args, *args) - args, kwargs = self._maybe_to_tree(args, kwargs) + all_args = (*self._captured_info.nodes, *args) + all_args, kwargs = self._maybe_to_tree(all_args, kwargs) if not self.graph: - extract.check_no_aliases('jit', args=args, kwargs=kwargs) - lowered = self.jitted_fn.lower(*args, **kwargs) + extract.check_no_aliases( + 'jit', args=all_args, kwargs=kwargs, + capture_ids=self._captured_info.variable_ids, + ) + lowered = self.jitted_fn.lower(*all_args, **kwargs) return SimpleLowered(lowered, self) def jit_partial( fun: tp.Callable[..., R],