diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index ac0fd2391..37eba3f10 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -146,6 +146,7 @@ from .rnglib import fork_rngs as fork_rngs from .rnglib import reseed as reseed from .rnglib import split_rngs as split_rngs +from .rnglib import with_rngs as with_rngs from .rnglib import restore_rngs as restore_rngs from .spmd import PARTITION_NAME as PARTITION_NAME from .spmd import get_partition_spec as get_partition_spec diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 6e7d0c9db..37e3aa0c0 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -725,6 +725,106 @@ def __enter__(self): def __exit__(self, *args): restore_rngs(self) +def with_rngs(tree, split=None, fork=None, only=True, graph=False): + """Returns a copy of ``tree`` with ``RngStream`` objects replaced according to + ``split`` and ``fork`` rules. + + ``split`` controls which streams are **split** — after splitting, each call + to the stream produces one key from an array of pre-generated keys rather + than a single key. ``fork`` controls which of the remaining streams are + **forked** — each call to a forked stream produces a unique key derived from + the parent counter. Streams that match neither rule are returned unchanged. + + Args: + tree: A pytree that may contain ``RngStream`` objects (e.g. an ``Rngs`` + instance, a module, or any nested structure). + split: Specifies which streams to split and into what shape. Can be: + + * An ``int`` or ``tuple[int, ...]`` — split *all* streams into this + shape, equivalent to ``{...: split}``. + * A :class:`~flax.nnx.filterlib.Filter`-keyed mapping where each value + is an ``int`` or ``tuple[int, ...]``. The first matching filter wins. + + fork: A :class:`~flax.nnx.filterlib.Filter` selecting which streams not + already handled by ``split`` should be forked. Pass ``...`` to fork all + remaining streams. + graph: If ``True``, uses graph-mode which supports the full + NNX feature set including shared references. If ``False``, uses + tree-mode which treats Modules as regular JAX pytrees, avoiding + the overhead of the graph protocol. + + Returns: + A new tree of the same structure as ``tree`` with ``RngStream`` objects + replaced by split or forked copies as specified. + + Example — split all streams:: + + >>> from flax import nnx + ... + >>> rngs = nnx.Rngs(params=0, dropout=1) + >>> new_rngs = nnx.with_rngs(rngs, split=4) + >>> new_rngs.params.key.shape + (4,) + >>> new_rngs.dropout.key.shape + (4,) + + Example — split some streams, fork the rest:: + + >>> rngs = nnx.Rngs(params=0, dropout=1) + >>> new_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...) + >>> new_rngs.params.key.shape + (4,) + >>> new_rngs.dropout.key.shape # forked: scalar key, advanced counter + () + + Example — per-filter split shapes:: + + >>> rngs = nnx.Rngs(params=0, dropout=1, noise=2) + >>> new_rngs = nnx.with_rngs(rngs, split={ + ... 'params': 4, # split params into 4 keys + ... ...: (2, 4), # split anything else into 2×4 keys + ... }) + >>> new_rngs.params.key.shape + (4,) + >>> new_rngs.noise.key.shape + (2, 4) + + """ + if split is None: + split = {} + elif isinstance(split, (int, tuple)): + split = {...: split} + if isinstance(fork, str) or not isinstance(fork, tp.Sequence): + fork = [fork] + split_predicates = [(k, filterlib.to_predicate(k), v) for k, v in split.items()] + fork_predicates = [(p, filterlib.to_predicate(p)) for p in fork] + only_predicate = filterlib.to_predicate(only) + + def f(path, val): + if isinstance(val, RngStream) and only_predicate(path, val): + results = {} + for (filter, predicate, num_splits) in split_predicates: + if predicate(path, val): + results['split'] = (filter, num_splits) + break + for (filter, predicate) in fork_predicates: + if predicate(path, val): + results['fork'] = (filter,) + break + if len(results) > 1: + fork_filter = results['fork'][0] + if fork_filter not in (..., True): + rule_descriptions = '\n'.join(f' - {rule} matches filter {info[0]!r}' for rule, info in results.items()) + raise ValueError( + f"RngStream at path {path} matches multiple rules:\n{rule_descriptions}" + ) + if 'split' in results: + return val.split(results['split'][1]) + if 'fork' in results: + return val.fork() + return val + + return graphlib.recursive_map(f, tree, graph=graph) @tp.overload def split_rngs( diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index 9f01582bd..7ff2fc01d 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -14,6 +14,7 @@ from functools import partial from typing import Any +import re import jax import jax.numpy as jnp @@ -239,5 +240,94 @@ def test_random_helpers(self): ) np.testing.assert_allclose(x_nnx, x_jax) +class TestWithRngs(parameterized.TestCase): + def test_split_int_splits_all_streams(self): + rngs = nnx.Rngs(params=0, dropout=1) + new_rngs = nnx.with_rngs(rngs, split=4) + + self.assertEqual(new_rngs.params.key.shape, (4,)) + self.assertEqual(new_rngs['dropout'].key.shape, (4,)) + + def test_split_tuple_splits_all_streams(self): + rngs = nnx.Rngs(params=0, dropout=1) + new_rngs = nnx.with_rngs(rngs, split=(2, 3)) + + self.assertEqual(new_rngs.params.key.shape, (2, 3)) + self.assertEqual(new_rngs['dropout'].key.shape, (2, 3)) + + def test_fork_forks_all_streams(self): + rngs = nnx.Rngs(params=0, dropout=1) + original_params_key = rngs.params.key[...] + original_dropout_key = rngs['dropout'].key[...] + + new_rngs = nnx.with_rngs(rngs, fork=...) + + # Forked keys are scalar and differ from originals + self.assertEqual(new_rngs.params.key.shape, ()) + self.assertEqual(new_rngs['dropout'].key.shape, ()) + self.assertFalse(jnp.array_equal(new_rngs.params.key[...], original_params_key)) + self.assertFalse(jnp.array_equal(new_rngs['dropout'].key[...], original_dropout_key)) + + def test_split_mapping_applies_per_filter(self): + rngs = nnx.Rngs(params=0, dropout=1, noise=2) + new_rngs = nnx.with_rngs(rngs, split={'params': 4, ...: (2, 3)}) + + self.assertEqual(new_rngs.params.key.shape, (4,)) + self.assertEqual(new_rngs['dropout'].key.shape, (2, 3)) + self.assertEqual(new_rngs.noise.key.shape, (2, 3)) + + def test_split_mapping_first_matching_filter_wins(self): + rngs = nnx.Rngs(params=0, dropout=1) + # 'params' filter comes before '...' so it should match first + new_rngs = nnx.with_rngs(rngs, split={'params': 4, ...: 8}) + + self.assertEqual(new_rngs.params.key.shape, (4,)) + self.assertEqual(new_rngs['dropout'].key.shape, (8,)) + + def test_split_some_fork_rest(self): + rngs = nnx.Rngs(params=0, dropout=1) + new_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...) + + self.assertEqual(new_rngs.params.key.shape, (4,)) + # dropout not matched by split → forked (scalar) + self.assertEqual(new_rngs['dropout'].key.shape, ()) + + def test_original_base_key_not_replaced(self): + # nnx.with_rngs advances the original stream's counter (consuming one step to + # derive the new keys) but does not replace the original's base key. + rngs = nnx.Rngs(params=0, dropout=1) + original_key_var = rngs.params.key + + nnx.with_rngs(rngs, split=4) + + self.assertIs(rngs.params.key, original_key_var) + self.assertEqual(rngs.params.key.shape, ()) + + def test_unmatched_streams_returned_unchanged(self): + rngs = nnx.Rngs(params=0, dropout=1) + # Only fork 'params'; 'dropout' matches neither split nor fork + new_rngs = nnx.with_rngs(rngs, fork='params') + + self.assertIsNot(new_rngs['dropout'], rngs['dropout']) # new tree, but... + self.assertTrue(jnp.array_equal(new_rngs['dropout'].key[...], rngs['dropout'].key[...])) + self.assertEqual(new_rngs['dropout'].key.shape, ()) + + def test_split_and_fork_same_stream_raises(self): + rngs = nnx.Rngs(params=0, dropout=1) + with self.assertRaisesRegex(ValueError, re.compile(r"multiple rules")): + nnx.with_rngs(rngs, split={'params': 4}, fork='params') + + def test_works_on_plain_pytree(self): + params_stream = nnx.RngStream(0, tag='params') + dropout_stream = nnx.RngStream(1, tag='dropout') + tree = {'a': params_stream, 'b': dropout_stream} + + new_tree = nnx.with_rngs(tree, split=4) + + self.assertEqual(new_tree['a'].key.shape, (4,)) + self.assertEqual(new_tree['b'].key.shape, (4,)) + # Originals unchanged + self.assertEqual(params_stream.key.shape, ()) + if __name__ == '__main__': absltest.main()