Skip to content

[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829

Open
sudhakarsingh27 wants to merge 12 commits intoNVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag
Open

[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829
sudhakarsingh27 wants to merge 12 commits intoNVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

@sudhakarsingh27 sudhakarsingh27 commented Apr 3, 2026

Description

Add THD (variable-length sequence) format support to AttnFuncWithCPAndKVAllGather. Previously, AllGather-based CP only supported fixed-length formats (bshd/sbhd). THD format packs variable-length sequences into a single [t, h, d] tensor tracked by cu_seqlens, which is needed for workloads with heterogeneous sequence lengths.

The key challenge is that AllGather CP splits Q across 2 steps (one per local chunk), but THD tensors can't be naively sliced like fixed-length formats. This PR uses an offset-based approach: the full Q tensor is passed to the cuDNN kernel each step, with per-step cu_seqlens_q_padded values directing the kernel to read the correct chunk. This avoids tensor slicing entirely and leverages cuDNN's back-padding convention (valid tokens at the beginning of each padded allocation).

Fixes # (issue)

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

Please list the changes introduced in this PR:

  • Offset-based Q chunking: Per-step cu_seqlens_q_padded selects which chunk the kernel reads from the full Q tensor, instead of slicing Q per step
  • Per-step KV cu_seqlens: Computes visible KV token counts per step for causal masking (chunks 0..chunk_id) and non-causal (all tokens)
  • THD reorder reuse: Reuses the existing reorder_seq_chunks_*_thd helpers (originally for A2A) to reorder all-gathered KV into contiguous per-sequence order
  • max_logit masking fix: Handles non-zero-starting cu_seqlens_q_padded in the valid-token mask (step 1's padded offsets don't start at 0)
  • Test gates: Enables THD+all_gather for FusedAttention tests; skips FlashAttention (no THD padding support)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

… cu_seqlens

- Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing
- Use padded cu_seqlens_kv for K/V reordering (ensures divisibility)
- Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature
- Compute per-step Q and KV cu_seqlens correctly from actual seqlens
- Support non-causal attention (all KV visible)
- Zero-initialize out/dq for THD to avoid garbage in padding regions
- Save per-step cu_seqlens in ctx for backward (avoid recomputation)

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Remove skip gates that blocked THD format with all_gather CP comm type.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…seqlens_q_padded

The interleaved valid mask computation assumed cu_seqlens_q_padded starts
at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at
a non-zero offset, causing a size mismatch. Use absolute positions from
cu_seqlens_q_padded to build the valid mask instead.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
if qkv_format == "thd":
# [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d]
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
k_ag = reorder_seq_chunks_after_a2a_before_attn_thd(
Copy link
Copy Markdown
Collaborator Author

@sudhakarsingh27 sudhakarsingh27 Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reorder_seq_chunks_after_a2a_before_attn_thd and the other releated method are not "a2a" specific now, rename them to something like dualchunk_to_contiguous_order_thd and the other one contiguous_to_dualchunk_order_thd

@sudhakarsingh27 sudhakarsingh27 changed the title Cp thd swa with ag [PyTorch][CP] Add THD format support for AllGather-based Context Parallelism Apr 13, 2026
@sudhakarsingh27 sudhakarsingh27 marked this pull request as ready for review April 13, 2026 21:53
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 13, 2026

Greptile Summary

This PR adds THD (variable-length sequence) format support to AttnFuncWithCPAndKVAllGather using an offset-based chunking strategy: instead of slicing Q per step, the full Q tensor is passed to the cuDNN kernel with per-step cu_seqlens_q_padded offsets directing each kernel call to the correct chunk. It also wires cu_seqlens_kv and cu_seqlens_kv_padded through the forward/backward signatures (removing the old args.pop() workaround) and fixes max_logit masking to handle non-zero-starting padded offsets.

  • P1 (backward dK/dV): The new THD backward uses dk.add_(dk_per_step[i-1]) and dv.add_(dv_per_step[i-1]), relying on the cuDNN kernel to zero non-valid positions. The A2A path for THD uses tex.thd_grad_correction specifically to handle valid-vs-padding boundaries — the inconsistency raises doubt about whether a plain add_() is safe here.
  • P2 (performance): The max-logit masking in fused_attn.py was refactored to an element-wise Python loop calling .item() on GPU tensors, introducing 2×B CUDA synchronizations per forward pass where the previous approach was fully vectorized.

Confidence Score: 4/5

Mergeable after the dK/dV backward correctness assumption is confirmed; plain add_() diverges from A2A path's explicit thd_grad_correction.

One P1 concern around THD backward dK/dV accumulation relying on undocumented kernel zeroing behavior — if wrong, gradients are silently corrupted. Remaining findings are P2.

context_parallel.py (backward dK/dV around line 3436) and fused_attn.py (max_logit loop around line 686)

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Core THD+AllGather CP implementation; backward dK/dV accumulation diverges from A2A path's tex.thd_grad_correction — correctness depends on unverified cuDNN behavior; also has a debug comment and a dead variable.
transformer_engine/pytorch/cpp_extensions/fused_attn.py Fixes max_logit masking for non-zero cu_seqlens_q_padded offsets correctly, but replaces vectorized logic with a per-element loop introducing 2×B GPU synchronizations.
tests/pytorch/attention/test_attention_with_cp.py Removes THD+all_gather skip for FusedAttention tests and updates FlashAttention skip message; clean gating change.

Comments Outside Diff (1)

  1. transformer_engine/pytorch/cpp_extensions/fused_attn.py, line 686-692 (link)

    P2 Per-element .item() calls introduce 2×B CUDA synchronizations

    The rewrite replaces a fully-vectorized mask construction with a Python loop calling .item() on GPU tensors twice per sequence. Each .item() forces a host–device sync. A drop-in fix reducing syncs to 2:

    tq = max_tensor.shape[0]
    valid = torch.zeros(tq, dtype=torch.bool, device=max_tensor.device)
    starts = cu_seqlens_q_padded[:-1].tolist()   # single sync
    n_valids = actual_seqlens.tolist()            # single sync
    for b_idx in range(b):
        valid[starts[b_idx] : starts[b_idx] + int(n_valids[b_idx])] = True

Reviews (1): Last reviewed commit: "Merge branch 'cp_thd_swa_with_ag' of git..." | Re-trigger Greptile

Comment on lines +3436 to +3440
# dK/dV: add full tensor (kernel zeros non-valid positions)
if i > 1:
flash_attn_streams[i - 1].wait_event(dkv_update_done)
dk.add_(dk_per_step[i - 1])
dv.add_(dv_per_step[i - 1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 THD backward dK/dV relies on unverified cuDNN zeroing behavior

The comment says "kernel zeros non-valid positions", but this assumption is not documented in the cuDNN/TE spec. The A2A backward for THD uses tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, ...) specifically to handle the valid/padding boundary — a plain add_() was not considered sufficient there. If fused_attn_bwd leaves positions beyond cu_seqlens_kv_per_step[i] uninitialised in its output, both steps contribute garbage at non-overlapping KV ranges, which propagates through reduce_scatter_along_first_dim into the final dK/dV.

Before merging, either confirm (and document) that NVTE_F16_arbitrary_seqlen zeros non-valid dK/dV entries, or add explicit zeroing/use tex.thd_grad_correction if applicable to the contiguous-KV layout.

Comment on lines +3007 to +3011
# [AG+THD] Is this needed?
visible_actual = [
torch.minimum(actual_seqlens_kv, visible_padded_split)
for visible_padded_split in visible_padded
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unresolved development comment left in production code

# [AG+THD] Is this needed? reads like an open question from a debug session. The torch.minimum clamp is required: for sequences whose length is not a multiple of 2 * cp_size, padded_chunk_sizes_kv * (chunk_id + 1) can exceed actual_seqlens_kv[b], causing cu_seqlens_kv passed to the kernel to count padding as valid tokens. The comment should be resolved or removed.

if ctx.qkv_format == "thd":
cu_seqlens_kv_padded = ctx.cu_seqlens_kv_padded
thd_cu_seqlens_q_per_step = ctx.thd_cu_seqlens_q_per_step
cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead variable in backward pass

cu_seqlens_q_padded_rank is computed here but never read in the backward. The padded offsets are loaded from ctx.thd_cu_seqlens_q_padded_per_step a few lines later. This line can be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant