Hi folks, me again.
I keep playing around with nnx and seems like nnx.Optimizer, when creating optimizer state for models intended to be used with 'scan', use sharding information for original, non-stacked tensor, without taking into the account extra dimension added by vmap.
import jax
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
mesh1 = jax.make_mesh((2, 4), ("a", "b"))
rules1 = (("A", "a"), ("B", "b"))
class Model(nnx.Module):
def __init__(self, num_layers, rngs: nnx.Rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_linear(rngs: nnx.Rngs):
return nnx.Param(
jnp.ones((16, 16)),
sharding=("A", "B"),
mesh=mesh1,
sharding_rules=rules1,
)
self.linears = create_linear(rngs=rngs)
@nnx.jit
def init():
model = Model(num_layers=1, rngs=nnx.Rngs(params=0))
optimizer = nnx.Optimizer(
model,
optax.adam(learning_rate=0.001),
wrt=nnx.Param,
)
return model, optimizer
model, optimizer = init()
Traceback (most recent call last):
File "/papyrax/test_scan_axis.py", line 37, in <module>
model, optimizer = init()
^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/transforms/compilation.py", line 474, in __call__
pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/transforms/compilation.py", line 135, in __call__
out = self.f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/papyrax/test_scan_axis.py", line 30, in init
optimizer = nnx.Optimizer(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 400, in __call__
return _graph_node_meta_call(cls, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 412, in _graph_node_meta_call
cls._pytree_meta_construct(node, *args, **kwargs)
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 403, in _pytree_meta_construct
self.__init__(*args, **kwargs)
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 88, in _check_wrt_wrapper
return f(*args, wrt=wrt, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 160, in __init__
to_opt_state(tx.init(nnx.state(model, wrt)))
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 57, in to_opt_state
tree = jax.tree.map(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/tree.py", line 155, in map
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 52, in _to_opt_state
opt_state = OptVariable(x.get_value(), **x.get_metadata()) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 904, in __call__
return cls._variable_meta_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 907, in _variable_meta_call
variable = super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 1108, in __init__
value = core_spmd.shard_value(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 49, in shard_value
return _apply_sharding(value, NamedSharding(mesh, pspec))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 37, in _apply_sharding
return jax.jit(lambda x: x, out_shardings=sharding)(value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 2, but it is equal to 1 (full shape: (1, 16, 16))
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Hi folks, me again.
I keep playing around with nnx and seems like nnx.Optimizer, when creating optimizer state for models intended to be used with 'scan', use sharding information for original, non-stacked tensor, without taking into the account extra dimension added by vmap.
Repro script:
Output: