Skip to content

Handle captures in jax transforms#5485

Open
samanklesaria wants to merge 1 commit into
google:mainfrom
samanklesaria:handle_captures
Open

Handle captures in jax transforms#5485
samanklesaria wants to merge 1 commit into
google:mainfrom
samanklesaria:handle_captures

Conversation

@samanklesaria

@samanklesaria samanklesaria commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator

Auto-capture of closure Variables in NNX transforms

Enables functions passed to nnx.jit, nnx.vmap, nnx.grad, and nnx.scan to close over Modules and Variables, with the transform automatically detecting and managing them. Previously, all stateful objects had to be passed as explicit function arguments.

Before:

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))

@nnx.jit(graph=False)
def forward(model, x):
    return model(x)

y = forward(model, x)

After:

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))

@nnx.jit(graph=False)
def forward(x):
    return model(x)

y = forward(x)

State mutations on captured Variables propagate back, external parameter changes are reflected in subsequent calls, and all three mode combinations (graph=False, graph=True/graph_updates=False, graph=True/graph_updates=True) are supported.

Design

The implementation works in three layers:

1. Discovery (extract.py)

find_captured_nodes(f) inspects f.__closure__ and returns a CapturedInfo containing the top-level graph nodes/Variables found as closure freevars, along with their freevar cell indices. CapturedInfo.exclude_from_args() deduplicates against explicit arguments (by id()), walking into graph nodes to find nested Variables, so the same module can safely appear in both the closure and the call arguments.

2. Outer wrapper (each transform's wrapper function)

The outer wrapper (SimpleJitWrapped.__call__, simple_vmap_wrapper, tree_grad_wrapper, simple_scan_wrapper) prepends captured nodes to the positional args before passing them to the inner JAX-transformed function. This makes them dynamic inputs to jax.jit/jax.vmap/etc. rather than trace-time constants. Default prefixes are prepended to match: None shardings for jit, None (broadcast) axes for vmap/scan, non-differentiable for grad (argnums are shifted by n_captured).

3. Inner function (each transform's Simple*Fn.__call__)

Inside the traced function, after treemap_copy_args creates mutable copies of all args, replace_closure_cells() creates a new version of the user function whose closure cells point to the dynamic copies instead of the originals. This is the key mechanism -- without it, the user function would see the original Variables (which are trace-level constants and can't be mutated). The new function is called with only the user's explicit args (captured args are stripped). The existing snapshot/update machinery then operates on all args (including captured), detects mutations, and apply_updates writes changes back to the original Variables.

replace_closure_cells uses types.FunctionType to construct a new function object with modified closure cells, targeting only the cells identified by CapturedInfo.freevar_indices.

Edge cases handled

  • Deduplication: If a node appears in both closure and explicit args, it's excluded from the captured list to avoid aliasing check failures
  • No closure: Functions without closures (or with only non-graph-node closures like jnp) are unaffected -- CapturedInfo is empty and all code paths short-circuit
  • Nested transforms: Each transform level independently discovers its own captures
  • Shardings/axes consistency: Captured node prefixes are prepended to in_shardings/in_axes tuples only when they're tuple-typed, preserving the existing behavior for scalar prefixes

Files changed

File Change
flax/nnx/extract.py Added CapturedInfo, find_captured_nodes(), replace_closure_cells(), _make_cell()
flax/nnx/transforms/compilation.py SimpleJitFn and SimpleJitWrapped use closure capture + cell replacement
flax/nnx/transforms/iteration.py Same pattern for SimpleVmapFn/vmap wrapper and SimpleScanFn/scan wrapper
flax/nnx/transforms/autodiff.py Same pattern for SimpleGradFn/grad wrapper, with argnum shifting
tests/nnx/transforms_test.py Added TestClosureCapture (14 tests); updated one error regex in existing test

Test coverage

The TestClosureCapture class covers:

  • Basic closure capture for jit (all 3 graph/graph_updates modes), vmap (all 3 modes), grad, scan
  • State mutation propagation: counter Variable incremented inside jit/vmap, verified externally
  • External param changes reflected: zeroing model params after first call, verifying output changes
  • Deduplication: same module captured and passed explicitly
  • No-closure baseline: plain function still works
  • Nested transforms: jit wrapping vmap with shared closure

@samanklesaria samanklesaria force-pushed the handle_captures branch 6 times, most recently from 22ef0b8 to 7ee1421 Compare June 9, 2026 20:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant