diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index ec2ed9a367..79a73a9a2c 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -153,7 +153,6 @@ jobs: python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ - python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/ popd L0-py-core-tests: diff --git a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py index 0cbd4b36ac..55dec54ff6 100644 --- a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py +++ b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py @@ -59,7 +59,7 @@ def forward(self, input, weight): torch.randn(weight_shape, device="cuda", dtype=data_type), ] - self.run_test(rmsnorm(), inputs, precision=dtype.f16) + self.run_test(rmsnorm(), inputs) if __name__ == "__main__": diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 051a08f083..6ff27de3da 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -272,7 +272,7 @@ def run_test_with_error(self, mod, inputs, interpreter, expect_error): cuda_inputs.append(i.cuda()) mod.eval() - interpreter.run(precision=torch.float) + interpreter.run() def assert_has_op(self, mod, ops): ops_in_mod = set() @@ -359,7 +359,6 @@ def run_test( inputs, rtol=RTOL, atol=ATOL, - precision=dtype.f32, check_dtype=True, use_dynamo_tracer=None, enable_passes=False, @@ -368,6 +367,8 @@ def run_test( immutable_weights=True, decompose_attention=False, attn_bias_is_causal=True, + require_full_compilation=False, + disable_tf32=False, ): # TODO: lan to remove this and set use_dynamo_traccer to True by default # once all the converter test files are moved to use_dynamo_tracer @@ -379,6 +380,8 @@ def run_test( immutable_weights=immutable_weights, decompose_attention=decompose_attention, attn_bias_is_causal=attn_bias_is_causal, + require_full_compilation=require_full_compilation, + disable_tf32=disable_tf32, ) mod = self.generate_graph( @@ -444,6 +447,13 @@ def run_test( compilation_settings=compilation_settings, ) + if require_full_compilation: + missing = interp.validate_conversion() + self.assertTrue( + len(missing) == 0, + f"require_full_compilation=True but the following ops don't have TRT converter: {missing}", + ) + super().run_test( mod, trt_inputs, @@ -460,7 +470,6 @@ def run_test_compare_tensor_attributes_only( inputs, expected_ops, comparators: List[Tuple[Callable, List]], - precision=torch.float, output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py index 387f44065f..85204daa83 100644 --- a/tests/py/dynamo/conversion/test_attention.py +++ b/tests/py/dynamo/conversion/test_attention.py @@ -27,7 +27,6 @@ def forward(self, query, key, value): inputs, rtol=1e-2, atol=1e-2, - precision=torch.float16, enable_passes=True, decompose_attention=True, ) @@ -63,7 +62,6 @@ def forward(self, query, key, value): inputs, rtol=1e-2, atol=1e-2, - precision=torch.float16, enable_passes=True, decompose_attention=True, ) @@ -96,7 +94,6 @@ def forward(self, query, key, value): inputs, rtol=1e-2, atol=1e-2, - precision=torch.float16, enable_passes=True, decompose_attention=True, ) diff --git a/tests/py/dynamo/conversion/test_attention_aten.py b/tests/py/dynamo/conversion/test_attention_aten.py index 2225b59278..8205958339 100644 --- a/tests/py/dynamo/conversion/test_attention_aten.py +++ b/tests/py/dynamo/conversion/test_attention_aten.py @@ -139,7 +139,6 @@ def forward(self, query, key, value, attn_mask=None): inputs, rtol=1e-2, atol=1e-2, - precision=dtype, enable_passes=True, ) @@ -274,7 +273,6 @@ def forward(self, query, key, value, attn_mask=None): inputs, rtol=1e-2, atol=1e-2, - precision=dtype, enable_passes=True, ) @@ -513,7 +511,6 @@ def forward(self, query, key, value, attn_bias=None): inputs, rtol=1e-2, atol=1e-2, - precision=dtype, enable_passes=True, ) @@ -630,7 +627,6 @@ def forward(self, query, key, value, attn_bias=None): inputs, rtol=1e-2, atol=1e-2, - precision=dtype, enable_passes=True, ) diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index 550a9c1d45..0c924f9cfc 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -47,11 +47,10 @@ def forward(self, x): y = torch.ops.aten._to_copy.default(x, dtype=torch.half) return y - inputs = [torch.rand((1, 3, 10))] + inputs = [torch.rand((1, 3, 10), dtype=torch.half)] self.run_test( ToCopyHalf(), inputs, - precision=torch.half, ) def test_to_copy_float(self): @@ -60,11 +59,10 @@ def forward(self, x): y = torch.ops.aten._to_copy.default(x, dtype=torch.float) return y - inputs = [torch.rand((1, 3, 10)).half()] + inputs = [torch.rand((1, 3, 10), dtype=torch.float)] self.run_test( ToCopyFloat(), inputs, - precision=torch.float, ) def test_to_copy_bfloat16(self): @@ -74,11 +72,10 @@ def forward(self, x): y = y**2 return y - inputs = [torch.rand((1, 3, 10), dtype=torch.float32)] + inputs = [torch.rand((1, 3, 10), dtype=torch.bfloat16)] self.run_test( ToCopyBFloat16(), inputs, - precision=torch.float, ) def test_to_copy_i64b(self): @@ -102,11 +99,10 @@ def forward(self, x): z = torch.ops.aten._to_copy.default(x_1, dtype=torch.float) return y, z - inputs = [torch.rand((1, 3, 10))] + inputs = [torch.rand((1, 3, 10), dtype=torch.float)] self.run_test( ToCopyReturns(), inputs, - precision=torch.float, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index e8019192b3..748bd43d88 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -146,7 +146,6 @@ def forward(self, weight, indices): self.run_test( TestEmbeddingBag(), inputs=[weight, indices], - precision=weight.dtype, enable_passes=True, propagate_shapes=True, immutable_weights=True, @@ -345,7 +344,6 @@ def forward(self, weight, indices, offsets): self.run_test( TestEmbeddingBag(), inputs=[weight, indices, offsets], - precision=weight.dtype, enable_passes=True, propagate_shapes=True, immutable_weights=True, @@ -410,7 +408,6 @@ def forward(self, weight, indices, offsets): self.run_test( TestEmbeddingBag(), inputs=[weight, indices, offsets], - precision=weight.dtype, enable_passes=True, propagate_shapes=True, immutable_weights=True, diff --git a/tests/py/dynamo/conversion/test_erf_aten.py b/tests/py/dynamo/conversion/test_erf_aten.py index 08b445a172..1691593d03 100644 --- a/tests/py/dynamo/conversion/test_erf_aten.py +++ b/tests/py/dynamo/conversion/test_erf_aten.py @@ -22,7 +22,7 @@ def forward(self, input): return torch.ops.aten.erf.default(input) inputs = [torch.randn(x, dtype=type)] - self.run_test(erf(), inputs, precision=type) + self.run_test(erf(), inputs) @parameterized.expand( [ diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py index 46e66ecd9b..8c3dc0bfff 100644 --- a/tests/py/dynamo/conversion/test_group_norm_aten.py +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -42,7 +42,6 @@ def forward(self, x, weight, bias): self.run_test( GroupNorm(), inputs, - precision=torch.half, use_dynamo_tracer=True, enable_passes=True, ) diff --git a/tests/py/dynamo/conversion/test_hard_sigmoid_aten.py b/tests/py/dynamo/conversion/test_hard_sigmoid_aten.py index 7b4cc8a9e2..5877948b66 100644 --- a/tests/py/dynamo/conversion/test_hard_sigmoid_aten.py +++ b/tests/py/dynamo/conversion/test_hard_sigmoid_aten.py @@ -51,11 +51,10 @@ class TestModule(nn.Module): def forward(self, x): return torch.ops.aten.hardsigmoid.default(x) - inputs = [torch.randn(1, 10)] + inputs = [torch.randn(1, 10, dtype=torch.float16)] self.run_test( TestModule(), inputs, - precision=torch.half, check_dtype=False, ) diff --git a/tests/py/dynamo/conversion/test_neg_aten.py b/tests/py/dynamo/conversion/test_neg_aten.py index a0439c02bc..993afc2420 100644 --- a/tests/py/dynamo/conversion/test_neg_aten.py +++ b/tests/py/dynamo/conversion/test_neg_aten.py @@ -21,8 +21,8 @@ class neg(nn.Module): def forward(self, input): return torch.ops.aten.neg.default(input) - inputs = [torch.randn(x, dtype=type)] - self.run_test(neg(), inputs, precision=type) + inputs = [torch.randn(x, dtype=type).cuda()] + self.run_test(neg(), inputs) @parameterized.expand( [ diff --git a/tests/py/dynamo/conversion/test_sigmoid_aten.py b/tests/py/dynamo/conversion/test_sigmoid_aten.py index b8cb27574e..60c68c561d 100644 --- a/tests/py/dynamo/conversion/test_sigmoid_aten.py +++ b/tests/py/dynamo/conversion/test_sigmoid_aten.py @@ -49,11 +49,10 @@ class TestModule(nn.Module): def forward(self, x): return torch.ops.aten.sigmoid.default(x) - inputs = [torch.randn(1, 10)] + inputs = [torch.randn(1, 10, dtype=torch.float16)] self.run_test( TestModule(), inputs, - precision=torch.half, check_dtype=False, ) diff --git a/tests/py/dynamo/hlo/test_attention.py b/tests/py/dynamo/hlo/test_attention.py index 2f8c29c026..f787380536 100644 --- a/tests/py/dynamo/hlo/test_attention.py +++ b/tests/py/dynamo/hlo/test_attention.py @@ -13,12 +13,6 @@ or upgrade to TRT 11.0 or later. TODO: @Evan to verify the version of TensorRT-RTX that resolves this bug. - PyTorch 2.12.0 (resolved in PyTorch 2.13.0): - PyTorch 2.12.0's core_aten decomposition expands scaled_dot_product_attention - into matmul + _safe_softmax before the TRT converter runs. No converter - is registered for _safe_softmax, so FP32 GQA requires decompose_attention=True. - To resolve this issue, please upgrade to PyTorch 2.13.0 or later. - Notes on attn_bias_is_causal ----------------------------- Default True: the force_causal_efficient_attention lowering pass strips @@ -70,10 +64,10 @@ "Flash attention requires Ampere (SM80) or higher", ) -# skip RTX on Windows -_TRT_RTX_WINDOWS_SKIP = unittest.skipIf( - torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx and sys.platform == "win32", - "This test is skipped on TensorRT-RTX on Windows", +# skip on Windows +_WINDOWS_SKIP = unittest.skipIf( + sys.platform == "win32", + "This test is skipped on Windows because USE_FLASH_ATTENTION was not enabled for build", ) @@ -191,7 +185,6 @@ def forward(self, q, k, v): [q, k, v], rtol=1e-2, atol=test_atol, - precision=dtype, enable_passes=True, decompose_attention=use_decompose, ) @@ -235,7 +228,6 @@ def forward(self, q, k, v): [q, k, v], rtol=1e-2, atol=1e-2, - precision=dtype, enable_passes=True, decompose_attention=False, ) @@ -293,7 +285,6 @@ def forward(self, q, k, v, mask): [q, k, v, mask], rtol=1e-2, atol=1e-2, - precision=dtype, enable_passes=True, decompose_attention=use_decompose, ) @@ -347,7 +338,6 @@ def forward(self, q, k, v, mask): [q, k, v, mask], rtol=1e-2, atol=test_atol, - precision=dtype, enable_passes=True, decompose_attention=use_decompose, ) @@ -361,7 +351,7 @@ def forward(self, q, k, v, mask): ("gqa_32q_8kv_s2048_bf16", 1, 32, 8, 2048, 128, True, torch.bfloat16,False), # large causal in bf16 ("gqa_16q_4kv_s128_fp16", 2, 16, 4, 128, 64, True, torch.float16, False), ("gqa_8q_2kv_nc_fp16", 2, 8, 2, 64, 64, False, torch.float16, False), - ("gqa_8q_4kv_fp32", 2, 8, 4, 64, 64, False, torch.float32, False), # decomposed to _safe_softmax + matmul in torch 2.12.0 but not in 2.13.0 + ("gqa_8q_4kv_fp32", 2, 8, 4, 64, 64, False, torch.float32, False), ("gqa_24q_8kv_fp16", 1, 24, 8, 128, 128, True, torch.float16, False), # Llama-3.2-3B ("gqa_14q_2kv_fp16", 1, 14, 2, 128, 64, True, torch.float16, False), # Qwen2.5-0.5B # MQA (kv_heads = 1) @@ -401,7 +391,6 @@ def forward(self, q, k, v): [q, k, v], rtol=1e-2, atol=1e-2, - precision=dtype, enable_passes=True, decompose_attention=use_decompose, ) @@ -413,7 +402,7 @@ def forward(self, q, k, v): @_FLASH_ATTN_SKIP -@_TRT_RTX_WINDOWS_SKIP +@_WINDOWS_SKIP class TestFlashAttention(DispatchTestCase): """_scaled_dot_product_flash_attention kernel (Ampere+ required). @@ -492,7 +481,6 @@ def forward(self, q, k, v): [q, k, v], rtol=1e-2, atol=atol, - precision=dtype, enable_passes=True, decompose_attention=use_decompose, ) @@ -535,7 +523,6 @@ def forward(self, q, k, v): [q, k, v], rtol=1e-2, atol=atol, - precision=dtype, enable_passes=True, decompose_attention=False, ) @@ -592,7 +579,6 @@ def forward(self, q, k, v): [q, k, v], rtol=1e-2, atol=atol, - precision=dtype, enable_passes=True, decompose_attention=use_decompose, ) @@ -677,7 +663,6 @@ def forward(self, q, k, v): [q, k, v], rtol=1e-2, atol=atol, - precision=dtype, enable_passes=True, decompose_attention=True, ) @@ -725,7 +710,6 @@ def forward(self, q, k, v, bias): [q, k, v, bias], rtol=1e-2, atol=atol, - precision=dtype, enable_passes=True, decompose_attention=False, attn_bias_is_causal=False, @@ -767,7 +751,6 @@ def forward(self, q, k, v, bias): [q, k, v, bias], rtol=1e-2, atol=atol, - precision=dtype, enable_passes=True, decompose_attention=False, attn_bias_is_causal=False, @@ -808,7 +791,6 @@ def forward(self, q, k, v, bias): [q, k, v, bias], rtol=1e-2, atol=atol, - precision=dtype, enable_passes=True, decompose_attention=False, attn_bias_is_causal=False, @@ -850,7 +832,6 @@ def forward(self, q, k, v, bias): [q, k, v, bias], rtol=1e-2, atol=atol, - precision=dtype, enable_passes=True, decompose_attention=False, attn_bias_is_causal=True, diff --git a/tests/py/dynamo/hlo/test_moe.py b/tests/py/dynamo/hlo/test_moe.py new file mode 100644 index 0000000000..7ddad3990f --- /dev/null +++ b/tests/py/dynamo/hlo/test_moe.py @@ -0,0 +1,1050 @@ +"""Comprehensive MoE subgraph tests for TRT converter bug discovery. + +Covers all Mixture-of-Experts routing and dispatch variants found in popular +open-source models: Mixtral, Llama4, Qwen2-MoE, Qwen3-MoE, DeepSeek-V2/V3, +and NVIDIA Nemotron-H. Each test class instantiates a self-contained MoE +block and validates TRT output against PyTorch reference. + +Routing variants covered +------------------------ + Softmax routing (Mixtral, Qwen2, Qwen3): + softmax(gate_logits) → topk → optional renormalization + Sigmoid routing (Llama4, DeepSeek-V3/R1, Nemotron): + sigmoid(gate_logits) → topk + +Group-limited greedy selection (DeepSeek, Nemotron): + Two sub-variants are represented: + max-per-group (DeepSeek-V2): group score = max expert score in group + top2-sum-per-group (DeepSeek-V3 / Nemotron): + group score = sum of top-2 expert scores; e_score_correction_bias added + +Shared expert variants +---------------------- + None (Mixtral, Qwen3): all computation goes through the routed experts only + Always-on, unweighted (Llama4, DeepSeek, Nemotron): + shared output always added to routed output + Sigmoid-gated scalar (Qwen2): + shared output weighted by sigmoid(Linear(hidden, 1)) per token + +Expert MLP styles +----------------- + SwiGLU / gated MLP (all except Nemotron): + output = down_proj(act(gate_proj(x)) * up_proj(x)) + Plain 2-layer MLP (Nemotron): + output = down_proj(act(up_proj(x))) + +Dispatch mechanism +------------------ + Scatter-based dense dispatch (used in all test classes here): + Build routing_weight_matrix [T, N] via scatter_, run every expert on all + tokens, accumulate weighted outputs. This is the only dispatch pattern + compatible with torch.export + static shapes. + + The original models use three dispatch patterns that are NOT directly + exportable and are therefore approximated: + index_add dispatch (Mixtral, Qwen2, Qwen3, Nemotron): + torch.where(expert_mask) returns dynamic-size indices; the subsequent + hidden_states[top_x] index is data-dependent → rejected by torch.export. + Sort-based dispatch (DeepSeek moe_infer): + tokens_per_expert.cpu().numpy() is a device sync + Python loop over + dynamic counts → rejected by torch.export. + Dense-broadcast dispatch (Llama4): + hidden.repeat(N, 1) + sigmoid mask → zero-weight experts contribute ~0; + tested as-is since it IS export-friendly. + + The scatter-based approximation in all non-Llama4 classes computes the + identical numerical result as the original index_add dispatch; it is a + mathematical equivalence, not a compromise. + +Known limitations +----------------- + FP32 MoE with large token counts: accumulated rounding in the routing + scatter + expert matmul chain causes larger divergence than FP16; tests + use atol=1e-3 for FP32 cases. + +Test classes +------------ + TestMixtralStyleMoE softmax routing, SwiGLU, no shared expert + TestQwen2StyleMoE softmax routing, SwiGLU, sigmoid-gated shared expert + TestQwen3StyleMoE softmax routing + optional norm_topk_prob, SwiGLU, no shared + TestLlama4StyleMoE sigmoid routing, dense broadcast, batched bmm experts, shared expert + TestDeepSeekV2StyleMoE sigmoid/softmax + group_limited_greedy (max-per-group), shared expert + TestDeepSeekV3StyleMoE sigmoid + group_limited_greedy (top2-sum + bias), shared expert + TestNemotronStyleMoE sigmoid + group_limited_greedy (top2-sum + bias), shared expert, plain MLP +""" + +import unittest + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from ..conversion.harness import DispatchTestCase + +_BF16_SKIP = unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8, + "BF16 requires Ampere (SM80) or higher", +) + + +# --------------------------------------------------------------------------- +# Shared building blocks +# --------------------------------------------------------------------------- + + +class SwiGLUExpert(nn.Module): + """Gated MLP used by Mixtral, Qwen, DeepSeek (gate_proj * act + up_proj → down_proj).""" + + def __init__(self, hidden_size: int, ffn_dim: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, ffn_dim, bias=False) + self.up_proj = nn.Linear(hidden_size, ffn_dim, bias=False) + self.down_proj = nn.Linear(ffn_dim, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class PlainMLPExpert(nn.Module): + """Non-gated MLP used by Nemotron-H (up_proj → act → down_proj, no gate).""" + + def __init__(self, hidden_size: int, ffn_dim: int): + super().__init__() + self.up_proj = nn.Linear(hidden_size, ffn_dim, bias=False) + self.down_proj = nn.Linear(ffn_dim, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.up_proj(x))) + + +def _scatter_dispatch( + hidden: torch.Tensor, + experts: nn.ModuleList, + routing_weights: torch.Tensor, + selected_experts: torch.Tensor, + num_experts: int, +) -> torch.Tensor: + """Export-friendly dense dispatch: one_hot routing weights, run all experts. + + Mathematically equivalent to the index_add dispatch used in Mixtral/Qwen; + avoids data-dependent torch.where indexing that torch.export rejects, and + avoids torch.zeros(T, N) which fails when T is an FX Proxy. + + Args: + hidden: [T, hidden_size] + routing_weights: [T, top_k] + selected_experts: [T, top_k] (int indices) + num_experts: N + + Returns: + [T, hidden_size] + """ + # one_hot: [T, top_k, N]; multiply by weights then sum over top_k → [T, N] + one_hot_mask = F.one_hot(selected_experts.long(), num_classes=num_experts).to( + routing_weights.dtype + ) # [T, top_k, N] + weight_matrix = (one_hot_mask * routing_weights.unsqueeze(-1)).sum(dim=1) # [T, N] + + final = torch.zeros_like(hidden) + for i, expert in enumerate(experts): + expert_out = expert(hidden) # [T, hidden_size] + final = final + expert_out * weight_matrix[:, i : i + 1].to(hidden.dtype) + return final + + +# --------------------------------------------------------------------------- +# TestMixtralStyleMoE +# --------------------------------------------------------------------------- + + +class MixtralStyleMoE(nn.Module): + """Softmax-routed MoE without shared expert (Mixtral, Qwen3 baseline). + + Routing: softmax(gate) → topk → always renormalize to sum=1. + Dispatch: scatter-based dense (export-friendly equivalent of index_add). + """ + + def __init__( + self, + hidden_size: int, + ffn_dim: int, + num_experts: int, + top_k: int, + norm_topk_prob: bool = True, + ): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.norm_topk_prob = norm_topk_prob + self.gate = nn.Linear(hidden_size, num_experts, bias=False) + self.experts = nn.ModuleList( + [SwiGLUExpert(hidden_size, ffn_dim) for _ in range(num_experts)] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + hidden = hidden_states.view(-1, H) + router_logits = self.gate(hidden) + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + if self.norm_topk_prob: + routing_weights = routing_weights / routing_weights.sum( + dim=-1, keepdim=True + ) + routing_weights = routing_weights.to(hidden.dtype) + out = _scatter_dispatch( + hidden, self.experts, routing_weights, selected_experts, self.num_experts + ) + return out.view(B, S, H) + + +class TestMixtralStyleMoE(DispatchTestCase): + # fmt: off + @parameterized.expand( + [ + # (name, batch, seq, hidden, ffn, num_experts, top_k, norm, dtype, atol) + # --- Basic FP16 --- + ("b1_s32_e4_k2_fp16", 1, 32, 64, 128, 4, 2, True, torch.float16, 1e-2), + ("b2_s64_e4_k2_fp16", 2, 64, 64, 128, 4, 2, True, torch.float16, 1e-2), + ("b1_s128_e8_k2_fp16", 1, 128, 64, 128, 8, 2, True, torch.float16, 1e-2), + # top_k=1 (single-expert routing) + ("b1_s32_e4_k1_fp16", 1, 32, 64, 128, 4, 1, True, torch.float16, 1e-2), + # norm_topk_prob=False (Qwen3 config with norm disabled) + ("b1_s32_e4_k2_nonorm_fp16", 1, 32, 64, 128, 4, 2, False, torch.float16, 1e-2), + # Larger hidden_size + ("b1_s32_e4_k2_h128_fp16", 1, 32, 128, 256, 4, 2, True, torch.float16, 1e-2), + # --- FP32 --- + ("b1_s32_e4_k2_fp32", 1, 32, 64, 128, 4, 2, True, torch.float32, 1e-3), + ("b1_s64_e8_k2_fp32", 1, 64, 64, 128, 8, 2, True, torch.float32, 1e-3), + # Mixtral-realistic (small proxy): 8 experts, top-2 + ("mixtral_proxy_fp16", 1, 64, 64, 128, 8, 2, True, torch.float16, 1e-2), + # Qwen3-realistic (small proxy): 8 experts, top-2, no norm + ("qwen3_proxy_fp16", 1, 64, 64, 128, 8, 2, False, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_mixtral_style( + self, name, batch, seq, hidden, ffn, n_exp, top_k, norm, dtype, atol + ): + mod = ( + MixtralStyleMoE(hidden, ffn, n_exp, top_k, norm_topk_prob=norm) + .eval() + .cuda() + .to(dtype) + ) + x = torch.randn(batch, seq, hidden, dtype=dtype) + if dtype == torch.float32: + # small diff between tf32 and float32 may cause the topk function to choose different experts + disable_tf32 = True + else: + disable_tf32 = False + self.run_test( + mod, + [x], + rtol=1e-2, + atol=atol, + enable_passes=True, + use_dynamo_tracer=True, + require_full_compilation=True, + disable_tf32=disable_tf32, + ) + + +# --------------------------------------------------------------------------- +# TestQwen2StyleMoE +# --------------------------------------------------------------------------- + + +class Qwen2StyleMoE(nn.Module): + """Softmax-routed MoE with sigmoid-gated scalar shared expert (Qwen2-MoE). + + Routing: softmax(gate) → topk → optional renorm. + Shared expert: shared_output * sigmoid(Linear(hidden, 1)) added to routed output. + """ + + def __init__( + self, + hidden_size: int, + ffn_dim: int, + shared_ffn_dim: int, + num_experts: int, + top_k: int, + norm_topk_prob: bool = False, + ): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.norm_topk_prob = norm_topk_prob + self.gate = nn.Linear(hidden_size, num_experts, bias=False) + self.experts = nn.ModuleList( + [SwiGLUExpert(hidden_size, ffn_dim) for _ in range(num_experts)] + ) + self.shared_expert = SwiGLUExpert(hidden_size, shared_ffn_dim) + self.shared_expert_gate = nn.Linear(hidden_size, 1, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + hidden = hidden_states.view(-1, H) + router_logits = self.gate(hidden) + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + if self.norm_topk_prob: + routing_weights = routing_weights / routing_weights.sum( + dim=-1, keepdim=True + ) + routing_weights = routing_weights.to(hidden.dtype) + routed_out = _scatter_dispatch( + hidden, self.experts, routing_weights, selected_experts, self.num_experts + ) + shared_out = self.shared_expert(hidden) + shared_gate = torch.sigmoid(self.shared_expert_gate(hidden)) # [T, 1] + out = routed_out + shared_gate * shared_out + return out.view(B, S, H) + + +class TestQwen2StyleMoE(DispatchTestCase): + # fmt: off + @parameterized.expand( + [ + # (name, batch, seq, hidden, ffn, shared_ffn, num_experts, top_k, norm, dtype, atol) + ("b1_s32_e4_k2_fp16", 1, 32, 64, 128, 256, 4, 2, False, torch.float16, 1e-2), + ("b2_s64_e4_k2_fp16", 2, 64, 64, 128, 256, 4, 2, False, torch.float16, 1e-2), + ("b1_s32_e8_k2_fp16", 1, 32, 64, 128, 256, 8, 2, False, torch.float16, 1e-2), + # norm_topk_prob=True + ("b1_s32_e4_k2_norm_fp16", 1, 32, 64, 128, 256, 4, 2, True, torch.float16, 1e-2), + # Larger shared expert intermediate size + ("b1_s32_e4_k2_bigshared_fp16", 1, 32, 64, 128, 512, 4, 2, False, torch.float16, 1e-2), + # FP32 (xfail on TRT-RTX only — handled in test body) + ("b1_s32_e4_k2_fp32", 1, 32, 64, 128, 256, 4, 2, False, torch.float32, 1e-3), + # Qwen2-realistic proxy + ("qwen2_proxy_fp16", 1, 64, 64, 128, 256, 8, 2, False, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_qwen2_style( + self, name, batch, seq, hidden, ffn, shared_ffn, n_exp, top_k, norm, dtype, atol + ): + mod = ( + Qwen2StyleMoE(hidden, ffn, shared_ffn, n_exp, top_k, norm_topk_prob=norm) + .eval() + .cuda() + .to(dtype) + ) + x = torch.randn(batch, seq, hidden, dtype=dtype) + if dtype == torch.float32: + # small diff between tf32 and float32 may cause the topk function to choose different experts + disable_tf32 = True + else: + disable_tf32 = False + try: + self.run_test( + mod, + [x], + rtol=1e-2, + atol=atol, + enable_passes=True, + use_dynamo_tracer=True, + require_full_compilation=True, + disable_tf32=disable_tf32, + ) + except AssertionError: + if ( + name == "b1_s32_e4_k2_fp32" + and torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx + ): + pytest.xfail( + "TF32 vs FP32 MoE top-k divergence on TensorRT-RTX; " + "disable_tf32 is not effective on TRT-RTX." + ) + raise + + +# --------------------------------------------------------------------------- +# TestQwen3StyleMoE +# --------------------------------------------------------------------------- + + +class Qwen3StyleMoE(nn.Module): + """Softmax-routed MoE without shared expert, configurable norm (Qwen3-MoE). + + Identical structure to MixtralStyleMoE but captures Qwen3's specific + combination of optional norm_topk_prob and moe_intermediate_size. + """ + + def __init__( + self, + hidden_size: int, + moe_intermediate_size: int, + num_experts: int, + top_k: int, + norm_topk_prob: bool = True, + ): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.norm_topk_prob = norm_topk_prob + self.gate = nn.Linear(hidden_size, num_experts, bias=False) + self.experts = nn.ModuleList( + [ + SwiGLUExpert(hidden_size, moe_intermediate_size) + for _ in range(num_experts) + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + hidden = hidden_states.view(-1, H) + router_logits = self.gate(hidden) + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + if self.norm_topk_prob: + routing_weights = routing_weights / routing_weights.sum( + dim=-1, keepdim=True + ) + routing_weights = routing_weights.to(hidden.dtype) + out = _scatter_dispatch( + hidden, self.experts, routing_weights, selected_experts, self.num_experts + ) + return out.view(B, S, H) + + +class TestQwen3StyleMoE(DispatchTestCase): + # fmt: off + @parameterized.expand( + [ + # (name, batch, seq, hidden, moe_ffn, num_experts, top_k, norm, dtype, atol) + ("b1_s32_e4_k2_norm_fp16", 1, 32, 64, 128, 4, 2, True, torch.float16, 1e-2), + ("b1_s32_e4_k2_nonorm_fp16", 1, 32, 64, 128, 4, 2, False, torch.float16, 1e-2), + ("b2_s64_e8_k2_norm_fp16", 2, 64, 64, 128, 8, 2, True, torch.float16, 1e-2), + ("b1_s128_e4_k1_fp16", 1, 128, 64, 128, 4, 1, True, torch.float16, 1e-2), + ("b1_s32_e4_k2_fp32", 1, 32, 64, 128, 4, 2, True, torch.float32, 1e-3), + # Qwen3-MoE-0.6B proxy: 64 experts, top-8 → scaled down to 8 experts, top-2 + ("qwen3_moe_proxy_fp16", 1, 64, 64, 128, 8, 2, True, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_qwen3_style( + self, name, batch, seq, hidden, moe_ffn, n_exp, top_k, norm, dtype, atol + ): + mod = ( + Qwen3StyleMoE(hidden, moe_ffn, n_exp, top_k, norm_topk_prob=norm) + .eval() + .cuda() + .to(dtype) + ) + x = torch.randn(batch, seq, hidden, dtype=dtype) + self.run_test( + mod, + [x], + rtol=1e-2, + atol=atol, + enable_passes=True, + use_dynamo_tracer=True, + require_full_compilation=True, + ) + + +# --------------------------------------------------------------------------- +# TestLlama4StyleMoE +# --------------------------------------------------------------------------- + + +class Llama4StyleExperts(nn.Module): + """All experts fused into batched matmuls — Llama4TextExperts pattern. + + Weights shape: [N, hidden, 2*ffn] (gate+up fused) and [N, ffn, hidden] (down). + Input is tiled [N*T, hidden], reshaped to [N, T, hidden] for bmm. + """ + + def __init__(self, num_experts: int, hidden_size: int, ffn_dim: int): + super().__init__() + self.num_experts = num_experts + self.hidden_size = hidden_size + self.ffn_dim = ffn_dim + self.gate_up_proj = nn.Parameter( + torch.empty(num_experts, hidden_size, 2 * ffn_dim) + ) + self.down_proj = nn.Parameter(torch.empty(num_experts, ffn_dim, hidden_size)) + nn.init.normal_(self.gate_up_proj, std=0.02) + nn.init.normal_(self.down_proj, std=0.02) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden_states: [N*T, hidden_size] + h = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(h, self.gate_up_proj) # [N, T, 2*ffn] + gate, up = gate_up.chunk(2, dim=-1) + out = torch.bmm(up * F.silu(gate), self.down_proj) # [N, T, hidden] + return out.view(-1, self.hidden_size) + + +class Llama4StyleMoE(nn.Module): + """Sigmoid-routed MoE with dense broadcast dispatch and always-on shared expert (Llama4). + + Routing: topk(logits) → scatter back to full expert space → sigmoid. + Dispatch: tile all tokens N times; zero out non-selected via sigmoid(-inf)≈0. + Shared expert: always-on Llama4TextMLP, output added unconditionally. + """ + + def __init__( + self, + hidden_size: int, + ffn_dim: int, + shared_ffn_dim: int, + num_experts: int, + top_k: int, + ): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.router = nn.Linear(hidden_size, num_experts, bias=False) + self.experts = Llama4StyleExperts(num_experts, hidden_size, ffn_dim) + self.shared_expert = SwiGLUExpert(hidden_size, shared_ffn_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + hidden = hidden_states.view(-1, H) # [T, H] + T = hidden.shape[0] + + router_logits = self.router(hidden) # [T, N] + top_values, top_indices = torch.topk(router_logits, self.top_k, dim=1) + + # Scatter selected logits back; fill unselected with -inf → sigmoid → ~0 + router_scores = ( + torch.full_like(router_logits, float("-inf")) + .scatter_(1, top_indices, top_values) + .transpose(0, 1) # [N, T] + ) + router_scores = torch.sigmoid(router_scores.float()).to(hidden.dtype) # [N, T] + + # Dense broadcast: tile all tokens for all experts + routed_in = hidden.repeat(self.num_experts, 1) # [N*T, H] + routed_in = routed_in * router_scores.reshape(-1, 1) # zero non-selected + routed_out = self.experts(routed_in) # [N*T, H] + + # Sum contributions across experts + expert_sum = routed_out.reshape(self.num_experts, T, H).sum(dim=0) # [T, H] + + out = self.shared_expert(hidden) + expert_sum + return out.view(B, S, H) + + +class TestLlama4StyleMoE(DispatchTestCase): + # fmt: off + @parameterized.expand( + [ + # (name, batch, seq, hidden, ffn, shared_ffn, num_experts, top_k, dtype, atol) + ("b1_s32_e4_k2_fp16", 1, 32, 64, 128, 256, 4, 2, torch.float16, 1e-2), + ("b2_s64_e4_k2_fp16", 2, 64, 64, 128, 256, 4, 2, torch.float16, 1e-2), + ("b1_s128_e8_k2_fp16", 1, 128, 64, 128, 256, 8, 2, torch.float16, 1e-2), + # top_k=1 + ("b1_s32_e4_k1_fp16", 1, 32, 64, 128, 256, 4, 1, torch.float16, 1e-2), + # FP32: dense-broadcast accumulation in FP32 has larger rounding; loosen atol + ("b1_s32_e4_k2_fp32", 1, 32, 64, 128, 256, 4, 2, torch.float32, 1e-2), + # Llama4-Scout proxy (16 experts, top-1) + ("llama4_scout_proxy_fp16", 1, 64, 64, 128, 256, 8, 1, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_llama4_style( + self, name, batch, seq, hidden, ffn, shared_ffn, n_exp, top_k, dtype, atol + ): + mod = ( + Llama4StyleMoE(hidden, ffn, shared_ffn, n_exp, top_k) + .eval() + .cuda() + .to(dtype) + ) + x = torch.randn(batch, seq, hidden, dtype=dtype) + self.run_test( + mod, + [x], + rtol=1e-2, + atol=atol, + enable_passes=True, + use_dynamo_tracer=True, + require_full_compilation=True, + ) + + +# --------------------------------------------------------------------------- +# TestDeepSeekV2StyleMoE +# --------------------------------------------------------------------------- + + +def _group_limited_greedy_topk_max( + scores: torch.Tensor, top_k: int, n_group: int, topk_group: int +) -> tuple[torch.Tensor, torch.Tensor]: + """DeepSeek-V2 group-limited greedy: group score = max expert score in group. + + Args: + scores: [T, N_experts] — softmax or sigmoid scores + top_k: number of experts to select per token + n_group: number of expert groups + topk_group: number of groups to select + + Returns: + topk_weight [T, top_k], topk_idx [T, top_k] + """ + T, N = scores.shape + experts_per_group = N // n_group + + # Score each group by its best expert + group_scores = ( + scores.view(T, n_group, experts_per_group).max(dim=-1).values + ) # [T, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [T, topk_group] + + # Build per-expert mask from selected groups + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1.0) + score_mask = ( + group_mask.unsqueeze(-1).expand(T, n_group, experts_per_group).reshape(T, N) + ) + + masked_scores = scores.masked_fill(~score_mask.bool(), 0.0) + topk_weight, topk_idx = torch.topk(masked_scores, k=top_k, dim=-1, sorted=False) + return topk_weight, topk_idx + + +class DeepSeekV2StyleMoE(nn.Module): + """Group-limited greedy MoE (max-per-group) with shared expert (DeepSeek-V2). + + Routing: softmax(gate) → group_limited_greedy (max per group) → optional renorm. + Shared expert: always-on SwiGLU added to routed output. + """ + + def __init__( + self, + hidden_size: int, + moe_ffn_dim: int, + shared_ffn_dim: int, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + norm_topk_prob: bool = False, + routed_scaling_factor: float = 1.0, + ): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.n_group = n_group + self.topk_group = topk_group + self.norm_topk_prob = norm_topk_prob + self.routed_scaling_factor = routed_scaling_factor + self.gate_weight = nn.Parameter(torch.empty(num_experts, hidden_size)) + nn.init.normal_(self.gate_weight, std=0.02) + self.experts = nn.ModuleList( + [SwiGLUExpert(hidden_size, moe_ffn_dim) for _ in range(num_experts)] + ) + self.shared_expert = SwiGLUExpert(hidden_size, shared_ffn_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + hidden = hidden_states.view(-1, H) + + logits = F.linear(hidden.float(), self.gate_weight.float()) + scores = F.softmax(logits, dim=-1) + + topk_weight, topk_idx = _group_limited_greedy_topk_max( + scores, self.top_k, self.n_group, self.topk_group + ) + if self.norm_topk_prob: + topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20) + else: + topk_weight = topk_weight * self.routed_scaling_factor + topk_weight = topk_weight.to(hidden.dtype) + + routed_out = _scatter_dispatch( + hidden, self.experts, topk_weight, topk_idx, self.num_experts + ) + out = routed_out + self.shared_expert(hidden) + return out.view(B, S, H) + + +class TestDeepSeekV2StyleMoE(DispatchTestCase): + # fmt: off + @parameterized.expand( + [ + # (name, batch, seq, hidden, moe_ffn, shared_ffn, n_exp, top_k, n_group, topk_group, norm, scale, dtype, atol) + ("b1_s32_e4_k2_g2_tg1_fp16", 1, 32, 64, 128, 256, 4, 2, 2, 1, False, 1.0, torch.float16, 1e-2), + ("b2_s64_e4_k2_g2_tg1_fp16", 2, 64, 64, 128, 256, 4, 2, 2, 1, False, 1.0, torch.float16, 1e-2), + ("b1_s32_e8_k2_g4_tg2_fp16", 1, 32, 64, 128, 256, 8, 2, 4, 2, False, 1.0, torch.float16, 1e-2), + # norm_topk_prob=True + ("b1_s32_e4_k2_g2_tg1_norm_fp16", 1, 32, 64, 128, 256, 4, 2, 2, 1, True, 1.0, torch.float16, 1e-2), + # routed_scaling_factor != 1 + ("b1_s32_e4_k2_scale16_fp16", 1, 32, 64, 128, 256, 4, 2, 2, 1, False, 1.6, torch.float16, 1e-2), + # FP32 + ("b1_s32_e4_k2_g2_tg1_fp32", 1, 32, 64, 128, 256, 4, 2, 2, 1, False, 1.0, torch.float32, 1e-3), + # DeepSeek-V2-Lite proxy: 64 experts → 8 here, top-6 → top-2, 8 groups → 2 + ("deepseekv2_proxy_fp16", 1, 64, 64, 128, 256, 8, 2, 2, 1, True, 1.0, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_deepseekv2_style( + self, + name, + batch, + seq, + hidden, + moe_ffn, + shared_ffn, + n_exp, + top_k, + n_group, + topk_group, + norm, + scale, + dtype, + atol, + ): + mod = ( + DeepSeekV2StyleMoE( + hidden, + moe_ffn, + shared_ffn, + n_exp, + top_k, + n_group, + topk_group, + norm_topk_prob=norm, + routed_scaling_factor=scale, + ) + .eval() + .cuda() + .to(dtype) + ) + x = torch.randn(batch, seq, hidden, dtype=dtype) + self.run_test( + mod, + [x], + rtol=1e-2, + atol=atol, + enable_passes=True, + use_dynamo_tracer=True, + require_full_compilation=True, + ) + + +# --------------------------------------------------------------------------- +# TestDeepSeekV3StyleMoE +# --------------------------------------------------------------------------- + + +def _group_limited_greedy_topk_top2sum( + scores: torch.Tensor, + top_k: int, + n_group: int, + topk_group: int, + correction_bias: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """DeepSeek-V3 / Nemotron group-limited greedy: group score = sum of top-2. + + Args: + scores: [T, N_experts] — raw sigmoid scores (before bias) + top_k: experts to select per token + n_group: number of expert groups + topk_group: groups to select + correction_bias: [N_experts] — per-expert additive bias for selection only + + Returns: + topk_weight [T, top_k], topk_idx [T, top_k] + (weights use raw sigmoid scores, not biased scores) + """ + T, N = scores.shape + experts_per_group = N // n_group + + scores_for_choice = scores + correction_bias.unsqueeze(0) # [T, N] + group_scores = ( + scores_for_choice.view(T, n_group, experts_per_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [T, n_group] + + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1.0) + score_mask = ( + group_mask.unsqueeze(-1).expand(T, n_group, experts_per_group).reshape(T, N) + ) + + masked_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_idx = torch.topk(masked_scores, k=top_k, dim=-1, sorted=False)[1] + topk_weight = scores.gather(1, topk_idx) # use raw scores (not biased) + return topk_weight, topk_idx + + +class DeepSeekV3StyleMoE(nn.Module): + """Group-limited greedy MoE (top2-sum-per-group, correction bias) with shared expert (DeepSeek-V3/R1). + + Routing: sigmoid(gate) → group_limited_greedy (top-2 sum per group + bias) → optional renorm × scale. + Shared expert: always-on SwiGLU added to routed output. + """ + + def __init__( + self, + hidden_size: int, + moe_ffn_dim: int, + shared_ffn_dim: int, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + norm_topk_prob: bool = True, + routed_scaling_factor: float = 1.0, + ): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.n_group = n_group + self.topk_group = topk_group + self.norm_topk_prob = norm_topk_prob + self.routed_scaling_factor = routed_scaling_factor + self.gate_weight = nn.Parameter(torch.empty(num_experts, hidden_size)) + nn.init.normal_(self.gate_weight, std=0.02) + self.e_score_correction_bias = nn.Parameter(torch.zeros(num_experts)) + self.experts = nn.ModuleList( + [SwiGLUExpert(hidden_size, moe_ffn_dim) for _ in range(num_experts)] + ) + self.shared_expert = SwiGLUExpert(hidden_size, shared_ffn_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + hidden = hidden_states.view(-1, H) + + logits = F.linear(hidden.float(), self.gate_weight.float()) + scores = torch.sigmoid(logits) # [T, N] + + topk_weight, topk_idx = _group_limited_greedy_topk_top2sum( + scores, + self.top_k, + self.n_group, + self.topk_group, + self.e_score_correction_bias.float(), + ) + if self.norm_topk_prob: + topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20) + topk_weight = (topk_weight * self.routed_scaling_factor).to(hidden.dtype) + + routed_out = _scatter_dispatch( + hidden, self.experts, topk_weight, topk_idx, self.num_experts + ) + out = routed_out + self.shared_expert(hidden) + return out.view(B, S, H) + + +class TestDeepSeekV3StyleMoE(DispatchTestCase): + # fmt: off + @parameterized.expand( + [ + # (name, batch, seq, hidden, moe_ffn, shared_ffn, n_exp, top_k, n_group, topk_group, norm, scale, dtype, atol) + ("b1_s32_e4_k2_g2_tg1_fp16", 1, 32, 64, 128, 256, 4, 2, 2, 1, True, 1.0, torch.float16, 1e-2), + ("b2_s64_e4_k2_g2_tg1_fp16", 2, 64, 64, 128, 256, 4, 2, 2, 1, True, 1.0, torch.float16, 1e-2), + ("b1_s32_e8_k2_g4_tg2_fp16", 1, 32, 64, 128, 256, 8, 2, 4, 2, True, 1.0, torch.float16, 1e-2), + # norm_topk_prob=False + ("b1_s32_e4_k2_nonorm_fp16", 1, 32, 64, 128, 256, 4, 2, 2, 1, False, 1.0, torch.float16, 1e-2), + # routed_scaling_factor != 1 + ("b1_s32_e4_k2_scale25_fp16", 1, 32, 64, 128, 256, 4, 2, 2, 1, True, 2.5, torch.float16, 1e-2), + # FP32 + ("b1_s32_e4_k2_fp32", 1, 32, 64, 128, 256, 4, 2, 2, 1, True, 1.0, torch.float32, 1e-3), + # DeepSeek-R1 proxy: sigmoid + top2-sum-per-group + ("deepseekr1_proxy_fp16", 1, 64, 64, 128, 256, 8, 2, 2, 1, True, 1.0, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_deepseekv3_style( + self, + name, + batch, + seq, + hidden, + moe_ffn, + shared_ffn, + n_exp, + top_k, + n_group, + topk_group, + norm, + scale, + dtype, + atol, + ): + mod = ( + DeepSeekV3StyleMoE( + hidden, + moe_ffn, + shared_ffn, + n_exp, + top_k, + n_group, + topk_group, + norm_topk_prob=norm, + routed_scaling_factor=scale, + ) + .eval() + .cuda() + .to(dtype) + ) + x = torch.randn(batch, seq, hidden, dtype=dtype) + self.run_test( + mod, + [x], + rtol=1e-2, + atol=atol, + enable_passes=True, + use_dynamo_tracer=True, + require_full_compilation=True, + ) + + +# --------------------------------------------------------------------------- +# TestNemotronStyleMoE +# --------------------------------------------------------------------------- + + +class NemotronStyleMoE(nn.Module): + """Group-limited greedy MoE (top2-sum + bias) with plain MLP experts and shared expert (Nemotron-H). + + Key difference from DeepSeekV3: expert MLP is non-gated (up → act → down, no gate_proj). + Routing and dispatch are otherwise identical to DeepSeekV3. + """ + + def __init__( + self, + hidden_size: int, + moe_ffn_dim: int, + shared_ffn_dim: int, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + norm_topk_prob: bool = True, + routed_scaling_factor: float = 1.0, + ): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.n_group = n_group + self.topk_group = topk_group + self.norm_topk_prob = norm_topk_prob + self.routed_scaling_factor = routed_scaling_factor + self.gate_weight = nn.Parameter(torch.empty(num_experts, hidden_size)) + nn.init.normal_(self.gate_weight, std=0.02) + self.e_score_correction_bias = nn.Parameter(torch.zeros(num_experts)) + # Plain MLP (no gate_proj) — Nemotron-H's non-gated expert + self.experts = nn.ModuleList( + [PlainMLPExpert(hidden_size, moe_ffn_dim) for _ in range(num_experts)] + ) + self.shared_expert = PlainMLPExpert(hidden_size, shared_ffn_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + hidden = hidden_states.view(-1, H) + + logits = F.linear(hidden.float(), self.gate_weight.float()) + scores = torch.sigmoid(logits) + + topk_weight, topk_idx = _group_limited_greedy_topk_top2sum( + scores, + self.top_k, + self.n_group, + self.topk_group, + self.e_score_correction_bias.float(), + ) + if self.norm_topk_prob: + topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20) + topk_weight = (topk_weight * self.routed_scaling_factor).to(hidden.dtype) + + routed_out = _scatter_dispatch( + hidden, self.experts, topk_weight, topk_idx, self.num_experts + ) + out = routed_out + self.shared_expert(hidden) + return out.view(B, S, H) + + +class TestNemotronStyleMoE(DispatchTestCase): + # fmt: off + @parameterized.expand( + [ + # (name, batch, seq, hidden, moe_ffn, shared_ffn, n_exp, top_k, n_group, topk_group, norm, scale, dtype, atol) + ("b1_s32_e4_k2_g2_tg1_fp16", 1, 32, 64, 128, 256, 4, 2, 2, 1, True, 1.0, torch.float16, 1e-2), + ("b2_s64_e4_k2_g2_tg1_fp16", 2, 64, 64, 128, 256, 4, 2, 2, 1, True, 1.0, torch.float16, 1e-2), + ("b1_s128_e8_k2_g4_tg2_fp16", 1, 128, 64, 128, 256, 8, 2, 4, 2, True, 1.0, torch.float16, 1e-2), + # norm_topk_prob=False + ("b1_s32_e4_k2_nonorm_fp16", 1, 32, 64, 128, 256, 4, 2, 2, 1, False, 1.0, torch.float16, 1e-2), + # Non-zero correction bias (Nemotron initializes it to zero but it's learned) + ("b1_s32_e4_k2_bias_fp16", 1, 32, 64, 128, 256, 4, 2, 2, 1, True, 1.0, torch.float16, 1e-2), + # FP32 + ("b1_s32_e4_k2_fp32", 1, 32, 64, 128, 256, 4, 2, 2, 1, True, 1.0, torch.float32, 1e-3), + # Nemotron-H proxy: plain MLP + group-limited greedy + ("nemotron_proxy_fp16", 1, 64, 64, 128, 256, 8, 2, 2, 1, True, 1.0, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_nemotron_style( + self, + name, + batch, + seq, + hidden, + moe_ffn, + shared_ffn, + n_exp, + top_k, + n_group, + topk_group, + norm, + scale, + dtype, + atol, + ): + mod = ( + NemotronStyleMoE( + hidden, + moe_ffn, + shared_ffn, + n_exp, + top_k, + n_group, + topk_group, + norm_topk_prob=norm, + routed_scaling_factor=scale, + ) + .eval() + .cuda() + .to(dtype) + ) + # Non-zero correction bias to exercise the bias path + if "bias" in name: + with torch.no_grad(): + mod.e_score_correction_bias.data = ( + torch.randn(n_exp, device="cuda") * 0.1 + ) + x = torch.randn(batch, seq, hidden, dtype=dtype) + self.run_test( + mod, + [x], + rtol=1e-2, + atol=atol, + enable_passes=True, + use_dynamo_tracer=True, + require_full_compilation=True, + ) + + +if __name__ == "__main__": + run_tests()