From a2ddec5cc2fc48618e48bc89209d42836bc037bf Mon Sep 17 00:00:00 2001 From: Sumukh Chaluvaraju Date: Thu, 4 Jun 2026 17:50:29 +0100 Subject: [PATCH 1/2] feat(nnx): add arrays=True to nnx.clone for independent buffer copies nnx.clone() currently uses copy-on-write semantics: new Variable wrapper objects are created but the underlying jax.Array buffers are shared with the original (Variable.copy() uses jax.tree.map(lambda x: x, value) which is structurally a copy but physically a no-op at the buffer level). This works for the common mutation-after-clone pattern (model.bias[...] += 1 rebinds the attribute, so the two diverge), but breaks donate_argnums: JAX inspects buffer addresses before running user code, and sees the same physical buffer donated twice, raising: "the same buffer cannot be donated more than once" Add arrays=True (default False to preserve existing behaviour): cloned = nnx.clone(model, arrays=True) When arrays=True, after the standard clone a second pass replaces every jax.Array leaf in the cloned State with jnp.array(x), forcing a new physical allocation. The clone and the original are then fully independent at the buffer level, making donate_argnums safe. Fixes: google/flax#5461 --- flax/nnx/graphlib.py | 50 ++++++++++++++++++++++++++++++++++++---- tests/nnx/module_test.py | 37 +++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index a3a0aa8fb..44bf077ba 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -22,6 +22,7 @@ import builtins import jax.core +import jax.numpy as jnp from flax import config from flax.nnx import filterlib, reprlib, traversals, variablelib @@ -2878,7 +2879,13 @@ def pop( return states -def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> Node: +def clone( + node: Node, + variables: bool = True, + *, + arrays: bool = False, + graph: bool | None = None, +) -> Node: """Create a deep copy of the given graph node. Example usage:: @@ -2890,10 +2897,32 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N >>> model.bias[...] += 1 >>> assert (model.bias[...] != cloned_model.bias[...]).all() + When ``arrays=True``, the underlying JAX array buffers are also copied so + that the clone and the original share no physical memory. This is needed + when the cloned model will be passed to a function with + ``donate_argnums``—without it both the original and the clone share the + same buffer, causing JAX to reject the donation with + *"the same buffer cannot be donated more than once"*. + + Example with ``arrays=True``:: + + >>> import jax + >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + >>> cloned_model = nnx.clone(model, arrays=True) + >>> # The buffers are now independent — safe to donate separately. + >>> model_ptr = model.bias.value.unsafe_buffer_pointer() + >>> clone_ptr = cloned_model.bias.value.unsafe_buffer_pointer() + >>> assert model_ptr != clone_ptr + Args: node: A graph node object. - variables: If ``True`` (default) copies of the :class:`Variable` objects are created, - otherwise the Variables are shared between the original and cloned node. + variables: If ``True`` (default) copies of the :class:`Variable` objects + are created, otherwise the Variables are shared between the original and + cloned node. + arrays: If ``True``, replaces all shared JAX array buffers in the clone + with independent physical copies (using :func:`jax.numpy.array`). This + has no effect when ``variables=False``. Defaults to ``False`` to + preserve the existing copy-on-write behaviour. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding @@ -2902,7 +2931,20 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N A deep copy of the :class:`Module` object. """ graphdef, state = split(node, graph=graph) - return merge(graphdef, state, copy=variables) + merged = merge(graphdef, state, copy=variables) + if arrays and variables: + # Variable.copy() currently uses jax.tree.map(lambda x: x, value), which + # reconstructs the pytree wrapper but reuses the underlying JAX array + # buffer (copy-on-write semantics). For donate_argnums compatibility we + # need each buffer to be physically independent. We do a second pass over + # the cloned state and replace every jax.Array leaf with a fresh copy. + _, cloned_state = split(merged, graph=graph) + deep_state = jax.tree.map( + lambda x: jnp.array(x) if isinstance(x, jax.Array) else x, + cloned_state, + ) + return merge(graphdef, deep_state, copy=False) + return merged def with_vars( diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index be350e8a9..fb79f7406 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -574,6 +574,43 @@ def test_clone(self): assert m.b.c.get_value() == m2.b.c.get_value() assert m.b.d.get_value() == m2.b.d.get_value() + def test_clone_arrays_true_creates_independent_buffers(self): + """Regression test for https://github.com/google/flax/issues/5461. + + nnx.clone() with the default arrays=False uses copy-on-write semantics: + new Variable wrappers are created but the underlying jax.Array buffers + are shared. This breaks donate_argnums because JAX sees the same buffer + donated twice. + + With arrays=True, all array buffers are physically independent. + """ + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + cloned = nnx.clone(model, arrays=True) + + # Values must be equal. + np.testing.assert_array_equal( + model.kernel[...], cloned.kernel[...]) + np.testing.assert_array_equal( + model.bias[...], cloned.bias[...]) + + # Physical buffers must be distinct. + assert (model.kernel.get_value().unsafe_buffer_pointer() != + cloned.kernel.get_value().unsafe_buffer_pointer()), ( + 'arrays=True should produce independent buffers, but kernel still ' + 'shares memory with the original') + assert (model.bias.get_value().unsafe_buffer_pointer() != + cloned.bias.get_value().unsafe_buffer_pointer()), ( + 'arrays=True should produce independent buffers, but bias still ' + 'shares memory with the original') + + def test_clone_arrays_false_default_preserves_copy_on_write(self): + """Default clone (arrays=False) still diverges on mutation.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + cloned = nnx.clone(model) # arrays=False is the default + model.bias[...] += 1 + assert (model.bias[...] != cloned.bias[...]).all(), ( + 'Default clone should diverge after in-place mutation') + def test_sow_existing_non_variable_field(self): class Foo(nnx.Module): def __init__(self) -> None: From d8e97b41110b51193482f2bb0ad2137e685114d7 Mon Sep 17 00:00:00 2001 From: Sumukh Chaluvaraju Date: Fri, 5 Jun 2026 13:03:21 +0100 Subject: [PATCH 2/2] fix(nnx): update clone doctest to use subscript instead of deprecated .value .value access is deprecated in newest Flax; use variable[...] for Variable[Array] instances to avoid UnexpectedException in doctest. --- flax/nnx/graphlib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index 44bf077ba..8cef1d231 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -2910,8 +2910,8 @@ def clone( >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> cloned_model = nnx.clone(model, arrays=True) >>> # The buffers are now independent — safe to donate separately. - >>> model_ptr = model.bias.value.unsafe_buffer_pointer() - >>> clone_ptr = cloned_model.bias.value.unsafe_buffer_pointer() + >>> model_ptr = model.bias[...].unsafe_buffer_pointer() + >>> clone_ptr = cloned_model.bias[...].unsafe_buffer_pointer() >>> assert model_ptr != clone_ptr Args: