Skip to content

Cute Dsl kernel for Wgrad for Fused MOE Layer#2869

Merged
vthumbe1503 merged 10 commits intoNVIDIA:mainfrom
vthumbe1503:users/vthumbe/wgrad_cute_dsl
Apr 13, 2026
Merged

Cute Dsl kernel for Wgrad for Fused MOE Layer#2869
vthumbe1503 merged 10 commits intoNVIDIA:mainfrom
vthumbe1503:users/vthumbe/wgrad_cute_dsl

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 3 commits April 13, 2026 02:46
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title Users/vthumbe/wgrad cute dsl Cute Dsl for Wgrad for Fused MOE Layer Apr 13, 2026
@vthumbe1503 vthumbe1503 changed the title Cute Dsl for Wgrad for Fused MOE Layer Cute Dsl kernel for Wgrad for Fused MOE Layer Apr 13, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 13, 2026

Greptile Summary

This PR routes the weight-gradient (wgrad) GEMM for the fused MXFP8 MOE backward pass through a new CuTe DSL kernel (grouped_gemm_wgrad_wrapper_sm100) on SM100+ hardware, with an automatic cuBLAS fallback when the kernel is unavailable. The core additions are _cudnn_compute_wgrad (handles dense / discrete per-expert modes) and a refactored _compute_grad_params that selects between the CuTe kernel and the legacy general_grouped_gemm_for_grouped_tensor path.

  • P1 — unhandled ImportError on wgrad kernel import: grouped_gemm_wgrad_kernel() contains no try/except, so if grouped_gemm_wgrad_wrapper_sm100 is absent despite the >= 1.23.0 version check, the backward pass crashes instead of falling back to cuBLAS (see inline comment).

Confidence Score: 4/5

Safe to merge after addressing the missing ImportError guard in grouped_gemm_wgrad_kernel().

One P1 finding: the wgrad kernel method has no try/except, so a symbol-not-found condition would crash the backward pass with no cuBLAS fallback. The fix is a two-line try/except wrapper. All other findings are P2 style concerns.

transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py — grouped_gemm_wgrad_kernel() classmethod needs ImportError handling.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/_common.py Adds _nvidia_cudnn_frontend_supports_wgrad() version-gate for the new wgrad kernel; its body is identical to the existing _scaled_clamped_qgeglu check (same >= 1.23.0 threshold), which is potentially fragile if the wgrad symbol was added at a different version.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Introduces _cudnn_compute_wgrad and refactors _compute_grad_params to dispatch wgrad to the CuTe DSL kernel or fall back to cuBLAS. grouped_gemm_wgrad_kernel() lacks try/except around the import, so a symbol-not-found ImportError propagates to the backward pass with no fallback.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fuser_backward] --> B[_compute_grad_params FC2 wgrad]
    A --> C[_compute_grad_params FC1 wgrad]
    B --> D{cudnn_wgrad_kernel_fn not None?}
    C --> D2{cudnn_wgrad_kernel_fn not None?}
    D -->|Yes| E[functools.partial _cudnn_compute_wgrad]
    D -->|No| F[functools.partial cuBLAS fallback]
    D2 -->|Yes| E2[functools.partial _cudnn_compute_wgrad]
    D2 -->|No| F2[functools.partial cuBLAS fallback]
    E --> G{delay_wgrad?}
    G -->|Yes| H[wgrad_store.put deferred]
    G -->|No| I[_cudnn_compute_wgrad]
    I --> J{single_grouped_weight?}
    J -->|Yes| K[output_mode=dense]
    J -->|No| L[output_mode=discrete]
    subgraph version_gate[grouped_gemm_wgrad_kernel]
        M{supports_wgrad?} -->|False| N[return None]
        M -->|True| O[import grouped_gemm_wgrad_wrapper_sm100]
    end
Loading

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…/TransformerEngine into users/vthumbe/wgrad_cute_dsl
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
timmoon10
timmoon10 previously approved these changes Apr 13, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 merged commit 72328b3 into NVIDIA:main Apr 13, 2026
21 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants