Skip to content

[PyTorch] Avoid autograd's gradient accumulation in grouped MLP if possible#2871

Open
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:rm_unnecessary_wgrad_kernels
Open

[PyTorch] Avoid autograd's gradient accumulation in grouped MLP if possible#2871
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:rm_unnecessary_wgrad_kernels

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

If .grad field of a param is not None (many cases such as externally maintained grads for various optimizations, multiple microbatches etc.), pytorch by default accumulates gradients resulting in expensive copy and adds.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring
  • Performance

Changes

  • Return None gradient when possible in grouped MLP backward.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 13, 2026 20:35
@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci pytorch L0

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 13, 2026

Greptile Summary

This PR introduces a _compute_grad_params helper function to the fused CuDNN BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8 backward pass, adding the accumulate_into_main_grad / delay_wgrad optimizations that already existed in the non-fused GroupedLinear.fuser_backward. The intent is to return None to autograd when gradients are accumulated directly into main_grad, avoiding PyTorch's expensive in-place grad += wgrad operation.

  • In the single_grouped_weight=True + accumulate_into_main_grad=True path, when the weight parameter lacks the grad_added_to_main_grad attribute, packed_wgrad is set to a live view of main_grad and returned to autograd rather than None. This is the opposite of what the non-fused GroupedLinear.fuser_backward does in the same branch, and can cause double-accumulation of the weight gradient across microbatches (weight.grad += main_grad when weight.grad is already non-None).

Confidence Score: 3/5

Not safe to merge as-is — the single_grouped_weight=True + accumulate_into_main_grad=True path can return a live main_grad view to autograd instead of None, causing double gradient accumulation across microbatches.

One P1 correctness bug: the else: packed_wgrad = None branch is missing in _compute_grad_params for the single_grouped_weight case, diverging from the established non-fused GroupedLinear.fuser_backward behavior and introducing the exact gradient-accumulation problem the PR is trying to solve.

transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py — specifically lines 144–156 in _compute_grad_params.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Adds _compute_grad_params helper for the fused CuDNN grouped MLP backward; has a divergence from the non-fused path: when single_grouped_weight=True, accumulate_into_main_grad=True, and grad_added_to_main_grad is absent on the param, a live view of main_grad is returned to autograd instead of None, risking double gradient accumulation.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["_compute_grad_params called"] --> B{weight_requires_grad?}
    B -- No --> Z["w_list = [None]"]
    B -- Yes --> C{single_grouped_weight?}
    C -- Yes --> D{_accumulate_into_main_grad?}
    D -- No --> E["Create fresh GroupedTensor"]
    E --> F["GEMM → grouped_wgrad"]
    F --> G["packed_wgrad = rowwise_data.view(...)"]
    G --> H["w_list = [packed_wgrad]"]
    D -- Yes --> I["Get main_grad\naccumulate_into_main_grad = not overwrite_main_grad"]
    I --> J{accumulate_into_main_grad?}
    J -- Yes --> K["grouped_wgrad from main_grad\nGEMM accumulate=True"]
    K --> L["packed_wgrad = rowwise_data.view(...)"]
    L --> M{hasattr weight_param\ngrad_added_to_main_grad?}
    M -- Yes --> N["grad_added_to_main_grad=True\npacked_wgrad = dummy_wgrad\nw_list = [dummy]"]
    M -- No --> O["⚠️ packed_wgrad = main_grad view\nw_list = [main_grad view]\n← BUG: should be None"]
    J -- No --> P["Create fresh tensor\nGEMM accumulate=False"]
    P --> H
    C -- No --> Q{_accumulate_into_main_grad?}
    Q -- Yes --> R["w_list[idx]=main_grad[idx]"]
    R --> S["GEMM into main_grad\nw_list=[None]*n\nOverride with dummy if attr present"]
    Q -- No --> T["w_list[idx]=fresh tensors\nReturn w_list"]
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py, line 144-156 (link)

    P1 Missing None return when accumulate_into_main_grad and no grad_added_to_main_grad

    When accumulate_into_main_grad=True but weight_param does not have the grad_added_to_main_grad attribute, packed_wgrad is set to grouped_wgrad.rowwise_data.view(...) — a live view into main_grad — and is returned unchanged to autograd. Autograd then accumulates it into weight.grad, double-counting the gradient that was already written directly into main_grad by the GEMM.

    The non-fused GroupedLinear.fuser_backward (grouped_linear.py lines 1006–1017) explicitly returns None in this branch:

    if accumulate_into_main_grad:
        …
        if hasattr(weight_param, "grad_added_to_main_grad"):
            …grad_weight = get_dummy_wgrad(…)
        else:
            grad_weight = None   # ← non-fused always returns None here

    The fused helper is missing the symmetric else branch:

    This matches the non-fused path and avoids the spurious accumulation into .grad across microbatches.

Reviews (1): Last reviewed commit: "Avoid grad accumulation when not needed" | Re-trigger Greptile

Comment on lines +158 to +159
if delay_wgrad or accumulate_into_main_grad:
w_list = [None] * num_groups
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? I recall we needed to return a dummy wgrad because that affects how grad callbacks are applied, and we got race conditions if we returned None.

I'm also wary about only making this change in this particular fused op. If we are changing our contract with Mcore, we should apply it consistently in the linear modules and linear ops.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't that logic make the or accumulate_into_main_grad here redundant?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want to return a None gradient unless megatron explicitly sets grad_added_to_main_grad attr to the weight, in which case we supply a dummy grad.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit convoluted, but it seems that the Linear module also returns None if delayed wgrad is enabled and fuse_wgrad_accumulation=False:

if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
if (
wgrad_gemm_kwargs["ub"] is not None
or wgrad_gemm_kwargs["ub_type"] is not None
or wgrad_gemm_kwargs["extra_output"] is not None
or wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)

The GroupedLinear op has similar behavior with the same combination:

if self.single_grouped_weight:
grad_weight = None
if ctx.weight_requires_grad:
if delay_wgrad:
grad_weight = None
else:
grad_weight = torch.stack(grad_weights, dim=0)
final_weight_grads = [grad_weight]
else:
if delay_wgrad and ctx.weight_requires_grad:
final_weight_grads = [None] * num_groups
else:
final_weight_grads = grad_weights

The implementations are all inconsistent when delayed wgrad is enabled and main_grad accumulation is enabled (Linear module returns dummy tensor, GroupedLinear op returns None, this fused op returns dummy tensor).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it is not redundant

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think this logic matches the Linear module. There is an edge case where the unfused GroupedLinear op is inconsistent.

Config Linear module GroupedLinear op Fused grouped MLP
Vanilla Return wgrad Return wgrad Return wgrad
Delayed wgrad Return None Return None Return None
Accumulate into main_grad, without grad_added_to_main_grad Return None Return None Return None
Accumulate into main_grad with grad_added_to_main_grad Return dummy Return dummy Return dummy
Delayed wgrad, accumulate into main_grad, without grad_added_to_main_grad Return None Return None Return None
Delayed wgrad, accumulate into main_grad with grad_added_to_main_grad Return dummy Return None Return dummy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants