Add doc_packed_attn: AutoParallel-shardable document-packed varlen attention#508
Open
fmassa wants to merge 1 commit into
Open
Add doc_packed_attn: AutoParallel-shardable document-packed varlen attention#508fmassa wants to merge 1 commit into
fmassa wants to merge 1 commit into
Conversation
…tention
Summary
Adds `doc_packed_attn` — a document-packed variable-length attention op
whose tensor shapes survive AutoParallel sharding without ragged-per-rank
shape gymnastics. Wraps `torch.nn.attention.varlen.varlen_attn` with a
`[B, S, H, D]` / `[B, MAX_DOCS+1]` API so the leading dim is a clean
`Shard(0)` for DP and the heads dim is a clean `Shard(2)` for TP.
The packed-cu_seq_q convention: each batch element's row stores its real
document boundaries followed by repeated `S` sentinels, so trailing entries
become zero-length docs that the flash kernel skips. This keeps the per-rank
shape uniform across all ranks even when batches carry different document
counts, avoiding the ragged shape that DTensor placements can't represent.
Review order
1. `autoparallel/ops.py` — two custom ops (`doc_packed_attn_op`,
`doc_packed_attn_backward_op`), the autograd glue, and the Python helper
`doc_packed_attn` that end users call. The backward saves
`(out, lse, rng_state)` from the forward and dispatches directly to
`torch_attn::_varlen_attn_backward` — no forward recompute.
2. `autoparallel/shardings/propagation_rules.py` — two strategies built with
`expand_to_full_mesh_op_strategy`, one per op. Each enumerates Replicate /
Shard-batch / Shard-heads per mesh dim. Both ops see the same
`[B, S, H, D]` layout so the strategies are near-mirrors.
3. `tests/test_ops.py::TestDocPackedAttn` — 15 parity tests covering B=1,
B>1 uniform, B>1 mixed-doc-count (zero-length padding), causal /
full / sliding-window, plus a profile-based assertion that the backward
doesn't recompute the forward.
4. `tests/test_ops.py::TestDocPackedAttnShardingStrategy` — 3 AP integration
tests confirming the solver picks Shard(0) on a 1D dp mesh and
(Shard(0), Shard(2)) on a 2D dp×tp mesh.
Phase 1 scope
- Self-attention only (`cu_seq_k == cu_seq_q`, `max_q == max_k == S`).
- No context parallelism (S dim is never sharded).
- `enable_gqa` is exposed via the kernel's native broadcast — no manual
`repeat_interleave`.
- `max_q` is always passed as `S` (the trivial upper bound) to keep the
graph sync-free; the kernel uses it only for workspace sizing.
- [ ] `pytest tests/test_ops.py -v` — 36 passed.
- [ ] `pytest tests/test_api.py tests/test_dtensor.py tests/test_correctness.py` —
31 passed (no regressions).
- [ ] Manual smoke: `AutoParallel` traces a model that calls
`doc_packed_attn`, optimizer picks DP+TP placements, `apply_placement`
succeeds end-to-end.
Authored with Claude.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
doc_packed_attn— a document-packed variable-length attention op whose tensor shapes survive AutoParallel sharding without ragged-per-rank shape gymnastics. Wrapstorch.nn.attention.varlen.varlen_attnwith a[B, S, H, D]/[B, MAX_DOCS+1]API so the leading dim is a cleanShard(0)for DP and the heads dim is a cleanShard(2)for TP.The packed-cu_seq_q convention: each batch element's row stores its real document boundaries followed by repeated
Ssentinels, so trailing entries become zero-length docs that the flash kernel skips. This keeps the per-rank shape uniform across all ranks even when batches carry different document counts, avoiding the ragged shape that DTensor placements can't represent.Review order
autoparallel/ops.py— two custom ops (doc_packed_attn_op,doc_packed_attn_backward_op), the autograd glue, and the Python helperdoc_packed_attnthat end users call. The backward saves(out, lse, rng_state)from the forward and dispatches directly totorch_attn::_varlen_attn_backward— no forward recompute.autoparallel/shardings/propagation_rules.py— two strategies built withexpand_to_full_mesh_op_strategy, one per op. Each enumerates Replicate / Shard-batch / Shard-heads per mesh dim. Both ops see the same[B, S, H, D]layout so the strategies are near-mirrors.tests/test_ops.py::TestDocPackedAttn— 15 parity tests covering B=1, B>1 uniform, B>1 mixed-doc-count (zero-length padding), causal / full / sliding-window, plus a profile-based assertion that the backward doesn't recompute the forward.tests/test_ops.py::TestDocPackedAttnShardingStrategy— 3 AP integration tests confirming the solver picks Shard(0) on a 1D dp mesh and (Shard(0), Shard(2)) on a 2D dp×tp mesh.Phase 1 scope
cu_seq_k == cu_seq_q,max_q == max_k == S).enable_gqais exposed via the kernel's native broadcast — no manualrepeat_interleave.max_qis always passed asS(the trivial upper bound) to keep the graph sync-free; the kernel uses it only for workspace sizing.Test plan
pytest tests/test_ops.py -v— 36 passed.pytest tests/test_api.py tests/test_dtensor.py tests/test_correctness.py— 31 passed (no regressions).AutoParalleltraces a model that callsdoc_packed_attn, optimizer picks DP+TP placements,apply_placementsucceeds end-to-end.Authored with Claude.