diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index cc816075b..c20f859a7 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -17,6 +17,7 @@ import dataclasses import functools import inspect +import types import typing as tp from flax import struct @@ -676,11 +677,25 @@ def copy_var_structure(tree: A) -> A: ) def check_no_aliases( - fn_name: str, /, *, check: tp.Iterable[str] = (), **kwargs + fn_name: str, + /, + *, + check: tp.Iterable[str] = (), + capture_ids: frozenset[int] = frozenset(), + **kwargs, ) -> dict[jax.tree_util.KeyPath, variablelib.Variable]: + """Check that no Variable appears more than once in the given collections. + + When ``capture_ids`` is provided (ids of Variables belonging to closure + captures), a Variable that was first seen as a capture is allowed to appear + exactly once more as an explicit argument (deduplication case). A Variable + seen twice in non-capture positions, or three or more times total, always + raises. + """ container = labeled(**kwargs) is_leaf = lambda x: isinstance(x, variablelib.Variable) seen: dict[int, jax.tree_util.KeyPath] = {} + dedup_used: set[int] = set() # capture_ids that have matched one explicit arg all_variables: dict[jax.tree_util.KeyPath, variablelib.Variable] = {} for path, leaf in jax.tree.leaves_with_path(container, is_leaf=is_leaf): if not isinstance(leaf, variablelib.Variable): @@ -700,6 +715,12 @@ def check_no_aliases( var_id = id(leaf) if var_id in seen: + if var_id in capture_ids and var_id not in dedup_used: + # First duplicate: capture prepended + same Variable passed explicitly. + # Allow this once (deduplication); record so a third copy still errors. + dedup_used.add(var_id) + all_variables[path] = leaf + continue path_str = jax.tree_util.keystr(path) seen_path_str = jax.tree_util.keystr(seen[var_id]) raise ValueError( @@ -1119,6 +1140,97 @@ def to_masked(tree, all_updates: OrderedDict[tp.Any, Updates]): is_leaf=lambda x: x is None ) +@dataclasses.dataclass(frozen=True) +class CapturedInfo: + """Info about closure-captured graph nodes/Variables.""" + nodes: tuple[object, ...] + freevar_indices: tuple[int, ...] + + def __len__(self): + return len(self.nodes) + + def __bool__(self): + return len(self.nodes) > 0 + + @property + def variable_ids(self) -> frozenset[int]: + """Ids of all Variables contained within the captured nodes.""" + ids: set[int] = set() + for _, value in graphlib.iter_graph(self.nodes, graph=True): + if isinstance(value, variablelib.Variable): + ids.add(id(value)) + return frozenset(ids) + + def exclude_from_args( + self, + args: tuple, + kwargs: dict | tp.Mapping, + ) -> 'CapturedInfo': + """Remove captured nodes that also appear in explicit args (by id).""" + if not self.nodes: + return self + arg_ids: set[int] = set() + for _, value in graphlib.iter_graph((args, kwargs), graph=True): + if isinstance(value, variablelib.Variable): + arg_ids.add(id(value)) + keep = [(node, idx) for node, idx in zip(self.nodes, self.freevar_indices) + if id(node) not in arg_ids] + if len(keep) == len(self.nodes): + return self + if not keep: + return CapturedInfo((), ()) + nodes, indices = zip(*keep) + return CapturedInfo(tuple(nodes), tuple(indices)) + + +_EMPTY_CAPTURED = CapturedInfo((), ()) + + +def find_captured_nodes(f) -> CapturedInfo: + """Find top-level nnx Variables and graph nodes in f's closure.""" + if not hasattr(f, '__closure__') or f.__closure__ is None: + return _EMPTY_CAPTURED + closure = f.__closure__ + seen: set[int] = set() + captured_nodes: list[object] = [] + captured_indices: list[int] = [] + for i, cell in enumerate(closure): + try: + obj = cell.cell_contents + except ValueError: + continue + if graphlib.is_graph_node(obj) or isinstance(obj, variablelib.Variable): + obj_id = id(obj) + if obj_id not in seen: + seen.add(obj_id) + captured_nodes.append(obj) + captured_indices.append(i) + return CapturedInfo(tuple(captured_nodes), tuple(captured_indices)) + + +def _make_cell(val): + """Create a closure cell containing val.""" + x = val + return (lambda: x).__closure__[0] # type: ignore + + +def replace_closure_cells( + f: tp.Callable, + captured_info: CapturedInfo, + replacements: tuple, +) -> tp.Callable: + """Return a copy of f with captured closure cells replaced by replacements.""" + if not captured_info or not f.__closure__: + return f + cells = list(f.__closure__) + for idx, new_val in zip(captured_info.freevar_indices, replacements): + cells[idx] = _make_cell(new_val) + return types.FunctionType( + f.__code__, f.__globals__, f.__name__, + f.__defaults__, tuple(cells), + ) + + def filter_kwargs(f, **kwargs): sig = inspect.signature(f) has_var_keyword = any( diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 222f9e439..f13616c50 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -72,6 +72,7 @@ class SimpleGradFn: f: tp.Callable[..., tp.Any] has_aux: bool graph: bool + captured_info: extract.CapturedInfo = extract._EMPTY_CAPTURED def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @@ -81,7 +82,14 @@ def __call__(self, *args, **kwargs): current, snapshot = extract.snapshot(labeled(args=args, kwargs=kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) - out = self.f(*args, **kwargs) + n_captured = len(self.captured_info) + user_args = args[n_captured:] + if self.captured_info: + f = extract.replace_closure_cells( + self.f, self.captured_info, args[:n_captured]) + else: + f = self.f + out = f(*user_args, **kwargs) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('grad', **current, out=out, check=['out']) @@ -151,8 +159,14 @@ def _grad_general( if not graph or not graph_updates: + captured_info = extract.find_captured_nodes(f) + + # argnums index into all_args = (*captured_nodes, *explicit_args). + # Captured nodes occupy positions 0..n_cap-1; explicit args follow. + # No shift is applied: argnums=0 refers to the first captured node (e.g. the + # model), which is the common case for closure-based training functions. gradded_fn = transform( - SimpleGradFn(f, has_aux, graph=graph), + SimpleGradFn(f, has_aux, graph=graph, captured_info=captured_info), argnums=argnums, # type: ignore[arg-type] has_aux=True, holomorphic=holomorphic, @@ -160,18 +174,26 @@ def _grad_general( ) def tree_grad_wrapper(*args, **kwargs): + all_args = (*captured_info.nodes, *args) if graph: - diff_argnums = (argnums,) if isinstance(argnums, int) else argnums - args_prefix = tuple( - i in diff_argnums for i in range(len(args)) + n_cap = len(captured_info) + diff_argnum_ints = frozenset( + a.argnum if isinstance(a, DiffState) else a + for a in ((argnums,) if isinstance(argnums, (int, DiffState)) else argnums) ) - args, kwargs = extract.to_tree2( - (args, kwargs), prefix=(args_prefix, False), + args_prefix = tuple(j in diff_argnum_ints for j in range(n_cap)) + tuple( + (n_cap + i) in diff_argnum_ints for i in range(len(args)) + ) + all_args, kwargs = extract.to_tree2( + (all_args, kwargs), prefix=(args_prefix, False), ) - variables = extract.check_no_aliases('grad', args=args, kwargs=kwargs) + variables = extract.check_no_aliases( + 'grad', args=all_args, kwargs=kwargs, + capture_ids=captured_info.variable_ids, + ) - fn_out = gradded_fn(*args, **kwargs) + fn_out = gradded_fn(*all_args, **kwargs) if return_value: if has_aux: diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 52ec0205a..b94328789 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -453,6 +453,7 @@ class SimpleJitFn: donate_argnames: frozenset[str] graph: bool update_shardings: tuple[tp.Any, ...] + captured_info: extract.CapturedInfo = extract._EMPTY_CAPTURED def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @@ -464,7 +465,14 @@ def __call__(self, *args, **kwargs): ) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) - out = self.f(*args, **kwargs) + n_captured = len(self.captured_info) + user_args = args[n_captured:] + if self.captured_info: + f = extract.replace_closure_cells( + self.f, self.captured_info, args[:n_captured]) + else: + f = self.f + out = f(*user_args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_shardings) extract.check_no_aliases('jit', **current, out=out, check=['out']) @@ -508,6 +516,9 @@ def __init__( self.partial_args = partial_args self.graph = graph + # Capture closure nodes once at construction time + self._captured_info = extract.find_captured_nodes(fun) + resolved = _resolve_argnums(fun, static_argnums, static_argnames) if isinstance(in_shardings, (tuple, list)) and resolved: expanded = list(in_shardings) @@ -517,6 +528,13 @@ def __init__( else: self.in_shardings = in_shardings + # Prepend None shardings for captured nodes + n_captured = len(self._captured_info) + if n_captured > 0 and isinstance(in_shardings, (tuple, list)): + jit_in_shardings = (None,) * n_captured + tuple(in_shardings) + else: + jit_in_shardings = in_shardings + donate_argnums_set = frozenset( (donate_argnums,) if isinstance(donate_argnums, int) else donate_argnums or () @@ -528,14 +546,15 @@ def __init__( self.jitted_fn = jax.jit( SimpleJitFn( fun, - self.in_shardings, + jit_in_shardings if n_captured > 0 else self.in_shardings, out_shardings, donate_argnums_set, donate_argnames_set, graph, tuple(update_shardings), + captured_info=self._captured_info, ), - in_shardings=in_shardings, + in_shardings=jit_in_shardings, out_shardings=(out_shardings, update_shardings), static_argnums=static_argnums, static_argnames=static_argnames, @@ -565,9 +584,13 @@ def _maybe_from_tree(self, out): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: args = (*self.partial_args, *args) # type: ignore[assignment] - args, kwargs = self._maybe_to_tree(args, kwargs) - variables = extract.check_no_aliases('jit', args=args, kwargs=kwargs) - out, updates = self.jitted_fn(*args, **kwargs) + all_args = (*self._captured_info.nodes, *args) + all_args, kwargs = self._maybe_to_tree(all_args, kwargs) + variables = extract.check_no_aliases( + 'jit', args=all_args, kwargs=kwargs, + capture_ids=self._captured_info.variable_ids, + ) + out, updates = self.jitted_fn(*all_args, **kwargs) extract.apply_updates(variables, updates) return self._maybe_from_tree(out) @@ -578,26 +601,38 @@ def __get__(self, obj, objtype=None): def eval_shape(self, *args, **kwargs): args = (*self.partial_args, *args) - args, kwargs = self._maybe_to_tree(args, kwargs) + all_args = (*self._captured_info.nodes, *args) + all_args, kwargs = self._maybe_to_tree(all_args, kwargs) if not self.graph: - extract.check_no_aliases('jit', args=args, kwargs=kwargs) - out, updates = self.jitted_fn.eval_shape(*args, **kwargs) + extract.check_no_aliases( + 'jit', args=all_args, kwargs=kwargs, + capture_ids=self._captured_info.variable_ids, + ) + out, updates = self.jitted_fn.eval_shape(*all_args, **kwargs) return self._maybe_from_tree(out) def trace(self, *args, **kwargs): args = (*self.partial_args, *args) - args, kwargs = self._maybe_to_tree(args, kwargs) + all_args = (*self._captured_info.nodes, *args) + all_args, kwargs = self._maybe_to_tree(all_args, kwargs) if not self.graph: - extract.check_no_aliases('jit', args=args, kwargs=kwargs) - traced = self.jitted_fn.trace(*args, **kwargs) + extract.check_no_aliases( + 'jit', args=all_args, kwargs=kwargs, + capture_ids=self._captured_info.variable_ids, + ) + traced = self.jitted_fn.trace(*all_args, **kwargs) return SimpleTraced(traced, self) def lower(self, *args, **kwargs): args = (*self.partial_args, *args) - args, kwargs = self._maybe_to_tree(args, kwargs) + all_args = (*self._captured_info.nodes, *args) + all_args, kwargs = self._maybe_to_tree(all_args, kwargs) if not self.graph: - extract.check_no_aliases('jit', args=args, kwargs=kwargs) - lowered = self.jitted_fn.lower(*args, **kwargs) + extract.check_no_aliases( + 'jit', args=all_args, kwargs=kwargs, + capture_ids=self._captured_info.variable_ids, + ) + lowered = self.jitted_fn.lower(*all_args, **kwargs) return SimpleLowered(lowered, self) def jit_partial( fun: tp.Callable[..., R], diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index a13516600..0522d6e4a 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -273,6 +273,7 @@ class SimpleVmapFn: in_axes: tp.Any out_axes: tp.Any update_axes: tuple[tp.Any, ...] + captured_info: extract.CapturedInfo = extract._EMPTY_CAPTURED def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @@ -283,7 +284,14 @@ def __call__(self, *args, **kwargs): ) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) - out = self.f(*args, **kwargs) + n_captured = len(self.captured_info) + user_args = args[n_captured:] + if self.captured_info: + f = extract.replace_closure_cells( + self.f, self.captured_info, args[:n_captured]) + else: + f = self.f + out = f(*user_args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) extract.check_no_aliases( @@ -522,15 +530,29 @@ def vmap( if not (graph and graph_updates): + captured_info = extract.find_captured_nodes(f_unbound) + n_captured = len(captured_info) + + # Prepend None (broadcast) axes for captured nodes + full_in_axes: tp.Any + if n_captured > 0 and isinstance(in_axes, (tuple, list)): + full_in_axes = (None,) * n_captured + tuple(in_axes) + else: + full_in_axes = in_axes + + if n_captured > 0: + update_axes[None] = None # captured nodes are broadcast + vmapped_fn = jax.vmap( SimpleVmapFn( f_unbound, graph=graph, - in_axes=in_axes, + in_axes=full_in_axes, out_axes=out_axes, update_axes=tuple(update_axes), + captured_info=captured_info, ), - in_axes=in_axes, + in_axes=full_in_axes, out_axes=(out_axes, update_axes), axis_name=axis_name, axis_size=axis_size, @@ -539,13 +561,17 @@ def vmap( @functools.wraps(f_unbound) def simple_vmap_wrapper(*args, **kwargs): + all_args = (*captured_info.nodes, *args) if graph: - args, kwargs = extract.to_tree2( - (args, kwargs), - prefix=(in_axes, 0), + all_args, kwargs = extract.to_tree2( + (all_args, kwargs), + prefix=(full_in_axes, 0), ) - variables = extract.check_no_aliases('vmap', args=args, kwargs=kwargs) - out, updates = vmapped_fn(*args, **kwargs) + variables = extract.check_no_aliases( + 'vmap', args=all_args, kwargs=kwargs, + capture_ids=captured_info.variable_ids, + ) + out, updates = vmapped_fn(*all_args, **kwargs) extract.apply_updates(variables, updates) if graph: out = extract.from_tree2(out) @@ -1410,6 +1436,7 @@ class SimpleScanFn: carry_idx: int | None carry_out_idx: int | None update_axes: tuple[tp.Any, ...] + captured_info: extract.CapturedInfo = extract._EMPTY_CAPTURED def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @@ -1425,7 +1452,14 @@ def __call__(self, full_carry: tp.Any, args: tp.Any): carry = extract.from_tree2(carry) args = extract.replace_at(args, self.carry_idx, carry) - out = self.f(*args) + n_captured = len(self.captured_info) + user_args = args[n_captured:] + if self.captured_info: + f = extract.replace_closure_cells( + self.f, self.captured_info, args[:n_captured]) + else: + f = self.f + out = f(*user_args) if self.carry_idx is None: # has carry carry_out = None @@ -1673,16 +1707,27 @@ def _simple_scan( was_carry = in_axes is Carry if in_axes is Carry: in_axes = (Carry,) - carry_idx = extract.find(in_axes, Carry) + + captured_info = extract.find_captured_nodes(f_unbound) + n_captured = len(captured_info) + + # Prepend None (broadcast) axes for captured nodes + if n_captured > 0 and isinstance(in_axes, tuple): + full_in_axes = (None,) * n_captured + in_axes + else: + full_in_axes = in_axes + + carry_idx = extract.find(full_in_axes, Carry) carry_out_idx = extract.find(out_axes, Carry) simple_scan_fn = SimpleScanFn( f_unbound, graph=graph, - in_axes=in_axes, out_axes=out_axes, + in_axes=full_in_axes, out_axes=out_axes, out_is_tuple=out_is_tuple, carry_idx=carry_idx, carry_out_idx=carry_out_idx, update_axes=tuple(updates_axes), + captured_info=captured_info, ) @functools.wraps(f) @@ -1692,23 +1737,29 @@ def simple_scan_wrapper(*args): 'When in_axes=Carry, the function must take exactly one argument, ' f'got {len(args)} arguments.' ) + + all_args = (*captured_info.nodes, *args) + if graph: # check consistent aliasing - extract.to_tree2(args, prefix=in_axes) + extract.to_tree2(all_args, prefix=full_in_axes) - carry, args = extract.mask_at(args, carry_idx) + carry, all_args = extract.mask_at(all_args, carry_idx) if graph: - args = extract.to_tree2(args, prefix=in_axes) + all_args = extract.to_tree2(all_args, prefix=full_in_axes) carry = extract.to_tree2(carry) - variables = extract.check_no_aliases('scan', args=args, carry=carry) - args, broadcasts = extract.extract( + variables = extract.check_no_aliases( + 'scan', args=all_args, carry=carry, + capture_ids=captured_info.variable_ids, + ) + all_args, broadcasts = extract.extract( lambda _, axes, x: axes is None, - in_axes, args, + full_in_axes, all_args, is_leaf=lambda x: isinstance(x, variablelib.Variable), prefix_leaf=lambda x: x is None, ) args_t = _move_axis( lambda ax, leaf: jnp.moveaxis(leaf, ax, 0), - in_axes, args, + full_in_axes, all_args, ) (carry_out, final_broadcasts), (ys, updates) = jax.lax.scan( simple_scan_fn, diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index a1795caf0..75e24acad 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -6352,7 +6352,7 @@ def f(x): if graph_updates: error_regex = 'Cannot extract graph node from different trace level' else: - error_regex = 'Cannot return captured Variable' + error_regex = 'Cannot return captured Variable|Duplicate' with self.assertRaisesRegex(ValueError, error_regex): f(jnp.zeros((4,))) @@ -8016,5 +8016,245 @@ def forward_block(rng_state, rest_state, x): assert not jnp.allclose(y, y2) +class TestClosureCapture(parameterized.TestCase): + """Tests for auto-capture of closure Variables in tree-mode transforms.""" + + @parameterized.parameters( + (True, True), (True, False), (False, False), + ) + def test_jit_closure_basic(self, graph, graph_updates): + """Closure output matches explicit-arg version.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + @nnx.jit(graph=graph, graph_updates=graph_updates) + def forward_closure(x): + return model(x) + + @nnx.jit(graph=graph, graph_updates=graph_updates) + def forward_explicit(m, x): + return m(x) + + x = jnp.ones((4, 2)) + np.testing.assert_allclose(forward_closure(x), forward_explicit(model, x)) + + def test_jit_closure_state_update(self): + """Mutations to captured Variables match explicit-arg behavior.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + count_closure = nnx.Variable(jnp.array(0)) + count_explicit = nnx.Variable(jnp.array(0)) + + @nnx.jit(graph=False) + def forward_closure(x): + count_closure[...] += 1 + return model(x) + + @nnx.jit(graph=False) + def forward_explicit(count_var, m, x): + count_var[...] += 1 + return m(x) + + x = jnp.ones((1, 2)) + out_closure = forward_closure(x) + out_explicit = forward_explicit(count_explicit, model, x) + self.assertEqual(count_closure[...], count_explicit[...]) + np.testing.assert_allclose(out_closure, out_explicit) + + out_closure = forward_closure(x) + out_explicit = forward_explicit(count_explicit, model, x) + self.assertEqual(count_closure[...], count_explicit[...]) + np.testing.assert_allclose(out_closure, out_explicit) + + def test_jit_closure_param_update_reflected(self): + """External param changes are reflected in both closure and explicit-arg versions.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + @nnx.jit(graph=False) + def forward_closure(x): + return model(x) + + @nnx.jit(graph=False) + def forward_explicit(m, x): + return m(x) + + x = jnp.ones((1, 2)) + model.kernel[...] = jnp.zeros((2, 3)) + model.bias[...] = jnp.zeros((3,)) + + y_closure = forward_closure(x) + y_explicit = forward_explicit(model, x) + np.testing.assert_allclose(y_closure, np.zeros((1, 3))) + np.testing.assert_allclose(y_closure, y_explicit) + + def test_jit_closure_explicit_arg(self): + """Explicit-arg jit output matches closure version.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + @nnx.jit(graph=False) + def forward_explicit(m, x): + return m(x) + + @nnx.jit(graph=False) + def forward_closure(x): + return model(x) + + x = jnp.ones((4, 2)) + np.testing.assert_allclose(forward_explicit(model, x), forward_closure(x)) + + def test_jit_closure_deduplication(self): + """Variable captured in closure AND passed explicitly does not alias-error.""" + bias = nnx.Variable(jnp.ones((3,))) + + @nnx.jit(graph=False) + def forward_dedup(bias_arg, x): + # bias is closed over; bias_arg is the same Variable passed explicitly. + _ = bias # force bias into closure + return x + bias_arg[...] + + @nnx.jit(graph=False) + def forward_explicit(bias_arg, x): + return x + bias_arg[...] + + x = jnp.ones((4, 3)) + np.testing.assert_allclose(forward_dedup(bias, x), forward_explicit(bias, x)) + + def test_jit_closure_no_closure(self): + """Function with no closure still works.""" + @nnx.jit(graph=False) + def add(a, b): + return a + b + + self.assertEqual(add(jnp.array(1), jnp.array(2)), 3) + + @parameterized.parameters( + (True, True), (True, False), (False, False), + ) + def test_vmap_closure_broadcast(self, graph, graph_updates): + """Captured module is broadcast (not vmapped), output matches explicit-arg version.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + @nnx.vmap(in_axes=(0,), out_axes=0, graph=graph, graph_updates=graph_updates) + def forward_closure(x): + return model(x) + + @nnx.vmap(in_axes=(None, 0), out_axes=0, graph=graph, graph_updates=graph_updates) + def forward_explicit(m, x): + return m(x) + + x = jnp.ones((5, 2)) + np.testing.assert_allclose(forward_closure(x), forward_explicit(model, x)) + + def test_vmap_closure_state_update(self): + """State updates from vmapped closure match explicit-arg behavior.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + count_closure = nnx.Variable(jnp.array(0)) + count_explicit = nnx.Variable(jnp.array(0)) + + @nnx.vmap(in_axes=(0,), out_axes=0, graph=False) + def forward_closure(x): + count_closure[...] += 1 + return model(x) + + @nnx.vmap(in_axes=(None, None, 0), out_axes=0, graph=False) + def forward_explicit(count_var, m, x): + count_var[...] += 1 + return m(x) + + x = jnp.ones((5, 2)) + forward_closure(x) + forward_explicit(count_explicit, model, x) + self.assertEqual(count_closure[...], count_explicit[...]) + + def test_grad_closure(self): + """Grad where data arrays are captured but model is explicit.""" + x = jnp.ones((4, 2)) + y = jnp.ones((4, 3)) + + @nnx.value_and_grad + def loss_fn(model): + return jnp.mean((model(x) - y) ** 2) + + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + loss, grads = loss_fn(model) + self.assertEqual(loss.shape, ()) + + def test_grad_closure_captured_model(self): + """nnx.grad with captured model: argnums=0 differentiates wrt the captured model.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + x = jnp.ones((4, 2)) + + @nnx.grad(graph=True, graph_updates=False) + def compute_loss_closure(x): + return jnp.mean((model(x)) ** 2) + + @nnx.grad(graph=True, graph_updates=False) + def compute_loss_explicit(m, x): + return jnp.mean((m(x)) ** 2) + + grads_closure = compute_loss_closure(x) + grads_explicit = compute_loss_explicit(model, x) + + # Both should produce parameter gradients with the same values + np.testing.assert_allclose(grads_closure.kernel[...], grads_explicit.kernel[...]) + np.testing.assert_allclose(grads_closure.bias[...], grads_explicit.bias[...]) + + def test_scan_closure_broadcast(self): + """Captured module is broadcast in scan, output matches explicit-arg version.""" + model = nnx.Linear(3, 3, rngs=nnx.Rngs(0)) + + @nnx.scan(in_axes=(0, nnx.Carry), out_axes=(0, nnx.Carry), graph=False) + def step_closure(xs, carry): + return model(xs), model(carry) + + @nnx.scan(in_axes=(None, 0, nnx.Carry), out_axes=(0, nnx.Carry), graph=False) + def step_explicit(m, xs, carry): + return m(xs), m(carry) + + xs = jnp.ones((4, 3)) + carry = jnp.ones((3,)) + ys_closure, final_closure = step_closure(xs, carry) + ys_explicit, final_explicit = step_explicit(model, xs, carry) + + np.testing.assert_allclose(ys_closure, ys_explicit) + np.testing.assert_allclose(final_closure, final_explicit) + + def test_nested_jit_vmap_closure(self): + """nnx.jit over nnx.vmap: closure output matches explicit-arg version.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + @nnx.jit(graph=False) + def batched_closure(x): + @nnx.vmap(in_axes=(0,), out_axes=0, graph=False) + def forward(x): + return model(x) + return forward(x) + + @nnx.jit(graph=False) + def batched_explicit(m, x): + @nnx.vmap(in_axes=(None, 0), out_axes=0, graph=False) + def forward(m, x): + return m(x) + return forward(m, x) + + x = jnp.ones((5, 2)) + np.testing.assert_allclose(batched_closure(x), batched_explicit(model, x)) + + def test_jit_closure_no_recompilation(self): + """nnx.jit with a captured variable does not recompile on repeated calls.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + @nnx.jit(graph=False) + def forward(x): + return model(x) + + x = jnp.ones((4, 2)) + forward(x) + cache_size_after_first = forward.jitted_fn._cache_size() + forward(x) + cache_size_after_second = forward.jitted_fn._cache_size() + + self.assertEqual(cache_size_after_first, cache_size_after_second, + 'nnx.jit recompiled on the second call with a closure capture') + + if __name__ == '__main__': absltest.main()