Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from .graphlib import merge_context as merge_context
from .graphlib import variables as variables
from .graphlib import vars_as as vars_as
from .graphlib import as_pure as as_pure
from .graphlib import pure as pure
from .graphlib import cached_partial as cached_partial
from .graphlib import flatten as flatten
Expand Down
8 changes: 5 additions & 3 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flax import config
from flax.nnx import filterlib, reprlib, traversals, variablelib
from flax.nnx import statelib
from flax.nnx.deprecations import deprecated
from flax.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
Expand Down Expand Up @@ -2764,7 +2765,7 @@ def _to_refs(jax_path, x):
return node


def pure(tree: A) -> A:
def as_pure(tree: A) -> A:
"""Returns a new tree with all ``Variable`` objects replaced with inner values.

This can be used to remove Variable metadata when its is not needed for tasks like
Expand All @@ -2787,7 +2788,7 @@ def pure(tree: A) -> A:
value=(2, 3)
)
})
>>> pure_state = nnx.pure(state)
>>> pure_state = nnx.as_pure(state)
>>> jax.tree.map(jnp.shape, pure_state)
State({
'bias': (3,),
Expand All @@ -2803,7 +2804,7 @@ def pure(tree: A) -> A:

def _pure_fn(x):
if isinstance(x, Variable):
return pure(x.get_raw_value())
return as_pure(x.get_raw_value())
elif variablelib.is_array_ref(x):
return x[...]
return x
Expand All @@ -2814,6 +2815,7 @@ def _pure_fn(x):
is_leaf=lambda x: isinstance(x, Variable),
)

pure = deprecated(as_pure)

def call(
graphdef_state: tuple[GraphDef[A], GraphState], /
Expand Down
8 changes: 4 additions & 4 deletions flax/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ def update(self, model: M, grads, /, **kwargs) -> optax.Updates:
The updates PyTree containing the parameter updates applied to the model.
This matches the structure of the model parameters filtered by ``wrt``.
"""
param_arrays = nnx.pure(nnx.state(model, self.wrt))
grad_arrays = nnx.pure(nnx.state(grads, self.wrt))
opt_state_arrays = nnx.pure(self.opt_state)
kwargs_arrays = nnx.pure(kwargs)
param_arrays = nnx.as_pure(nnx.state(model, self.wrt))
grad_arrays = nnx.as_pure(nnx.state(grads, self.wrt))
opt_state_arrays = nnx.as_pure(self.opt_state)
kwargs_arrays = nnx.as_pure(kwargs)

updates, new_opt_state = self.tx.update(
grad_arrays, opt_state_arrays, param_arrays, **kwargs_arrays
Expand Down
8 changes: 4 additions & 4 deletions tests/nnx/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def loss_fn(model):
self.assertIsNotNone(updates)

# Verify updates structure matches params structure
params = nnx.pure(nnx.state(model, nnx.Param))
params = nnx.as_pure(nnx.state(model, nnx.Param))

def check_structure(path, update_val, param_val):
self.assertEqual(update_val.shape, param_val.shape)
Expand All @@ -362,7 +362,7 @@ def test_updates_match_param_changes(self):
optimizer = nnx.Optimizer(model, optax.sgd(1.0), wrt=nnx.Param)

# Get initial params as pure arrays
initial_params = nnx.pure(nnx.state(model, nnx.Param))
initial_params = nnx.as_pure(nnx.state(model, nnx.Param))

def loss_fn(model):
params = nnx.state(model)
Expand All @@ -375,7 +375,7 @@ def loss_fn(model):
updates = optimizer.update(model, grads)

# Get new params as pure arrays
new_params = nnx.pure(nnx.state(model, nnx.Param))
new_params = nnx.as_pure(nnx.state(model, nnx.Param))

# Verify: new_params = initial_params + updates (within optax.apply_updates)
def check_update(old, update, new):
Expand Down Expand Up @@ -403,7 +403,7 @@ def loss_fn(model):
self.assertIsNotNone(updates)

# Verify updates has the expected structure
params = nnx.pure(nnx.state(model, nnx.Param))
params = nnx.as_pure(nnx.state(model, nnx.Param))
def check_structure(path, update_val, param_val):
self.assertEqual(update_val.shape, param_val.shape)

Expand Down
Loading