[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829
[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829sudhakarsingh27 wants to merge 12 commits intoNVIDIA:mainfrom
Conversation
… cu_seqlens - Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing - Use padded cu_seqlens_kv for K/V reordering (ensures divisibility) - Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature - Compute per-step Q and KV cu_seqlens correctly from actual seqlens - Support non-causal attention (all KV visible) - Zero-initialize out/dq for THD to avoid garbage in padding regions - Save per-step cu_seqlens in ctx for backward (avoid recomputation) Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Remove skip gates that blocked THD format with all_gather CP comm type. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…seqlens_q_padded The interleaved valid mask computation assumed cu_seqlens_q_padded starts at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at a non-zero offset, causing a size mismatch. Use absolute positions from cu_seqlens_q_padded to build the valid mask instead. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1164a15 to
b4db9eb
Compare
for more information, see https://pre-commit.ci
| if qkv_format == "thd": | ||
| # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] | ||
| chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) | ||
| k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( |
There was a problem hiding this comment.
This reorder_seq_chunks_after_a2a_before_attn_thd and the other releated method are not "a2a" specific now, rename them to something like dualchunk_to_contiguous_order_thd and the other one contiguous_to_dualchunk_order_thd
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…formerEngine into cp_thd_swa_with_ag
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
…formerEngine into cp_thd_swa_with_ag
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
Greptile SummaryThis PR adds THD (variable-length sequence) format support to
Confidence Score: 4/5Mergeable after the dK/dV backward correctness assumption is confirmed; plain add_() diverges from A2A path's explicit thd_grad_correction. One P1 concern around THD backward dK/dV accumulation relying on undocumented kernel zeroing behavior — if wrong, gradients are silently corrupted. Remaining findings are P2. context_parallel.py (backward dK/dV around line 3436) and fused_attn.py (max_logit loop around line 686) Important Files Changed
|
| # dK/dV: add full tensor (kernel zeros non-valid positions) | ||
| if i > 1: | ||
| flash_attn_streams[i - 1].wait_event(dkv_update_done) | ||
| dk.add_(dk_per_step[i - 1]) | ||
| dv.add_(dv_per_step[i - 1]) |
There was a problem hiding this comment.
THD backward dK/dV relies on unverified cuDNN zeroing behavior
The comment says "kernel zeros non-valid positions", but this assumption is not documented in the cuDNN/TE spec. The A2A backward for THD uses tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, ...) specifically to handle the valid/padding boundary — a plain add_() was not considered sufficient there. If fused_attn_bwd leaves positions beyond cu_seqlens_kv_per_step[i] uninitialised in its output, both steps contribute garbage at non-overlapping KV ranges, which propagates through reduce_scatter_along_first_dim into the final dK/dV.
Before merging, either confirm (and document) that NVTE_F16_arbitrary_seqlen zeros non-valid dK/dV entries, or add explicit zeroing/use tex.thd_grad_correction if applicable to the contiguous-KV layout.
| # [AG+THD] Is this needed? | ||
| visible_actual = [ | ||
| torch.minimum(actual_seqlens_kv, visible_padded_split) | ||
| for visible_padded_split in visible_padded | ||
| ] |
There was a problem hiding this comment.
Unresolved development comment left in production code
# [AG+THD] Is this needed? reads like an open question from a debug session. The torch.minimum clamp is required: for sequences whose length is not a multiple of 2 * cp_size, padded_chunk_sizes_kv * (chunk_id + 1) can exceed actual_seqlens_kv[b], causing cu_seqlens_kv passed to the kernel to count padding as valid tokens. The comment should be resolved or removed.
| if ctx.qkv_format == "thd": | ||
| cu_seqlens_kv_padded = ctx.cu_seqlens_kv_padded | ||
| thd_cu_seqlens_q_per_step = ctx.thd_cu_seqlens_q_per_step | ||
| cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 |
Description
Add THD (variable-length sequence) format support to
AttnFuncWithCPAndKVAllGather. Previously, AllGather-based CP only supported fixed-length formats (bshd/sbhd). THD format packs variable-length sequences into a single[t, h, d]tensor tracked bycu_seqlens, which is needed for workloads with heterogeneous sequence lengths.The key challenge is that AllGather CP splits Q across 2 steps (one per local chunk), but THD tensors can't be naively sliced like fixed-length formats. This PR uses an offset-based approach: the full Q tensor is passed to the cuDNN kernel each step, with per-step
cu_seqlens_q_paddedvalues directing the kernel to read the correct chunk. This avoids tensor slicing entirely and leverages cuDNN's back-padding convention (valid tokens at the beginning of each padded allocation).Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
cu_seqlens_q_paddedselects which chunk the kernel reads from the full Q tensor, instead of slicing Q per stepreorder_seq_chunks_*_thdhelpers (originally for A2A) to reorder all-gathered KV into contiguous per-sequence ordercu_seqlens_q_paddedin the valid-token mask (step 1's padded offsets don't start at 0)Checklist: