Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions docs_nnx/hijax/hijax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"@jax.jit\n",
"def train_step(x, y):\n",
" loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)\n",
" loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad\n",
" loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad\n",
" optimizer.update(model, grads)\n",
" return loss\n",
"\n",
Expand Down Expand Up @@ -297,8 +297,8 @@
"\n",
"model = Linear(1, 3, rngs=nnx.Rngs(0))\n",
"\n",
"print(f\"{nnx.vars_as(model, mutable=False) = !s}\")\n",
"print(f\"{nnx.vars_as(model, mutable=True) = !s}\")"
"print(f\"{nnx.with_vars(model, mutable=False) = !s}\")\n",
"print(f\"{nnx.with_vars(model, mutable=True) = !s}\")"
]
},
{
Expand All @@ -317,7 +317,7 @@
],
"source": [
"v = nnx.Variable(jnp.array(0))\n",
"v_immut = nnx.vars_as(v, mutable=False)\n",
"v_immut = nnx.with_vars(v, mutable=False)\n",
"assert not v_immut.mutable\n",
"\n",
"try:\n",
Expand Down Expand Up @@ -355,7 +355,7 @@
],
"source": [
"v = nnx.Variable(jnp.array(0))\n",
"v_ref = nnx.vars_as(v, ref=True)\n",
"v_ref = nnx.with_vars(v, ref=True)\n",
"assert v_ref.ref\n",
"print(v_ref)\n",
"print(v_ref.get_raw_value())"
Expand Down Expand Up @@ -386,11 +386,11 @@
}
],
"source": [
"v_immut = nnx.vars_as(v_ref, mutable=False)\n",
"v_immut = nnx.with_vars(v_ref, mutable=False)\n",
"assert not v_immut.ref\n",
"print(\"immutable =\", v_immut)\n",
"\n",
"v_ref = nnx.vars_as(v_immut, mutable=True)\n",
"v_ref = nnx.with_vars(v_immut, mutable=True)\n",
"assert v_ref.ref\n",
"print(\"mutable =\", v_ref)"
]
Expand Down Expand Up @@ -458,7 +458,7 @@
" model = nnx.merge(graphdef, params, nondiff)\n",
" return ((model(x) - y) ** 2).mean()\n",
"\n",
" loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad\n",
" loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad\n",
" optimizer.update(model, grads)\n",
"\n",
" return loss\n",
Expand Down Expand Up @@ -563,9 +563,9 @@
"source": [
"@jax.jit\n",
"def create_model(rngs):\n",
" return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)\n",
" return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False)\n",
"\n",
"model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)\n",
"model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True)\n",
"\n",
"print(\"model.linear =\", model.linear)"
]
Expand Down
20 changes: 10 additions & 10 deletions docs_nnx/hijax/hijax.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param)
@jax.jit
def train_step(x, y):
loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad
loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad
optimizer.update(model, grads)
return loss

Expand Down Expand Up @@ -112,13 +112,13 @@ class Linear(nnx.Module):

model = Linear(1, 3, rngs=nnx.Rngs(0))

print(f"{nnx.vars_as(model, mutable=False) = !s}")
print(f"{nnx.vars_as(model, mutable=True) = !s}")
print(f"{nnx.with_vars(model, mutable=False) = !s}")
print(f"{nnx.with_vars(model, mutable=True) = !s}")
```

```{code-cell} ipython3
v = nnx.Variable(jnp.array(0))
v_immut = nnx.vars_as(v, mutable=False)
v_immut = nnx.with_vars(v, mutable=False)
assert not v_immut.mutable

try:
Expand All @@ -131,18 +131,18 @@ except Exception as e:

```{code-cell} ipython3
v = nnx.Variable(jnp.array(0))
v_ref = nnx.vars_as(v, ref=True)
v_ref = nnx.with_vars(v, ref=True)
assert v_ref.ref
print(v_ref)
print(v_ref.get_raw_value())
```

```{code-cell} ipython3
v_immut = nnx.vars_as(v_ref, mutable=False)
v_immut = nnx.with_vars(v_ref, mutable=False)
assert not v_immut.ref
print("immutable =", v_immut)

v_ref = nnx.vars_as(v_immut, mutable=True)
v_ref = nnx.with_vars(v_immut, mutable=True)
assert v_ref.ref
print("mutable =", v_ref)
```
Expand Down Expand Up @@ -176,7 +176,7 @@ def train_step(model, optimizer, x, y):
model = nnx.merge(graphdef, params, nondiff)
return ((model(x) - y) ** 2).mean()

loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad
loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad
optimizer.update(model, grads)

return loss
Expand Down Expand Up @@ -226,9 +226,9 @@ except Exception as e:
```{code-cell} ipython3
@jax.jit
def create_model(rngs):
return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)
return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False)

model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)
model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True)

print("model.linear =", model.linear)
```
Expand Down
2 changes: 1 addition & 1 deletion examples/nnx_toy_examples/hijax_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def loss_fn(params):
model = nnx.merge(graphdef, params, nondiff)
return jnp.mean((y - model(x)) ** 2)

grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False))
grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False))
optimizer.update(model, grads)

@jax.jit
Expand Down
2 changes: 1 addition & 1 deletion examples/nnx_toy_examples/hijax_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def loss_fn(params):

# For the time being we have to use 'immutable'
# as 'jax.grad' doesn't support QDD types yet.
grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False))
grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False))
# 'update' mutates the optimizer's state and the params in place
# so we don't need to return anything 🚀
optimizer.update(params, grads)
Expand Down
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from .graphlib import MergeContext as MergeContext
from .graphlib import merge_context as merge_context
from .graphlib import variables as variables
from .graphlib import with_vars as with_vars
from .graphlib import vars_as as vars_as
from .graphlib import as_pure as as_pure
from .graphlib import pure as pure
Expand Down
18 changes: 11 additions & 7 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2759,7 +2759,12 @@ def map(
A :class:`State` with the mapped values.
"""
graphdef, state = split(node, graph=graph)
state = statelib.map_state(f, state)
if isinstance(state, statelib.State):
state = statelib.map_state(f, state)
else:
# If node is a Variable, then state isn't actually a State. It stays a variable.
# This handles that very unintuitive behavior.
state = f((), state)
return merge(graphdef, state, recreate_variables=recreate_variables)


Expand Down Expand Up @@ -2899,7 +2904,7 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N
return merge(graphdef, state, copy=variables)


def vars_as(
def with_vars(
node: A,
/,
*,
Expand Down Expand Up @@ -2937,18 +2942,17 @@ def _different_vars(path, x):
duplicates_strs += '\n ---'
raise ValueError(f'Found duplicate at paths:{duplicates_strs}')

def _to_refs(jax_path, x):
if predicate(jax_to_nnx_path(jax_path), x):
def _to_refs(path, x):
if predicate(path, x):
assert isinstance(x, Variable)
variable = x.copy(**new_attrs)
return variable
return x

node = jax.tree.map_with_path(
_to_refs, node, is_leaf=lambda x: isinstance(x, Variable)
)
node = map(_to_refs, node, recreate_variables=False)
return node

vars_as = deprecated(with_vars)

def as_pure(tree: A) -> A:
"""Returns a new tree with all ``Variable`` objects replaced with inner values.
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def as_abstract(
When creating models with :func:`eval_shape`, Variables are abstract
(backed by ``jax.ShapeDtypeStruct``) and may not carry sharding
information, especially when using meshes with
:attr:`jax.sharding.AxisType.Auto` axes. ``abstract_with_sharding`` inspects each
:attr:`jax.sharding.AxisType.Auto` axes. ``as_abstract`` inspects each
Variable in ``tree`` and, if it has ``out_sharding`` metadata but no
sharding already set, attaches a :class:`jax.sharding.NamedSharding`
derived from the Variable's ``out_sharding`` and either its ``mesh``
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def loss_fn(params):
return ((model(x) - y) ** 2).mean() # call methods directly

loss, grads = jax.value_and_grad(loss_fn)(
nnx.vars_as(params, hijax=False)
nnx.with_vars(params, hijax=False)
)
optimizer.update(model, grads) # in-place updates

Expand Down
Loading
Loading