Skip to content

Add doc_packed_attn: AutoParallel-shardable document-packed varlen attention#508

Open
fmassa wants to merge 1 commit into
mainfrom
fmassa/packed_doc_attention
Open

Add doc_packed_attn: AutoParallel-shardable document-packed varlen attention#508
fmassa wants to merge 1 commit into
mainfrom
fmassa/packed_doc_attention

Conversation

@fmassa

@fmassa fmassa commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds doc_packed_attn — a document-packed variable-length attention op whose tensor shapes survive AutoParallel sharding without ragged-per-rank shape gymnastics. Wraps torch.nn.attention.varlen.varlen_attn with a [B, S, H, D] / [B, MAX_DOCS+1] API so the leading dim is a clean Shard(0) for DP and the heads dim is a clean Shard(2) for TP.

The packed-cu_seq_q convention: each batch element's row stores its real document boundaries followed by repeated S sentinels, so trailing entries become zero-length docs that the flash kernel skips. This keeps the per-rank shape uniform across all ranks even when batches carry different document counts, avoiding the ragged shape that DTensor placements can't represent.

Review order

  1. autoparallel/ops.py — two custom ops (doc_packed_attn_op, doc_packed_attn_backward_op), the autograd glue, and the Python helper doc_packed_attn that end users call. The backward saves (out, lse, rng_state) from the forward and dispatches directly to torch_attn::_varlen_attn_backward — no forward recompute.
  2. autoparallel/shardings/propagation_rules.py — two strategies built with expand_to_full_mesh_op_strategy, one per op. Each enumerates Replicate / Shard-batch / Shard-heads per mesh dim. Both ops see the same [B, S, H, D] layout so the strategies are near-mirrors.
  3. tests/test_ops.py::TestDocPackedAttn — 15 parity tests covering B=1, B>1 uniform, B>1 mixed-doc-count (zero-length padding), causal / full / sliding-window, plus a profile-based assertion that the backward doesn't recompute the forward.
  4. tests/test_ops.py::TestDocPackedAttnShardingStrategy — 3 AP integration tests confirming the solver picks Shard(0) on a 1D dp mesh and (Shard(0), Shard(2)) on a 2D dp×tp mesh.

Phase 1 scope

  • Self-attention only (cu_seq_k == cu_seq_q, max_q == max_k == S).
  • No context parallelism (S dim is never sharded).
  • enable_gqa is exposed via the kernel's native broadcast — no manual repeat_interleave.
  • max_q is always passed as S (the trivial upper bound) to keep the graph sync-free; the kernel uses it only for workspace sizing.

Test plan

  • pytest tests/test_ops.py -v — 36 passed.
  • pytest tests/test_api.py tests/test_dtensor.py tests/test_correctness.py — 31 passed (no regressions).
  • Manual smoke: AutoParallel traces a model that calls doc_packed_attn, optimizer picks DP+TP placements, apply_placement succeeds end-to-end.

Authored with Claude.

…tention

Summary

Adds `doc_packed_attn` — a document-packed variable-length attention op
whose tensor shapes survive AutoParallel sharding without ragged-per-rank
shape gymnastics. Wraps `torch.nn.attention.varlen.varlen_attn` with a
`[B, S, H, D]` / `[B, MAX_DOCS+1]` API so the leading dim is a clean
`Shard(0)` for DP and the heads dim is a clean `Shard(2)` for TP.

The packed-cu_seq_q convention: each batch element's row stores its real
document boundaries followed by repeated `S` sentinels, so trailing entries
become zero-length docs that the flash kernel skips. This keeps the per-rank
shape uniform across all ranks even when batches carry different document
counts, avoiding the ragged shape that DTensor placements can't represent.

Review order

1. `autoparallel/ops.py` — two custom ops (`doc_packed_attn_op`,
   `doc_packed_attn_backward_op`), the autograd glue, and the Python helper
   `doc_packed_attn` that end users call. The backward saves
   `(out, lse, rng_state)` from the forward and dispatches directly to
   `torch_attn::_varlen_attn_backward` — no forward recompute.
2. `autoparallel/shardings/propagation_rules.py` — two strategies built with
   `expand_to_full_mesh_op_strategy`, one per op. Each enumerates Replicate /
   Shard-batch / Shard-heads per mesh dim. Both ops see the same
   `[B, S, H, D]` layout so the strategies are near-mirrors.
3. `tests/test_ops.py::TestDocPackedAttn` — 15 parity tests covering B=1,
   B>1 uniform, B>1 mixed-doc-count (zero-length padding), causal /
   full / sliding-window, plus a profile-based assertion that the backward
   doesn't recompute the forward.
4. `tests/test_ops.py::TestDocPackedAttnShardingStrategy` — 3 AP integration
   tests confirming the solver picks Shard(0) on a 1D dp mesh and
   (Shard(0), Shard(2)) on a 2D dp×tp mesh.

Phase 1 scope

- Self-attention only (`cu_seq_k == cu_seq_q`, `max_q == max_k == S`).
- No context parallelism (S dim is never sharded).
- `enable_gqa` is exposed via the kernel's native broadcast — no manual
  `repeat_interleave`.
- `max_q` is always passed as `S` (the trivial upper bound) to keep the
  graph sync-free; the kernel uses it only for workspace sizing.

- [ ] `pytest tests/test_ops.py -v` — 36 passed.
- [ ] `pytest tests/test_api.py tests/test_dtensor.py tests/test_correctness.py` —
      31 passed (no regressions).
- [ ] Manual smoke: `AutoParallel` traces a model that calls
      `doc_packed_attn`, optimizer picks DP+TP placements, `apply_placement`
      succeeds end-to-end.

Authored with Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant