diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index a3a0aa8fb..8cef1d231 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -22,6 +22,7 @@ import builtins import jax.core +import jax.numpy as jnp from flax import config from flax.nnx import filterlib, reprlib, traversals, variablelib @@ -2878,7 +2879,13 @@ def pop( return states -def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> Node: +def clone( + node: Node, + variables: bool = True, + *, + arrays: bool = False, + graph: bool | None = None, +) -> Node: """Create a deep copy of the given graph node. Example usage:: @@ -2890,10 +2897,32 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N >>> model.bias[...] += 1 >>> assert (model.bias[...] != cloned_model.bias[...]).all() + When ``arrays=True``, the underlying JAX array buffers are also copied so + that the clone and the original share no physical memory. This is needed + when the cloned model will be passed to a function with + ``donate_argnums``—without it both the original and the clone share the + same buffer, causing JAX to reject the donation with + *"the same buffer cannot be donated more than once"*. + + Example with ``arrays=True``:: + + >>> import jax + >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + >>> cloned_model = nnx.clone(model, arrays=True) + >>> # The buffers are now independent — safe to donate separately. + >>> model_ptr = model.bias[...].unsafe_buffer_pointer() + >>> clone_ptr = cloned_model.bias[...].unsafe_buffer_pointer() + >>> assert model_ptr != clone_ptr + Args: node: A graph node object. - variables: If ``True`` (default) copies of the :class:`Variable` objects are created, - otherwise the Variables are shared between the original and cloned node. + variables: If ``True`` (default) copies of the :class:`Variable` objects + are created, otherwise the Variables are shared between the original and + cloned node. + arrays: If ``True``, replaces all shared JAX array buffers in the clone + with independent physical copies (using :func:`jax.numpy.array`). This + has no effect when ``variables=False``. Defaults to ``False`` to + preserve the existing copy-on-write behaviour. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding @@ -2902,7 +2931,20 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N A deep copy of the :class:`Module` object. """ graphdef, state = split(node, graph=graph) - return merge(graphdef, state, copy=variables) + merged = merge(graphdef, state, copy=variables) + if arrays and variables: + # Variable.copy() currently uses jax.tree.map(lambda x: x, value), which + # reconstructs the pytree wrapper but reuses the underlying JAX array + # buffer (copy-on-write semantics). For donate_argnums compatibility we + # need each buffer to be physically independent. We do a second pass over + # the cloned state and replace every jax.Array leaf with a fresh copy. + _, cloned_state = split(merged, graph=graph) + deep_state = jax.tree.map( + lambda x: jnp.array(x) if isinstance(x, jax.Array) else x, + cloned_state, + ) + return merge(graphdef, deep_state, copy=False) + return merged def with_vars( diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index be350e8a9..fb79f7406 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -574,6 +574,43 @@ def test_clone(self): assert m.b.c.get_value() == m2.b.c.get_value() assert m.b.d.get_value() == m2.b.d.get_value() + def test_clone_arrays_true_creates_independent_buffers(self): + """Regression test for https://github.com/google/flax/issues/5461. + + nnx.clone() with the default arrays=False uses copy-on-write semantics: + new Variable wrappers are created but the underlying jax.Array buffers + are shared. This breaks donate_argnums because JAX sees the same buffer + donated twice. + + With arrays=True, all array buffers are physically independent. + """ + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + cloned = nnx.clone(model, arrays=True) + + # Values must be equal. + np.testing.assert_array_equal( + model.kernel[...], cloned.kernel[...]) + np.testing.assert_array_equal( + model.bias[...], cloned.bias[...]) + + # Physical buffers must be distinct. + assert (model.kernel.get_value().unsafe_buffer_pointer() != + cloned.kernel.get_value().unsafe_buffer_pointer()), ( + 'arrays=True should produce independent buffers, but kernel still ' + 'shares memory with the original') + assert (model.bias.get_value().unsafe_buffer_pointer() != + cloned.bias.get_value().unsafe_buffer_pointer()), ( + 'arrays=True should produce independent buffers, but bias still ' + 'shares memory with the original') + + def test_clone_arrays_false_default_preserves_copy_on_write(self): + """Default clone (arrays=False) still diverges on mutation.""" + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + cloned = nnx.clone(model) # arrays=False is the default + model.bias[...] += 1 + assert (model.bias[...] != cloned.bias[...]).all(), ( + 'Default clone should diverge after in-place mutation') + def test_sow_existing_non_variable_field(self): class Foo(nnx.Module): def __init__(self) -> None: