Skip to content

fix: WanAttentionBlock gradient checkpointing + FSDP2 dtype mismatch#354

Open
Jack47 wants to merge 1 commit intoWan-Video:mainfrom
Jack47:fix/gc-fsdp2-dtype-mismatch
Open

fix: WanAttentionBlock gradient checkpointing + FSDP2 dtype mismatch#354
Jack47 wants to merge 1 commit intoWan-Video:mainfrom
Jack47:fix/gc-fsdp2-dtype-mismatch

Conversation

@Jack47
Copy link
Copy Markdown

@Jack47 Jack47 commented Apr 21, 2026

Summary

  • Fix CheckpointError when using torch.utils.checkpoint with FSDP2 mixed precision on WanAttentionBlock
  • Replace torch.amp.autocast('cuda', dtype=torch.float32) inside the block with explicit .float() + .to(dtype=compute_dtype), making it autocast-agnostic
  • Add gradient checkpointing support to WanModel via _gc_autocast() context_fn

Problem

Under FSDP2 with MixedPrecisionPolicy(param_dtype=bf16), gradient checkpointing recompute runs without the outer bf16 autocast that FSDP2 provides during forward. The inner fp32 autocast in WanAttentionBlock.forward() then produces fp32 tensors during recompute vs bf16 during forward, triggering:

torch.utils.checkpoint.CheckpointError: Checkpoint state_dict has mismatched tensor dtypes

Fix

Replace autocast context managers with explicit dtype casts:

# Before (autocast-dependent, breaks under GC recompute):
with torch.amp.autocast('cuda', dtype=torch.float32):
    e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)

# After (autocast-agnostic, works with GC):
compute_dtype = x.dtype
e = (self.modulation.float().unsqueeze(0) + e).chunk(6, dim=2)

The fix is numerically equivalent: bitwise identical in fp32, <1e-2 tolerance under bf16 autocast.

Impact

Enables gradient checkpointing for Wan2.2 training, reducing peak memory by 63-74%:

Config (5B) GC off GC on Reduction
37f mbs=2 109 GB 40.7 GB 63%
121f (5s) mbs=2 OOM 51.1 GB
121f (5s) mbs=4 OOM 69.8 GB

Usage

model = WanModel(...)
model.gradient_checkpointing = True  # enable GC

Fixes #353

… dtype mismatch

Replace torch.amp.autocast('cuda', dtype=torch.float32) context managers
inside WanAttentionBlock.forward() with explicit .float() upcasts and
.to(dtype=compute_dtype) downcasts.

The autocast approach is incompatible with gradient checkpointing under
FSDP2 mixed precision: during recompute, the outer bf16 autocast (provided
by FSDP2) is not replayed, causing saved tensors (bf16) to diverge from
recomputed tensors (fp32), triggering CheckpointError.

The explicit cast approach is autocast-agnostic and produces numerically
identical results (bitwise identical in fp32, <1e-2 tolerance in bf16).

Also adds gradient checkpointing support to WanModel via a _gc_autocast()
context_fn that infers the autocast dtype from x.dtype rather than
hardcoding bf16.

Fixes Wan-Video#353
@Jack47 Jack47 force-pushed the fix/gc-fsdp2-dtype-mismatch branch from d054eb3 to 218f206 Compare April 21, 2026 07:33
Copy link
Copy Markdown

@ivanpoh11 ivanpoh11 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

WanAttentionBlock: gradient checkpointing + FSDP2 mixed precision causes CheckpointError due to dtype mismatch

2 participants