Skip to content

feat(nnx): Missing object-oriented pooling layers in NNX #5202

@divye-joshi

Description

@divye-joshi

Problem Description

Currently, flax.nnx lacks native object-oriented pooling modules such as MaxPool, AvgPool, and GlobalAveragePool. Users migrating from frameworks like Keras or PyTorch—or even transitioning from flax.linen—are forced to mix functional API calls within the object-oriented NNX structure. This creates an inconsistent developer experience and requires manual boilerplate for common operations like Global Average Pooling.

Proposed Feature

Introduce a dedicated pooling module suite within nnx that mirrors the ergonomic design of other NNX layers. This includes:

  • Subsampling Modules: MaxPool, AvgPool, and MinPool.
  • Global Pooling: A dedicated GlobalAveragePool module to replace manual jnp.mean calls.

Implementation Status

I have already implemented these modules and exposed them in the nnx namespace.
See Pull Request: #5201

Justification & Benefits

  1. API Consistency: Maintains the OO-flow of NNX without jumping back into linen.functional.
  2. Framework Parity: Lowers the barrier for users migrating from Keras/PyTorch.
  3. Readability: Simplifies model definitions, especially for standard CNN architectures.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions