fix: WanAttentionBlock gradient checkpointing + FSDP2 dtype mismatch#354
Open
Jack47 wants to merge 1 commit intoWan-Video:mainfrom
Open
fix: WanAttentionBlock gradient checkpointing + FSDP2 dtype mismatch#354Jack47 wants to merge 1 commit intoWan-Video:mainfrom
Jack47 wants to merge 1 commit intoWan-Video:mainfrom
Conversation
… dtype mismatch
Replace torch.amp.autocast('cuda', dtype=torch.float32) context managers
inside WanAttentionBlock.forward() with explicit .float() upcasts and
.to(dtype=compute_dtype) downcasts.
The autocast approach is incompatible with gradient checkpointing under
FSDP2 mixed precision: during recompute, the outer bf16 autocast (provided
by FSDP2) is not replayed, causing saved tensors (bf16) to diverge from
recomputed tensors (fp32), triggering CheckpointError.
The explicit cast approach is autocast-agnostic and produces numerically
identical results (bitwise identical in fp32, <1e-2 tolerance in bf16).
Also adds gradient checkpointing support to WanModel via a _gc_autocast()
context_fn that infers the autocast dtype from x.dtype rather than
hardcoding bf16.
Fixes Wan-Video#353
d054eb3 to
218f206
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
CheckpointErrorwhen usingtorch.utils.checkpointwith FSDP2 mixed precision onWanAttentionBlocktorch.amp.autocast('cuda', dtype=torch.float32)inside the block with explicit.float()+.to(dtype=compute_dtype), making it autocast-agnosticWanModelvia_gc_autocast()context_fnProblem
Under FSDP2 with
MixedPrecisionPolicy(param_dtype=bf16), gradient checkpointing recompute runs without the outer bf16 autocast that FSDP2 provides during forward. The inner fp32 autocast inWanAttentionBlock.forward()then produces fp32 tensors during recompute vs bf16 during forward, triggering:Fix
Replace autocast context managers with explicit dtype casts:
The fix is numerically equivalent: bitwise identical in fp32, <1e-2 tolerance under bf16 autocast.
Impact
Enables gradient checkpointing for Wan2.2 training, reducing peak memory by 63-74%:
Usage
Fixes #353