Handle captures in jax transforms#5485
Open
samanklesaria wants to merge 1 commit into
Open
Conversation
22ef0b8 to
7ee1421
Compare
7ee1421 to
e63b647
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Auto-capture of closure Variables in NNX transforms
Enables functions passed to
nnx.jit,nnx.vmap,nnx.grad, andnnx.scanto close over Modules and Variables, with the transform automatically detecting and managing them. Previously, all stateful objects had to be passed as explicit function arguments.Before:
After:
State mutations on captured Variables propagate back, external parameter changes are reflected in subsequent calls, and all three mode combinations (
graph=False,graph=True/graph_updates=False,graph=True/graph_updates=True) are supported.Design
The implementation works in three layers:
1. Discovery (
extract.py)find_captured_nodes(f)inspectsf.__closure__and returns aCapturedInfocontaining the top-level graph nodes/Variables found as closure freevars, along with their freevar cell indices.CapturedInfo.exclude_from_args()deduplicates against explicit arguments (byid()), walking into graph nodes to find nested Variables, so the same module can safely appear in both the closure and the call arguments.2. Outer wrapper (each transform's wrapper function)
The outer wrapper (
SimpleJitWrapped.__call__,simple_vmap_wrapper,tree_grad_wrapper,simple_scan_wrapper) prepends captured nodes to the positional args before passing them to the inner JAX-transformed function. This makes them dynamic inputs tojax.jit/jax.vmap/etc. rather than trace-time constants. Default prefixes are prepended to match:Noneshardings for jit,None(broadcast) axes for vmap/scan, non-differentiable for grad (argnums are shifted byn_captured).3. Inner function (each transform's
Simple*Fn.__call__)Inside the traced function, after
treemap_copy_argscreates mutable copies of all args,replace_closure_cells()creates a new version of the user function whose closure cells point to the dynamic copies instead of the originals. This is the key mechanism -- without it, the user function would see the original Variables (which are trace-level constants and can't be mutated). The new function is called with only the user's explicit args (captured args are stripped). The existing snapshot/update machinery then operates on all args (including captured), detects mutations, andapply_updateswrites changes back to the original Variables.replace_closure_cellsusestypes.FunctionTypeto construct a new function object with modified closure cells, targeting only the cells identified byCapturedInfo.freevar_indices.Edge cases handled
jnp) are unaffected --CapturedInfois empty and all code paths short-circuitin_shardings/in_axestuples only when they're tuple-typed, preserving the existing behavior for scalar prefixesFiles changed
flax/nnx/extract.pyCapturedInfo,find_captured_nodes(),replace_closure_cells(),_make_cell()flax/nnx/transforms/compilation.pySimpleJitFnandSimpleJitWrappeduse closure capture + cell replacementflax/nnx/transforms/iteration.pySimpleVmapFn/vmap wrapper andSimpleScanFn/scan wrapperflax/nnx/transforms/autodiff.pySimpleGradFn/grad wrapper, with argnum shiftingtests/nnx/transforms_test.pyTestClosureCapture(14 tests); updated one error regex in existing testTest coverage
The
TestClosureCaptureclass covers: