Skip to content

Mixed-precision: per-policy param/buffer dtype cast (preserve fp32 buffers)#8066

Open
sfc-gh-truwase wants to merge 4 commits into
masterfrom
mixed-precision-dtype
Open

Mixed-precision: per-policy param/buffer dtype cast (preserve fp32 buffers)#8066
sfc-gh-truwase wants to merge 4 commits into
masterfrom
mixed-precision-dtype

Conversation

@sfc-gh-truwase

Copy link
Copy Markdown
Collaborator

Summary

  • Add data_types.param_dtype and data_types.buffer_dtype (both default None), mirroring FSDP MixedPrecisionPolicy.
  • Replace the blanket module.half() / module.bfloat16() in _configure_distributed_model with a targeted cast: parameters go to param_dtype; floating buffers keep their loaded dtype unless buffer_dtype is explicitly set.

Motivation

The blanket cast downcasts every floating buffer, including the rotary inv_freq buffer that HF/FSDP2 keep in fp32. On long contexts the bf16 inv_freq loses precision, RoPE angles drift, and logits/grads diverge from the FSDP2 reference. Preserving fp32 buffers by default fixes this; buffer_dtype is the escape hatch to reproduce the legacy behavior.

Behavior

  • param_dtype unset -> derived from the fp16/bf16 enabled flag (legacy param behavior).
  • buffer_dtype unset -> buffers keep their loaded dtype (e.g. fp32 inv_freq).
  • buffer_dtype set -> buffers force-cast (legacy blanket-cast parity).

Test plan

  • param_dtype=bf16, buffer_dtype unset -> params bf16, inv_freq stays fp32.
  • buffer_dtype=bf16 -> buffers downcast (legacy parity).
  • bf16/fp16 run with neither key set behaves as before except fp32 buffers preserved.
  • 8B / 32B ZeRO-3 long-context run -> grad_norm tracks the FSDP2 reference.

Made with Cursor

…ffers)

Add data_types.param_dtype / buffer_dtype mirroring FSDP MixedPrecisionPolicy.
Replace blanket module.half()/bfloat16() with a targeted cast so floating
buffers (e.g. the rotary inv_freq) keep their loaded dtype unless buffer_dtype
is explicitly set, matching HF/FSDP2 and avoiding RoPE precision drift.

Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9aee8a9b8d

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread deepspeed/runtime/engine.py Outdated
CPU unit tests cover dtype resolution and the targeted cast helper (params cast,
fp32 buffers preserved unless buffer_dtype is set, zero-init param guard). A
bf16 end-to-end test verifies the keys flow through DeepSpeedConfig and that the
fp32 inv_freq buffer survives deepspeed.initialize.

Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
@sfc-gh-truwase sfc-gh-truwase requested a review from loadams as a code owner June 15, 2026 16:15
A data_types.param_dtype that disagrees with the fp16/bf16 enabled flag would
cast params to a dtype the optimizer/master-weight/reduction paths (which derive
the model dtype from those flags) do not expect. Validate config-only in
_do_sanity_check so the run fails before any module cast or optimizer setup.

Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

@stas00 stas00 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably should add these to the config md in the docs, no?

Comment thread deepspeed/runtime/config.py
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants