flax.metrics.Average does not work well with vmap.
consider code like this:
@nnx.vmap(in_axes=(0, None, None, None, None, None), out_axes=0)
def parallel_init_models(
seeds, input_dim, n_components, scaled_hyperparameters, init_strategy, lr
):
return init_single_model(
seeds, input_dim, n_components, scaled_hyperparameters, init_strategy, lr
)
def init_single_model(
s, input_dim, n_components, scaled_hyperparameters, init_strategy, lr
):
rngs = nnx.Rngs(s)
model = Autoencoder(input_dim, n_components, rngs)
# Initialize model parameters using shared hyperparameters
optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)
metrics = nnx.MultiMetric(
loss=nnx.metrics.Average("loss"),
recon_loss=nnx.metrics.Average("recon_loss"),
weight_loss=nnx.metrics.Average("weight_loss"),
codes_loss=nnx.metrics.Average("codes_loss"),
mse=nnx.metrics.Average("mse"),
)
return model, optimizer, metrics
parallel_init_seeds is used to create multiple models, which helps me train different seeds for the same input.
The shapes in the beginning are good for metrics, models and optimizers.
Running metrics.reset breaks the metrics though.
Before reset
MultiMetric( # MetricState: 30 (120 B)
codes_loss=Average( # MetricState: 6 (24 B)
argname='codes_loss',
count=MetricState( # 3 (12 B)
value=Array(shape=(3,), dtype=dtype('int32'))
),
total=MetricState( # 3 (12 B)
value=Array(shape=(3,), dtype=dtype('float32'))
)
)
...
After calling metrics.reset(), it becomes
MultiMetric( # MetricState: 10 (40 B)
codes_loss=Average( # MetricState: 2 (8 B)
argname='codes_loss',
count=MetricState( # 1 (4 B)
value=Array(0, dtype=int32)
),
total=MetricState( # 1 (4 B)
value=Array(0., dtype=float32)
)
),
...
Note the shape difference. The code also seems to assign it to a static 0 array https://github.com/google/flax/blob/main/flax/nnx/training/metrics.py#L97
I'm sorry if this is not a bug and if users are recommended to maintain multiple metrics for multiple models.
System information
Tested in a colab notebook
PRETTY_NAME="Ubuntu 22.04.5 LTS"
NAME="Ubuntu"
VERSION_ID="22.04"
VERSION="22.04.5 LTS (Jammy Jellyfish)"
VERSION_CODENAME=jammy
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=jammy
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib):
Name: flax
Version: 0.12.0
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author:
Author-email: Flax team <[flax-dev@google.com](mailto:flax-dev@google.com)>
License:
Location: /usr/local/lib/python3.12/dist-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, treescope, typing_extensions
Required-by: dopamine_rl, pt-to-api
---
Name: jax
Version: 0.7.2
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: [jax-dev@google.com](mailto:jax-dev@google.com)
License: Apache-2.0
Location: /usr/local/lib/python3.12/dist-packages
Requires: jaxlib, ml_dtypes, numpy, opt_einsum, scipy
Required-by: dopamine_rl, flax, optax, orbax-checkpoint, pt-to-api
---
Name: jaxlib
Version: 0.7.2
Summary: XLA library for JAX
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: [jax-dev@google.com](mailto:jax-dev@google.com)
License: Apache-2.0
Location: /usr/local/lib/python3.12/dist-packages
Requires: ml_dtypes, numpy, scipy
Required-by: dopamine_rl, jax, optax
- Python version:
Python 3.12.13
- GPU/TPU model and memory:
nvidia-smi
Thu Jun 4 18:46:48 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 53C P0 30W / 70W | 11317MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1132 C /usr/bin/python3 11314MiB |
+-----------------------------------------------------------------------------------------+
What you expected to happen:
metrics should work well with vmap? Although if it is not supposed to, some documentation might be helpful.
Logs, error messages, etc:
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
flax.metrics.Average does not work well with vmap.
consider code like this:
parallel_init_seedsis used to create multiple models, which helps me train different seeds for the same input.The shapes in the beginning are good for metrics, models and optimizers.
Running
metrics.resetbreaks the metrics though.Before reset
After calling
metrics.reset(), it becomesNote the shape difference. The code also seems to assign it to a static 0 array https://github.com/google/flax/blob/main/flax/nnx/training/metrics.py#L97
I'm sorry if this is not a bug and if users are recommended to maintain multiple metrics for multiple models.
System information
Tested in a colab notebook
pip show flax jax jaxlib):Python 3.12.13nvidia-smiWhat you expected to happen:
metricsshould work well with vmap? Although if it is not supposed to, some documentation might be helpful.Logs, error messages, etc: