Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
from .rnglib import fork_rngs as fork_rngs
from .rnglib import reseed as reseed
from .rnglib import split_rngs as split_rngs
from .rnglib import with_rngs as with_rngs
from .rnglib import restore_rngs as restore_rngs
from .spmd import PARTITION_NAME as PARTITION_NAME
from .spmd import get_partition_spec as get_partition_spec
Expand Down
100 changes: 100 additions & 0 deletions flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,106 @@ def __enter__(self):
def __exit__(self, *args):
restore_rngs(self)

def with_rngs(tree, split=None, fork=None, only=True, graph=False):
"""Returns a copy of ``tree`` with ``RngStream`` objects replaced according to
``split`` and ``fork`` rules.

``split`` controls which streams are **split** — after splitting, each call
to the stream produces one key from an array of pre-generated keys rather
than a single key. ``fork`` controls which of the remaining streams are
**forked** — each call to a forked stream produces a unique key derived from
the parent counter. Streams that match neither rule are returned unchanged.

Args:
tree: A pytree that may contain ``RngStream`` objects (e.g. an ``Rngs``
instance, a module, or any nested structure).
split: Specifies which streams to split and into what shape. Can be:

* An ``int`` or ``tuple[int, ...]`` — split *all* streams into this
shape, equivalent to ``{...: split}``.
* A :class:`~flax.nnx.filterlib.Filter`-keyed mapping where each value
is an ``int`` or ``tuple[int, ...]``. The first matching filter wins.

fork: A :class:`~flax.nnx.filterlib.Filter` selecting which streams not
already handled by ``split`` should be forked. Pass ``...`` to fork all
remaining streams.
graph: If ``True``, uses graph-mode which supports the full
NNX feature set including shared references. If ``False``, uses
tree-mode which treats Modules as regular JAX pytrees, avoiding
the overhead of the graph protocol.

Returns:
A new tree of the same structure as ``tree`` with ``RngStream`` objects
replaced by split or forked copies as specified.

Example — split all streams::

>>> from flax import nnx
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> new_rngs = nnx.with_rngs(rngs, split=4)
>>> new_rngs.params.key.shape
(4,)
>>> new_rngs.dropout.key.shape
(4,)

Example — split some streams, fork the rest::

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> new_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)
>>> new_rngs.params.key.shape
(4,)
>>> new_rngs.dropout.key.shape # forked: scalar key, advanced counter
()

Example — per-filter split shapes::

>>> rngs = nnx.Rngs(params=0, dropout=1, noise=2)
>>> new_rngs = nnx.with_rngs(rngs, split={
... 'params': 4, # split params into 4 keys
... ...: (2, 4), # split anything else into 2×4 keys
... })
>>> new_rngs.params.key.shape
(4,)
>>> new_rngs.noise.key.shape
(2, 4)

"""
if split is None:
split = {}
elif isinstance(split, (int, tuple)):
split = {...: split}
if isinstance(fork, str) or not isinstance(fork, tp.Sequence):
fork = [fork]
split_predicates = [(k, filterlib.to_predicate(k), v) for k, v in split.items()]
fork_predicates = [(p, filterlib.to_predicate(p)) for p in fork]
only_predicate = filterlib.to_predicate(only)

def f(path, val):
if isinstance(val, RngStream) and only_predicate(path, val):
results = {}
for (filter, predicate, num_splits) in split_predicates:
if predicate(path, val):
results['split'] = (filter, num_splits)
break
for (filter, predicate) in fork_predicates:
if predicate(path, val):
results['fork'] = (filter,)
break
if len(results) > 1:
fork_filter = results['fork'][0]
if fork_filter not in (..., True):
rule_descriptions = '\n'.join(f' - {rule} matches filter {info[0]!r}' for rule, info in results.items())
raise ValueError(
f"RngStream at path {path} matches multiple rules:\n{rule_descriptions}"
)
if 'split' in results:
return val.split(results['split'][1])
if 'fork' in results:
return val.fork()
return val

return graphlib.recursive_map(f, tree, graph=graph)

@tp.overload
def split_rngs(
Expand Down
90 changes: 90 additions & 0 deletions tests/nnx/rngs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from functools import partial
from typing import Any
import re

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -239,5 +240,94 @@ def test_random_helpers(self):
)
np.testing.assert_allclose(x_nnx, x_jax)

class TestWithRngs(parameterized.TestCase):
def test_split_int_splits_all_streams(self):
rngs = nnx.Rngs(params=0, dropout=1)
new_rngs = nnx.with_rngs(rngs, split=4)

self.assertEqual(new_rngs.params.key.shape, (4,))
self.assertEqual(new_rngs['dropout'].key.shape, (4,))

def test_split_tuple_splits_all_streams(self):
rngs = nnx.Rngs(params=0, dropout=1)
new_rngs = nnx.with_rngs(rngs, split=(2, 3))

self.assertEqual(new_rngs.params.key.shape, (2, 3))
self.assertEqual(new_rngs['dropout'].key.shape, (2, 3))

def test_fork_forks_all_streams(self):
rngs = nnx.Rngs(params=0, dropout=1)
original_params_key = rngs.params.key[...]
original_dropout_key = rngs['dropout'].key[...]

new_rngs = nnx.with_rngs(rngs, fork=...)

# Forked keys are scalar and differ from originals
self.assertEqual(new_rngs.params.key.shape, ())
self.assertEqual(new_rngs['dropout'].key.shape, ())
self.assertFalse(jnp.array_equal(new_rngs.params.key[...], original_params_key))
self.assertFalse(jnp.array_equal(new_rngs['dropout'].key[...], original_dropout_key))

def test_split_mapping_applies_per_filter(self):
rngs = nnx.Rngs(params=0, dropout=1, noise=2)
new_rngs = nnx.with_rngs(rngs, split={'params': 4, ...: (2, 3)})

self.assertEqual(new_rngs.params.key.shape, (4,))
self.assertEqual(new_rngs['dropout'].key.shape, (2, 3))
self.assertEqual(new_rngs.noise.key.shape, (2, 3))

def test_split_mapping_first_matching_filter_wins(self):
rngs = nnx.Rngs(params=0, dropout=1)
# 'params' filter comes before '...' so it should match first
new_rngs = nnx.with_rngs(rngs, split={'params': 4, ...: 8})

self.assertEqual(new_rngs.params.key.shape, (4,))
self.assertEqual(new_rngs['dropout'].key.shape, (8,))

def test_split_some_fork_rest(self):
rngs = nnx.Rngs(params=0, dropout=1)
new_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)

self.assertEqual(new_rngs.params.key.shape, (4,))
# dropout not matched by split → forked (scalar)
self.assertEqual(new_rngs['dropout'].key.shape, ())

def test_original_base_key_not_replaced(self):
# nnx.with_rngs advances the original stream's counter (consuming one step to
# derive the new keys) but does not replace the original's base key.
rngs = nnx.Rngs(params=0, dropout=1)
original_key_var = rngs.params.key

nnx.with_rngs(rngs, split=4)

self.assertIs(rngs.params.key, original_key_var)
self.assertEqual(rngs.params.key.shape, ())

def test_unmatched_streams_returned_unchanged(self):
rngs = nnx.Rngs(params=0, dropout=1)
# Only fork 'params'; 'dropout' matches neither split nor fork
new_rngs = nnx.with_rngs(rngs, fork='params')

self.assertIsNot(new_rngs['dropout'], rngs['dropout']) # new tree, but...
self.assertTrue(jnp.array_equal(new_rngs['dropout'].key[...], rngs['dropout'].key[...]))
self.assertEqual(new_rngs['dropout'].key.shape, ())

def test_split_and_fork_same_stream_raises(self):
rngs = nnx.Rngs(params=0, dropout=1)
with self.assertRaisesRegex(ValueError, re.compile(r"multiple rules")):
nnx.with_rngs(rngs, split={'params': 4}, fork='params')

def test_works_on_plain_pytree(self):
params_stream = nnx.RngStream(0, tag='params')
dropout_stream = nnx.RngStream(1, tag='dropout')
tree = {'a': params_stream, 'b': dropout_stream}

new_tree = nnx.with_rngs(tree, split=4)

self.assertEqual(new_tree['a'].key.shape, (4,))
self.assertEqual(new_tree['b'].key.shape, (4,))
# Originals unchanged
self.assertEqual(params_stream.key.shape, ())

if __name__ == '__main__':
absltest.main()
Loading