Cute Dsl kernel for Wgrad for Fused MOE Layer#2869
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR routes the weight-gradient (wgrad) GEMM for the fused MXFP8 MOE backward pass through a new CuTe DSL kernel (
Confidence Score: 4/5Safe to merge after addressing the missing ImportError guard in One P1 finding: the wgrad kernel method has no try/except, so a symbol-not-found condition would crash the backward pass with no cuBLAS fallback. The fix is a two-line try/except wrapper. All other findings are P2 style concerns. transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py — Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fuser_backward] --> B[_compute_grad_params FC2 wgrad]
A --> C[_compute_grad_params FC1 wgrad]
B --> D{cudnn_wgrad_kernel_fn not None?}
C --> D2{cudnn_wgrad_kernel_fn not None?}
D -->|Yes| E[functools.partial _cudnn_compute_wgrad]
D -->|No| F[functools.partial cuBLAS fallback]
D2 -->|Yes| E2[functools.partial _cudnn_compute_wgrad]
D2 -->|No| F2[functools.partial cuBLAS fallback]
E --> G{delay_wgrad?}
G -->|Yes| H[wgrad_store.put deferred]
G -->|No| I[_cudnn_compute_wgrad]
I --> J{single_grouped_weight?}
J -->|Yes| K[output_mode=dense]
J -->|No| L[output_mode=discrete]
subgraph version_gate[grouped_gemm_wgrad_kernel]
M{supports_wgrad?} -->|False| N[return None]
M -->|True| O[import grouped_gemm_wgrad_wrapper_sm100]
end
Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…/TransformerEngine into users/vthumbe/wgrad_cute_dsl
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: