[PyTorch] Avoid autograd's gradient accumulation in grouped MLP if possible#2871
[PyTorch] Avoid autograd's gradient accumulation in grouped MLP if possible#2871ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch L0 |
Greptile SummaryThis PR introduces a
Confidence Score: 3/5Not safe to merge as-is — the One P1 correctness bug: the transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py — specifically lines 144–156 in Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["_compute_grad_params called"] --> B{weight_requires_grad?}
B -- No --> Z["w_list = [None]"]
B -- Yes --> C{single_grouped_weight?}
C -- Yes --> D{_accumulate_into_main_grad?}
D -- No --> E["Create fresh GroupedTensor"]
E --> F["GEMM → grouped_wgrad"]
F --> G["packed_wgrad = rowwise_data.view(...)"]
G --> H["w_list = [packed_wgrad]"]
D -- Yes --> I["Get main_grad\naccumulate_into_main_grad = not overwrite_main_grad"]
I --> J{accumulate_into_main_grad?}
J -- Yes --> K["grouped_wgrad from main_grad\nGEMM accumulate=True"]
K --> L["packed_wgrad = rowwise_data.view(...)"]
L --> M{hasattr weight_param\ngrad_added_to_main_grad?}
M -- Yes --> N["grad_added_to_main_grad=True\npacked_wgrad = dummy_wgrad\nw_list = [dummy]"]
M -- No --> O["⚠️ packed_wgrad = main_grad view\nw_list = [main_grad view]\n← BUG: should be None"]
J -- No --> P["Create fresh tensor\nGEMM accumulate=False"]
P --> H
C -- No --> Q{_accumulate_into_main_grad?}
Q -- Yes --> R["w_list[idx]=main_grad[idx]"]
R --> S["GEMM into main_grad\nw_list=[None]*n\nOverride with dummy if attr present"]
Q -- No --> T["w_list[idx]=fresh tensors\nReturn w_list"]
|
| if delay_wgrad or accumulate_into_main_grad: | ||
| w_list = [None] * num_groups |
There was a problem hiding this comment.
Is this correct? I recall we needed to return a dummy wgrad because that affects how grad callbacks are applied, and we got race conditions if we returned None.
I'm also wary about only making this change in this particular fused op. If we are changing our contract with Mcore, we should apply it consistently in the linear modules and linear ops.
There was a problem hiding this comment.
Yes that's correct and it doesn't change here. In the very next line we have the same logic as before: https://github.com/ksivaman/TransformerEngine-1/blob/f7a28865b1e9e263ae92c6dc97604280093ebc8f/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py/#L160-L169
There was a problem hiding this comment.
Doesn't that logic make the or accumulate_into_main_grad here redundant?
There was a problem hiding this comment.
We still want to return a None gradient unless megatron explicitly sets grad_added_to_main_grad attr to the weight, in which case we supply a dummy grad.
There was a problem hiding this comment.
It's a bit convoluted, but it seems that the Linear module also returns None if delayed wgrad is enabled and fuse_wgrad_accumulation=False:
TransformerEngine/transformer_engine/pytorch/module/linear.py
Lines 939 to 950 in dc92b39
The GroupedLinear op has similar behavior with the same combination:
TransformerEngine/transformer_engine/pytorch/ops/basic/grouped_linear.py
Lines 1039 to 1051 in dc92b39
The implementations are all inconsistent when delayed wgrad is enabled and main_grad accumulation is enabled (Linear module returns dummy tensor, GroupedLinear op returns None, this fused op returns dummy tensor).
There was a problem hiding this comment.
Ok, I think this logic matches the Linear module. There is an edge case where the unfused GroupedLinear op is inconsistent.
| Config | Linear module |
GroupedLinear op |
Fused grouped MLP |
|---|---|---|---|
| Vanilla | Return wgrad | Return wgrad | Return wgrad |
| Delayed wgrad | Return None |
Return None |
Return None |
Accumulate into main_grad, without grad_added_to_main_grad |
Return None |
Return None |
Return None |
Accumulate into main_grad with grad_added_to_main_grad |
Return dummy | Return dummy | Return dummy |
Delayed wgrad, accumulate into main_grad, without grad_added_to_main_grad |
Return None |
Return None |
Return None |
Delayed wgrad, accumulate into main_grad with grad_added_to_main_grad |
Return dummy | Return None |
Return dummy |
Description
If
.gradfield of a param is notNone(many cases such as externally maintained grads for various optimizations, multiple microbatches etc.), pytorch by default accumulates gradients resulting in expensive copy and adds.Type of change
Changes
Nonegradient when possible in grouped MLP backward.Checklist: