Skip to content

[RFE]: Graph-capturable NCCL EP API for CUDA Graph / Piecewise CUDA Graph integration #2104

@nkumaraws

Description

@nkumaraws

Goal: Enable NCCL EP dispatch/combine operations to work inside CUDA graphs (or at piecewise graph split points), eliminating the host-side C API calls that break cudaStreamBeginCapture.

Who benefits: All NCCL EP users on vLLM, SGLang, and other inference frameworks that depend on CUDA graphs for decode throughput. This is the single largest performance bottleneck for NCCL EP in production MoE inference.

Architecture/infrastructure: Any NCCL EP deployment. Validated gap on 2× p5en.48xlarge (16× H200) over EFA, but applies equally to InfiniBand deployments.

How it improves workflows: CUDA graphs are the single largest optimization opportunity — estimated 1.5-3× throughput improvement and a prerequisite for TBO/DBO overlap. Without them, every decode step pays full kernel launch overhead across 48+ MoE layers. The compound performance gap (no CUDA graphs + no DBO/TBO) is estimated at 4-6×.

Priority: High — this is the number 1 performance bottleneck for NCCL EP in production inference.


Summary

NCCL EP's Low-Latency (LL) mode cannot be used inside CUDA graphs because the API requires host-side C function calls (ncclEpCreateHandle, ncclEpDispatchTensorHandles, ncclEpCombineTensorHandles, ncclEpComplete, ncclEpHandleDestroy, ncclEpFreeDescriptors) during each forward pass. These calls break cudaStreamBeginCapture and prevent both full CUDA graph capture and piecewise CUDA graph capture in vLLM and SGLang.

We request either a graph-capturable API variant or documentation of a recommended workaround pattern.

Impact

Current Performance (eager mode, no CUDA graphs)

Model vLLM tok/s SGLang tok/s TPGS
Qwen3-30B 2,061 670 128.8 (vLLM)
Qwen3-235B 1,180 217 73.8 (vLLM)
DeepSeek-R1 1,085 121 67.8 (vLLM)

Published SOTA (with CUDA graphs + DBO + EPLB)

System TPGS Gap to Ours
vLLM Wide-EP (H200 + IB, DeepEP) 2,200 17×
SGLang + DeepEP (H100 + IB) 2,785 22×

Gap Decomposition

Factor Estimated Impact Addressable?
No CUDA graphs 2-3× Requires this RFC
No DBO/TBO 1.5-2× Depends on CUDA graphs
No EPLB 1.1-1.3× Independent
EP=16 vs EP=72+ 1.3-1.5× Hardware-limited
EFA vs IB latency 1.1-1.2× Intrinsic

CUDA graphs are the single largest factor AND a prerequisite for DBO/TBO overlap optimizations. Without them, the compound gap is ~4-6×.

Root Cause: Host-Side API Design

Operations That Break CUDA Graph Capture

Each MoE layer's forward pass requires these host-side calls:

# 1. Create handle (host-side allocation + NCCL registration)
handle = ncclEpCreateHandle(comm, num_experts, hidden, ...)

# 2. Dispatch (host-side call that enqueues GPU work)
ncclEpDispatchTensorHandles(handle, input_tensor, topk_ids, ...)

# 3. [Expert compute happens on GPU - this part IS graph-capturable]

# 4. Combine (host-side call that enqueues GPU work)
ncclEpCombineTensorHandles(handle, expert_output, ...)

# 5. Complete (host-side sync)
ncclEpComplete(handle)

# 6. Cleanup (host-side)
ncclEpHandleDestroy(handle)
ncclEpFreeDescriptors(handle)

Steps 1, 2, 4, 5, 6 are all host-side C function calls via ctypes/FFI. cudaStreamBeginCapture requires all operations to be GPU-side stream operations. Any host-side call during capture causes either a capture failure or undefined behavior.

Additional Host Sync in Dispatch

Beyond the API calls themselves, the dispatch path requires reading recv_expert_count on the host to determine per-expert token counts for the subsequent GEMM. This involves:

  • torch.cuda.synchronize() — device-wide barrier
  • cudaMemcpy D2H — synchronous copy of recv_counter

This is inherent to the "Standard" activation format where the CPU needs expert counts to construct the 3D-to-2D gather indices.

How DeepEP Solves This

DeepEP LL mode is fully CUDA-graph-capturable. The key design differences:

Aspect DeepEP LL NCCL EP LL
Dispatch/Combine Pure GPU kernels (NVSHMEM put/signal/wait) Host-side C API
Recv count query No host sync — fixed max_dispatch_tokens_per_rank upper bound synchronize() + memcpy D2H
Handle creation Not needed per-forward — routing encoded in kernel args ncclEpCreateHandle() per-forward
Buffer lifecycle Pre-allocated to max capacity, "restore on replay" Dynamic per-handle
Graph capture Supported (documented) Not supported

DeepEP's key insight: accept worst-case buffer sizing and never read expert counts on the host. The recv_expert_count tensor stays on device and is consumed directly by masked GEMM kernels that skip empty expert slots.

Proposed Solutions

Option A: CUDA Kernel API Variant (Recommended)

Expose dispatch and combine as CUDA kernels (or cudaLaunchKernel-compatible entry points) that can be captured by CUDA graphs:

// New API — graph-capturable variants
ncclResult_t ncclEpDispatchAsync(
    ncclEpHandle_t handle,
    void* input, void* output,
    int* topk_ids, float* topk_weights,
    int num_tokens, int max_tokens_per_rank,
    cudaStream_t stream);

ncclResult_t ncclEpCombineAsync(
    ncclEpHandle_t handle,
    void* input, void* output,
    int* topk_ids, float* topk_weights,
    int num_tokens, int max_tokens_per_rank,
    cudaStream_t stream);

Handle creation and destruction remain host-side but are moved outside the graph-captured region (called once during warmup, reused across replays).

Option B: Pre-Allocated Handle Pool

Allow creating a pool of handles during initialization that can be reused across forward passes without per-step host allocation:

// During init (outside graph capture)
ncclEpHandlePool_t pool;
ncclEpCreateHandlePool(&pool, comm, max_handles, max_tokens, ...);

// During forward (inside graph capture)
ncclEpHandle_t h = ncclEpHandlePoolGet(pool, step_index);
ncclEpDispatchFromPool(h, ...);  // GPU-only
ncclEpCombineFromPool(h, ...);   // GPU-only
// No explicit destroy needed — pool manages lifecycle

Option C: Piecewise Graph Integration Guide

If a fully graph-capturable API is not feasible, provide an official guide for piecewise CUDA graph integration where:

  • Non-EP model pieces (attention, norms, residuals) are captured as graphs
  • EP dispatch/combine runs eagerly between graph pieces
  • Handle creation is hoisted to a per-batch setup phase

This is what both vLLM and SGLang's piecewise CUDA graph runners attempt to do. The challenge is that the framework's torch.compile tracing also fails on the ctypes FFI calls, even when they're at split points. A torch.library.custom_op wrapper that's properly annotated for torch.compile would solve this.

Option D: Document the recv_expert_count Host Sync Workaround

Even with Options A-C, frameworks need guidance on avoiding the recv_expert_count host sync. The recommended pattern (from DeepEP):

  1. Pre-allocate output buffer to max_tokens_per_rank * num_ranks * hidden
  2. Keep recv_expert_count on device (never D2H copy)
  3. Pass recv_expert_count directly to the expert GEMM kernel
  4. GEMM kernel skips computation for expert slots with count=0
  5. Only sync for debug/correctness validation

Environment

  • 2× p5en.48xlarge, 16× NVIDIA H200, AWS EFA 3.2 Tbps/node
  • NCCL from master branch (post-v2.29.3-1) with EP + GIN Device API
  • NCCL EP LL mode with GIN Proxy (Type 2)
  • vLLM 0.18.0 + SGLang 0.5.10
  • CUDA 12.9.1

References

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions