From a0a9e9d1cd07adc14a50eecc2125356b83d1250d Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 21 Apr 2026 16:23:23 -0700 Subject: [PATCH] allow captures in nnx.grad PiperOrigin-RevId: 903485800 --- flax/nnx/extract.py | 5 ++ flax/nnx/transforms/autodiff.py | 52 +++++++++++++--- flax/nnx/variablelib.py | 106 ++++++++++++++++++++++++++++++-- tests/nnx/transforms_test.py | 82 ++++++++++++++++++++++++ 4 files changed, 230 insertions(+), 15 deletions(-) diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index cc816075b..3f18420fc 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -1137,3 +1137,8 @@ def filter_kwargs(f, **kwargs): } filtered_kwargs = {k: v for k, v in kwargs.items() if k in named_params} return filtered_kwargs + +def update_captures(captures: list[variablelib.Capture]): + for capture in captures: + variable, update = capture.ref.obj, capture.variable + variable.update_from_state(update) \ No newline at end of file diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 222f9e439..ce008a314 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -78,20 +78,24 @@ def __post_init__(self): @extract.treemap_copy_args def __call__(self, *args, **kwargs): + ctx = variablelib.current_capture_context(True) + ctx.set_trace_state() current, snapshot = extract.snapshot(labeled(args=args, kwargs=kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases('grad', **current, out=out, check=['out']) + extract.check_no_aliases( + 'grad', **current, captures=ctx.captures, out=out, check=['out'] + ) updates = extract.get_updates(current, snapshot) if self.has_aux: loss, aux = out - return loss, (updates, aux) + return loss, (updates, ctx.captures, aux) else: - return out, updates + return out, (updates, ctx.captures) @dataclasses.dataclass(eq=False) @@ -159,7 +163,7 @@ def _grad_general( allow_int=allow_int, ) - def tree_grad_wrapper(*args, **kwargs): + def simple_grad_wrapper(*args, **kwargs): if graph: diff_argnums = (argnums,) if isinstance(argnums, int) else argnums args_prefix = tuple( @@ -171,31 +175,33 @@ def tree_grad_wrapper(*args, **kwargs): variables = extract.check_no_aliases('grad', args=args, kwargs=kwargs) - fn_out = gradded_fn(*args, **kwargs) + with variablelib.capture_context(variables): + fn_out = gradded_fn(*args, **kwargs) if return_value: if has_aux: - (loss, (updates, aux)), grads = fn_out + (loss, (updates, captures, aux)), grads = fn_out if graph: grads, aux = extract.from_tree2((grads, aux)) result = (loss, aux), grads else: - (loss, updates), grads = fn_out + (loss, (updates, captures)), grads = fn_out if graph: grads = extract.from_tree2(grads) result = loss, grads else: if has_aux: - grads, (updates, aux) = fn_out + grads, (updates, captures, aux) = fn_out if graph: grads, aux = extract.from_tree2((grads, aux)) result = grads, aux else: - grads, updates = fn_out + grads, (updates, captures) = fn_out if graph: grads = extract.from_tree2(grads) result = grads + extract.update_captures(captures) extract.apply_updates(variables, updates) return result - return tree_grad_wrapper + return simple_grad_wrapper jax_argnums: int | tuple[int, ...] if isinstance(argnums, (int, DiffState)): @@ -288,6 +294,8 @@ def process_out(pure_out: A, /) -> A: @tp.overload + + def grad( f: tp.Callable[..., tp.Any], *, @@ -300,6 +308,8 @@ def grad( graph_updates: bool | None = None, ) -> tp.Callable[..., tp.Any]: ... @tp.overload + + def grad( *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, @@ -437,6 +447,8 @@ def grad( @tp.overload + + def value_and_grad( f: tp.Callable[..., tp.Any], *, @@ -449,6 +461,8 @@ def value_and_grad( graph_updates: bool | None = None, ) -> tp.Callable[..., tp.Any]: ... @tp.overload + + def value_and_grad( *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, @@ -581,6 +595,8 @@ def __call__(self, *primals): @tp.overload + + def vjp( f: tp.Callable[..., tp.Any], *primals: tp.Any, @@ -590,6 +606,8 @@ def vjp( graph_updates: bool | None = None, ) -> tuple[tp.Any, tp.Callable] | tuple[tp.Any, tp.Callable, tp.Any]: ... @tp.overload + + def vjp( *, has_aux: bool = False, @@ -747,6 +765,8 @@ def __call__(self, *primals): @tp.overload + + def jvp( f: tp.Callable[..., tp.Any], primals: tuple[tp.Any, ...], @@ -757,6 +777,8 @@ def jvp( graph_updates: bool | None = None, ) -> tuple[tp.Any, ...]: ... @tp.overload + + def jvp( *, has_aux: bool = False, @@ -764,6 +786,8 @@ def jvp( graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ... @tp.overload + + def jvp( f: tp.Callable[..., tp.Any], *, @@ -1338,6 +1362,8 @@ def defvjp( @tp.overload + + def custom_vjp( fun: tp.Callable[..., A], *, @@ -1346,6 +1372,8 @@ def custom_vjp( graph_updates: bool | None = None, ) -> CustomVjp[A] | SimpleCustomVjp[A]: ... @tp.overload + + def custom_vjp( *, nondiff_argnums: tuple[int | DiffState, ...] = (), @@ -1595,6 +1623,8 @@ def __call__(self, *args, **kwargs): return out, updates @tp.overload + + def remat( *, prevent_cse: bool = True, @@ -1604,6 +1634,8 @@ def remat( graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload + + def remat( f: F, *, diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 80021d5c4..ee418d6f4 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -14,6 +14,7 @@ # pytype: skip-file from __future__ import annotations +import contextlib import dataclasses import functools from functools import partial @@ -62,17 +63,86 @@ # JAX v0.7.0 or older from jax.experimental import MutableArray as Ref # type: ignore[no-redef] +@dataclasses.dataclass(eq=False, slots=True) +class Reference(tp.Generic[A]): + obj: A + +@jtu.register_dataclass +@dataclasses.dataclass(frozen=True, slots=True) +class Capture: + ref: Reference[Variable] = jax.tree.static() + variable: Variable + + @classmethod + def create(cls, value: A): + return cls(Reference(value), value) + +@dataclasses.dataclass(slots=True) +class VariableCaptureContext: + input_variables: dict[int, tuple] # id -> path + trace_state: tracers.TraceState | None + captures: list[Capture] + + def set_trace_state(self): + assert self.trace_state is None + self.trace_state = tracers.TraceState() + + def is_valid(self): + return self.trace_state is not None and self.trace_state.is_valid() + + def capture(self, variable: Variable): + if id(variable) in self.input_variables: + path = self.input_variables[id(variable)] + raise ValueError( + f'Cannot mutate captured Variable {variable!r} because it is an alias of ' + f'input Variable at {jtu.keystr(path)}.' + ) + self.captures.append(Capture.create(variable)) @dataclasses.dataclass class VariableContext(threading.local): variable_hijax_stack: list[bool] = dataclasses.field(default_factory=list) variable_ref_stack: list[bool] = dataclasses.field(default_factory=list) eager_shard_stack: list[bool] = dataclasses.field(default_factory=list) + capture_stack: list[VariableCaptureContext] = dataclasses.field(default_factory=list) VARIABLE_CONTEXT = VariableContext() +@tp.overload +def current_capture_context(check: tp.Literal[True]) -> VariableCaptureContext: + ... + +@tp.overload +def current_capture_context( + check: bool = False, +) -> VariableCaptureContext | None: + ... + +def current_capture_context( + check: bool = False, +) -> VariableCaptureContext | None: + if not VARIABLE_CONTEXT.capture_stack: + if check: + raise ValueError('No capture context found.') + return None + return VARIABLE_CONTEXT.capture_stack[-1] + + +@contextlib.contextmanager +def capture_context(variables: dict[jtu.KeyPath, Variable]): + ctx = VariableCaptureContext( + input_variables={id(v): path for path, v in variables.items()}, + trace_state=None, + captures=[], + ) + VARIABLE_CONTEXT.capture_stack.append(ctx) + try: + yield ctx + finally: + VARIABLE_CONTEXT.capture_stack.pop() + class use_eager_sharding(BaseConfigContext): """Sets whether Variables should use eager sharding by default or not. @@ -147,10 +217,18 @@ def __len__(self) -> int: @tp.overload + + + + def var_defaults() -> VarDefaults: ... @tp.overload + + + + def var_defaults( *, hijax: bool | None = None, ref: bool | None = None ) -> VarDefaultsContext: ... @@ -1388,16 +1466,26 @@ def __init__( @property def _can_update(self) -> bool: """Whether the Variable can be updated in-place in the current trace context.""" - if self.hijax: - return True - else: - return self._trace_state.is_valid() + can_update, _ = self._can_update_and_captured() + return can_update def _check_can_update(self): - if not self.hijax and not self._trace_state.is_valid(): + can_update, got_captured = self._can_update_and_captured() + if not can_update: raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) + if got_captured: + current_capture_context(True).capture(self) + + def _can_update_and_captured(self): + if self.hijax: + return True, False + elif self._trace_state.is_valid(): + return True, False + elif (capture_ctx := current_capture_context()) and capture_ctx.is_valid(): + return True, True + return False, False def __getattr__(self, name: str) -> tp.Any: if name in object.__getattribute__(self, '_var_metadata'): @@ -2381,6 +2469,10 @@ def variable_name_from_type( @tp.overload + + + + def register_variable_name( name: str, typ: type[Variable[tp.Any]], @@ -2390,6 +2482,10 @@ def register_variable_name( @tp.overload + + + + def register_variable_name( name: str, *, diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index a1795caf0..4dd4d737a 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -1791,6 +1791,88 @@ def f(m: nnx.Dict): assert m['c'] == 7 assert m['d'] == 5.0 + def test_grad_captures(self): + v = nnx.Variable(jnp.array(1.0)) + n = 0 + + @nnx.grad(graph=True, graph_updates=False) + def g(x): + nonlocal n + n += 1 + v[...] += 1.0 + return x * v[...] + + x = jnp.array(2.0) + grads = g(x) + + self.assertEqual(n, 1) + self.assertEqual(v.get_value(), 2.0) + self.assertEqual(grads, 2.0) + + grads = g(x) + self.assertEqual(n, 2) + self.assertEqual(v.get_value(), 3.0) + self.assertEqual(grads, 3.0) + + @nnx.set_graph_mode(False) + @nnx.set_graph_updates(False) + def test_unpack_capture(self): + class Foo(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(2, 2, rngs=rngs) + self.bn = nnx.BatchNorm(2, use_running_average=False, rngs=rngs) + self.count = nnx.Variable(jnp.array(0)) + + def __call__(self, x): + self.count[...] += 1 + return nnx.relu(self.bn(self.linear(x))) + + model = Foo(nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adamw(0.1), wrt=nnx.Param) + + @nnx.jit + def train_step(model: Foo, optimizer, x, y): + params, nondiff = nnx.unpack(model, nnx.Param, ...) + + def loss_fn(params): + model = nnx.merge(params, nondiff) + return jnp.mean((model(x) - y) ** 2) + + grads = nnx.grad(loss_fn)(params) + optimizer.update(grads, model) + + x, y = jnp.ones((1,2)), jnp.ones((1,2)) + train_step(model, optimizer, x, y) + self.assertEqual(model.count[...], 1) + train_step(model, optimizer, x, y) + self.assertEqual(model.count[...], 2) + + def test_grad_capture_alias_input_error(self): + v = nnx.Variable(jnp.array(1.0)) + + @nnx.grad(graph=True, graph_updates=False) + def g(x): + v[...] += 1.0 + return x * v[...] + + with self.assertRaisesRegex( + ValueError, 'Cannot mutate captured Variable' + ): + g(v) + + def test_grad_capture_alias_output_error(self): + v = nnx.Variable(jnp.array(1.0)) + + @nnx.grad(graph=True, graph_updates=False) + def g(x): + v[...] += 1.0 + return x, v + + with self.assertRaisesRegex( + ValueError, 'Duplicate' + ): + g(jnp.array(1.0)) + @parameterized.parameters(True, False) def test_grad_with_multiple_ref_types(self, graph_updates: bool): m = nnx.Dict(