Problem Description
Currently, flax.nnx lacks native object-oriented pooling modules such as MaxPool, AvgPool, and GlobalAveragePool. Users migrating from frameworks like Keras or PyTorch—or even transitioning from flax.linen—are forced to mix functional API calls within the object-oriented NNX structure. This creates an inconsistent developer experience and requires manual boilerplate for common operations like Global Average Pooling.
Proposed Feature
Introduce a dedicated pooling module suite within nnx that mirrors the ergonomic design of other NNX layers. This includes:
- Subsampling Modules:
MaxPool, AvgPool, and MinPool.
- Global Pooling: A dedicated
GlobalAveragePool module to replace manual jnp.mean calls.
Implementation Status
I have already implemented these modules and exposed them in the nnx namespace.
See Pull Request: #5201
Justification & Benefits
- API Consistency: Maintains the OO-flow of NNX without jumping back into
linen.functional.
- Framework Parity: Lowers the barrier for users migrating from Keras/PyTorch.
- Readability: Simplifies model definitions, especially for standard CNN architectures.
Problem Description
Currently,
flax.nnxlacks native object-oriented pooling modules such asMaxPool,AvgPool, andGlobalAveragePool. Users migrating from frameworks like Keras or PyTorch—or even transitioning fromflax.linen—are forced to mix functional API calls within the object-oriented NNX structure. This creates an inconsistent developer experience and requires manual boilerplate for common operations like Global Average Pooling.Proposed Feature
Introduce a dedicated pooling module suite within
nnxthat mirrors the ergonomic design of other NNX layers. This includes:MaxPool,AvgPool, andMinPool.GlobalAveragePoolmodule to replace manualjnp.meancalls.Implementation Status
I have already implemented these modules and exposed them in the
nnxnamespace.See Pull Request: #5201
Justification & Benefits
linen.functional.