Skip to content
This repository was archived by the owner on Nov 19, 2025. It is now read-only.
This repository was archived by the owner on Nov 19, 2025. It is now read-only.

Memory inefficiency when loading attention_mask, causing dataloader OOM with long context #488

@shensimeteor

Description

@shensimeteor

Is your feature request related to a problem? Please describe.

In GPTSFTChatDataset::collate_fn, it seems return a huge tensor of attention_mask (code:
https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py#L380-L381)

The tensor has shape is like (n, 1, seq, seq). n is global_batch_size divided by data_parallel_size (aka, micro_batch_size * gradient_accumulation_steps, if my understanding is correct), seq is max context length in the batch. Therefore the tensor will be very huge when context length is long and when we have large gradient_accumulation_steps.

There're lots of wasted memory of the tensor: all subtensors of shape (1, seq, seq) are the same. So it's possible to reduce this tensor size by 1/n. It will help a lot with long context training.

Describe the solution you'd like

I'm not very familiar with the codebase so my thought could be wrong here. In GPTSFTChatDataset::collate_fn, we can just use a tensor of shape (1, seq, seq) to save the attention_mask. Then we need to modify or extend the get_iterator_k_split method used in GPTSFTModel (https://github.com/NVIDIA/NeMo-Aligner/blob/main/nemo_aligner/models/nlp/gpt/gpt_sft_model.py#L86). In the modified or extended version, instead of splitting attention_mask, we just create the actual (micro_batch, 1, seq, seq) of the attention_mask.

Describe alternatives you've considered

No

Additional context

No

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions