Skip to content

optimizer.update fails with optax.partition #5372

@slievens

Description

@slievens

I am quite new to flax, so I am not 100% sure that I didn't make an obvious mistake.
Small example included at the bottom of this report.

System information

  • OS Platform and Distribution: Alma linux
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib): flax: 0.12.5, jax: 0.9.1, jaxlib: 0.9., optax: 0.2.6
  • Python version: 3.11.14
  • GPU/TPU model and memory:
  • CUDA version (if applicable): 13.1

Problem you have encountered:

The nnx training loop fails with I use a optax.partition as the optimizer. It works fine with another transformation. I can also use the optax.partition optimizer in a JAX training loop.

What you expected to happen:

Training loop should also work with nnx.

Logs, error messages, etc:

ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.variablelib.Param'>, value: JitTracer(float32[2,1]).

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

import jax
import jax.numpy as jnp
import flax
import flax.nnx as nnx
import optax

print("__jax_version__", jax.__version__
      , "__optax_version__", optax.__version__
      , "__flax_version__", flax.__version__)


BATCH_SIZE=8

# X is random data with two features, the target is the sum of those features + 3
X = jax.random.normal(key=jax.random.key(42), shape=(160,2))
y = jnp.sum(X, axis=1) + 3

# A very simple model with one linear layer
class SimpleModel(nnx.Module):

    def __init__(self, *, rngs):
        self.linear = nnx.Linear(in_features=2, out_features=1, rngs=rngs)

    def __call__(self, x):
        return self.linear(x)

model = SimpleModel(rngs=nnx.Rngs(0))

# Define the label function that will be used to assign labels to the parameters for partitioning
def label_fn(path, leaf):
    """ Assigns a label to a parameter based on its path"""
    param_name = path[-2].key   
    if 'bias' in param_name:
        return 'non_decay_group'

    return 'decay_group'

param_labels_pytree = jax.tree.map_with_path(
    label_fn, nnx.state(model, nnx.Param)
)

partitioned_chain = optax.partition(
    transforms={
        'decay_group' : optax.adamw(learning_rate=0.05),
        'non_decay_group' : optax.adam(learning_rate=0.05)
    },
    param_labels=param_labels_pytree
)

jax_optimizer = partitioned_chain


# Training loop with pure JAX --> This works fine
def make_train_step(graphdef, solver):
    def loss_fn(param_state, x, y):
        model = nnx.merge(graphdef, param_state)
        return jnp.mean((model(X) - y)**2)

    @jax.jit
    def train_step(param_state, opt_state, x, y):
        loss, grads = jax.value_and_grad(loss_fn)(param_state, x, y)
        updates, opt_state = solver.update(grads, opt_state, param_state)
        param_state = optax.apply_updates(param_state, updates)
        return param_state, opt_state

    return train_step

graphdef, param_state =  nnx.split(model, nnx.Param)
opt_state = jax_optimizer.init(params=param_state)
train_step = make_train_step(graphdef, jax_optimizer)


for _ in range(5):
    for _idx in range(0, X.shape[0], BATCH_SIZE):
        _x_batch, _y_batch = X[_idx:_idx+BATCH_SIZE], y[_idx:_idx+BATCH_SIZE]       
        param_state, opt_state = train_step(param_state, opt_state, _x_batch, _y_batch)
    
    print(f"MSE on train {jnp.mean((nnx.merge(graphdef, param_state)(X) - y)**2)}")

print("JAX training loop finished. Note. Model still unchanged.")


@nnx.jit
def train_step_nnx(model, optimizer, X, y):
    def loss_fn_nnx(model):
        return jnp.mean((model(X) - y)**2)

    grad_fn = nnx.value_and_grad(loss_fn_nnx)
    loss, grads = grad_fn(model)
    optimizer.update(model, grads)
    return loss

# Recreate just to make sure
partitioned_chain = optax.partition(
    transforms={
        'decay_group' : optax.adamw(learning_rate=0.05),
        'non_decay_group' : optax.adam(learning_rate=0.05)
    },
    param_labels=param_labels_pytree
)

# Choose one. Works with adamw, but not with the partitioned chain
#nnx_optimizer =  nnx.Optimizer(model, tx=optax.adamw(learning_rate=0.05), wrt=nnx.Param)
nnx_optimizer = nnx.Optimizer(model, tx=partitioned_chain, wrt=nnx.Param) 

print("Starting nnx training loop")
for _ in range(5):
    for _idx in range(0, X.shape[0], BATCH_SIZE):
        _x_batch, _y_batch = X[_idx:_idx+BATCH_SIZE], y[_idx:_idx+BATCH_SIZE]       
        loss = train_step_nnx(model, nnx_optimizer, _x_batch, _y_batch)
    
    print(f"MSE on train {jnp.mean((model(X) - y)**2)}")

print("nnx training loop finished")

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions