Skip to content

metrics does not work well with vmap #5483

@narang99

Description

@narang99

flax.metrics.Average does not work well with vmap.

consider code like this:

@nnx.vmap(in_axes=(0, None, None, None, None, None), out_axes=0)
def parallel_init_models(
    seeds, input_dim, n_components, scaled_hyperparameters, init_strategy, lr
):
    return init_single_model(
        seeds, input_dim, n_components, scaled_hyperparameters, init_strategy, lr
    )


def init_single_model(
    s, input_dim, n_components, scaled_hyperparameters, init_strategy, lr
):
    rngs = nnx.Rngs(s)
    model = Autoencoder(input_dim, n_components, rngs)
    # Initialize model parameters using shared hyperparameters
    optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)
    metrics = nnx.MultiMetric(
        loss=nnx.metrics.Average("loss"),
        recon_loss=nnx.metrics.Average("recon_loss"),
        weight_loss=nnx.metrics.Average("weight_loss"),
        codes_loss=nnx.metrics.Average("codes_loss"),
        mse=nnx.metrics.Average("mse"),
    )
    return model, optimizer, metrics

parallel_init_seeds is used to create multiple models, which helps me train different seeds for the same input.
The shapes in the beginning are good for metrics, models and optimizers.

Running metrics.reset breaks the metrics though.
Before reset

MultiMetric( # MetricState: 30 (120 B)
  codes_loss=Average( # MetricState: 6 (24 B)
    argname='codes_loss',
    count=MetricState( # 3 (12 B)
      value=Array(shape=(3,), dtype=dtype('int32'))
    ),
    total=MetricState( # 3 (12 B)
      value=Array(shape=(3,), dtype=dtype('float32'))
    )
  )
...

After calling metrics.reset(), it becomes

MultiMetric( # MetricState: 10 (40 B)
  codes_loss=Average( # MetricState: 2 (8 B)
    argname='codes_loss',
    count=MetricState( # 1 (4 B)
      value=Array(0, dtype=int32)
    ),
    total=MetricState( # 1 (4 B)
      value=Array(0., dtype=float32)
    )
  ),
  ...

Note the shape difference. The code also seems to assign it to a static 0 array https://github.com/google/flax/blob/main/flax/nnx/training/metrics.py#L97

I'm sorry if this is not a bug and if users are recommended to maintain multiple metrics for multiple models.

System information

Tested in a colab notebook

  • OS info
PRETTY_NAME="Ubuntu 22.04.5 LTS"
NAME="Ubuntu"
VERSION_ID="22.04"
VERSION="22.04.5 LTS (Jammy Jellyfish)"
VERSION_CODENAME=jammy
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=jammy
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib):
Name: flax
Version: 0.12.0
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author: 
Author-email: Flax team <[flax-dev@google.com](mailto:flax-dev@google.com)>
License: 
Location: /usr/local/lib/python3.12/dist-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, treescope, typing_extensions
Required-by: dopamine_rl, pt-to-api
---
Name: jax
Version: 0.7.2
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: [jax-dev@google.com](mailto:jax-dev@google.com)
License: Apache-2.0
Location: /usr/local/lib/python3.12/dist-packages
Requires: jaxlib, ml_dtypes, numpy, opt_einsum, scipy
Required-by: dopamine_rl, flax, optax, orbax-checkpoint, pt-to-api
---
Name: jaxlib
Version: 0.7.2
Summary: XLA library for JAX
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: [jax-dev@google.com](mailto:jax-dev@google.com)
License: Apache-2.0
Location: /usr/local/lib/python3.12/dist-packages
Requires: ml_dtypes, numpy, scipy
Required-by: dopamine_rl, jax, optax
  • Python version: Python 3.12.13
  • GPU/TPU model and memory:
    nvidia-smi
Thu Jun  4 18:46:48 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   53C    P0             30W /   70W |   11317MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            1132      C   /usr/bin/python3                      11314MiB |
+-----------------------------------------------------------------------------------------+

What you expected to happen:

metrics should work well with vmap? Although if it is not supposed to, some documentation might be helpful.

Logs, error messages, etc:

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions