Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import builtins

import jax.core
import jax.numpy as jnp

from flax import config
from flax.nnx import filterlib, reprlib, traversals, variablelib
Expand Down Expand Up @@ -2878,7 +2879,13 @@ def pop(
return states


def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> Node:
def clone(
node: Node,
variables: bool = True,
*,
arrays: bool = False,
graph: bool | None = None,
) -> Node:
"""Create a deep copy of the given graph node.

Example usage::
Expand All @@ -2890,10 +2897,32 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N
>>> model.bias[...] += 1
>>> assert (model.bias[...] != cloned_model.bias[...]).all()

When ``arrays=True``, the underlying JAX array buffers are also copied so
that the clone and the original share no physical memory. This is needed
when the cloned model will be passed to a function with
``donate_argnums``—without it both the original and the clone share the
same buffer, causing JAX to reject the donation with
*"the same buffer cannot be donated more than once"*.

Example with ``arrays=True``::

>>> import jax
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> cloned_model = nnx.clone(model, arrays=True)
>>> # The buffers are now independent — safe to donate separately.
>>> model_ptr = model.bias[...].unsafe_buffer_pointer()
>>> clone_ptr = cloned_model.bias[...].unsafe_buffer_pointer()
>>> assert model_ptr != clone_ptr

Args:
node: A graph node object.
variables: If ``True`` (default) copies of the :class:`Variable` objects are created,
otherwise the Variables are shared between the original and cloned node.
variables: If ``True`` (default) copies of the :class:`Variable` objects
are created, otherwise the Variables are shared between the original and
cloned node.
arrays: If ``True``, replaces all shared JAX array buffers in the clone
with independent physical copies (using :func:`jax.numpy.array`). This
has no effect when ``variables=False``. Defaults to ``False`` to
preserve the existing copy-on-write behaviour.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references. If ``False``, uses
tree-mode which treats Modules as regular JAX pytrees, avoiding
Expand All @@ -2902,7 +2931,20 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N
A deep copy of the :class:`Module` object.
"""
graphdef, state = split(node, graph=graph)
return merge(graphdef, state, copy=variables)
merged = merge(graphdef, state, copy=variables)
if arrays and variables:
# Variable.copy() currently uses jax.tree.map(lambda x: x, value), which
# reconstructs the pytree wrapper but reuses the underlying JAX array
# buffer (copy-on-write semantics). For donate_argnums compatibility we
# need each buffer to be physically independent. We do a second pass over
# the cloned state and replace every jax.Array leaf with a fresh copy.
_, cloned_state = split(merged, graph=graph)
deep_state = jax.tree.map(
lambda x: jnp.array(x) if isinstance(x, jax.Array) else x,
cloned_state,
)
return merge(graphdef, deep_state, copy=False)
return merged


def with_vars(
Expand Down
37 changes: 37 additions & 0 deletions tests/nnx/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,43 @@ def test_clone(self):
assert m.b.c.get_value() == m2.b.c.get_value()
assert m.b.d.get_value() == m2.b.d.get_value()

def test_clone_arrays_true_creates_independent_buffers(self):
"""Regression test for https://github.com/google/flax/issues/5461.

nnx.clone() with the default arrays=False uses copy-on-write semantics:
new Variable wrappers are created but the underlying jax.Array buffers
are shared. This breaks donate_argnums because JAX sees the same buffer
donated twice.

With arrays=True, all array buffers are physically independent.
"""
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
cloned = nnx.clone(model, arrays=True)

# Values must be equal.
np.testing.assert_array_equal(
model.kernel[...], cloned.kernel[...])
np.testing.assert_array_equal(
model.bias[...], cloned.bias[...])

# Physical buffers must be distinct.
assert (model.kernel.get_value().unsafe_buffer_pointer() !=
cloned.kernel.get_value().unsafe_buffer_pointer()), (
'arrays=True should produce independent buffers, but kernel still '
'shares memory with the original')
assert (model.bias.get_value().unsafe_buffer_pointer() !=
cloned.bias.get_value().unsafe_buffer_pointer()), (
'arrays=True should produce independent buffers, but bias still '
'shares memory with the original')

def test_clone_arrays_false_default_preserves_copy_on_write(self):
"""Default clone (arrays=False) still diverges on mutation."""
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
cloned = nnx.clone(model) # arrays=False is the default
model.bias[...] += 1
assert (model.bias[...] != cloned.bias[...]).all(), (
'Default clone should diverge after in-place mutation')

def test_sow_existing_non_variable_field(self):
class Foo(nnx.Module):
def __init__(self) -> None:
Expand Down