Skip to content

Adding transformer encoder and decoder layers to flax source as in pytorch #5176

@coder0143

Description

@coder0143

The pytorch source consists of implementations of wrappers for transformer modules.

SRC: https://github.com/pytorch/pytorch/blob/v2.9.1/torch/nn/modules/transformer.py#L966

I want to add such implementation for ease of use / ux. I will make a new file: flax/flax/nnx/nn/transformer.py which will contain the following modules:

  • TransformerEncoderLayer
  • TransformerEncoder
  • TransformerDecoderLayer
  • TransformerDecoder
  • Transformer

I will keep it consistent with nnx.Linear and nnx.MultiHeadAttention modules and update the docs too, if needed, I can implement custom separate attentions such as MHSA(for full) and GQA(with kv-cache) based on review. Can I do a PR? @cgarciae @vfdev-5

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions