Skip to content

fix: EAGLE-3 training compatibility with multimodal-wrapped targets and large vocabs#535

Open
elad-inferize wants to merge 1 commit into
sgl-project:mainfrom
elad-inferize:fix/mm-wrapper-eagle3-compat
Open

fix: EAGLE-3 training compatibility with multimodal-wrapped targets and large vocabs#535
elad-inferize wants to merge 1 commit into
sgl-project:mainfrom
elad-inferize:fix/mm-wrapper-eagle3-compat

Conversation

@elad-inferize
Copy link
Copy Markdown

Summary

Three bugs that block EAGLE-3 offline training (prepare_hidden_states.py + train_eagle3.py) on multimodal-wrapped target models (e.g., Kimi-K2.5) and/or large-vocab targets (vocab > 131072).

Discovered while fine-tuning lightseekorg/kimi-k2.5-eagle3 on 8×H200 via SpecForge offline mode.

Bugs fixed

1. wrap_eagle3_logits_processors_in_module iteration mutation + wrong targeting

  • File: specforge/modeling/target/sglang_backend/utils.py
  • Symptoms: RuntimeError: dictionary changed size during iteration on any model with nested LogitsProcessor submodules (e.g., KimiK25ForConditionalGeneration.language_model.logits_processor). Even if only the iteration bug is fixed, setattr(root, dotted_name, wrapped) silently creates a wrong attribute instead of navigating into the nested module — hidden-state capture produces junk data.
  • Fix: materialize list(module.named_modules()) + use module.set_submodule(name, wrapped) for correct dotted-path navigation.

2. SGLangEagle3TargetModel.set_aux_hidden_states_layers missing MM wrapper delegation

  • File: specforge/modeling/target/eagle3_target_model.py
  • Symptoms: AttributeError: 'KimiK25ForConditionalGeneration' object has no attribute 'set_eagle3_layers_to_capture' because the method is on the text backbone (.language_model), not the outer wrapper.
  • Fix: getattr(model, "language_model", model) probe before calling.

3. LogSoftmaxLoss Triton block size ceiling blocks vocab > 131072

  • File: specforge/core/loss.py + specforge/core/eagle3.py
  • Symptoms: RuntimeError: Cannot launch Triton kernel since n = 163840 exceeds the recommended Triton blocksize = 131072 at the first training step.
  • Fix: graceful fallback from fused LogSoftmaxLoss.apply to @torch.compile reference _compute_loss when vocab exceeds the Triton ceiling. No performance regression for models that fit.

Test plan

  • Verified on Kimi-K2.5 (vocab 163840, multimodal wrapper) with 8×H200, SpecForge offline mode
  • prepare_hidden_states.py runs to completion (1947 samples, all output files validated)
  • train_eagle3.py trains for 244 steps without crash, loss descends 0.46→0.26
  • No regression on non-wrapped models (Llama-3 path unchanged — getattr falls through, set_submodule works on flat models, LogSoftmaxLoss still used for vocab ≤ 131072)

…nd large vocabs

Three bugs that block EAGLE-3 offline training (prepare_hidden_states.py +
train_eagle3.py) on multimodal-wrapped target models (e.g., Kimi-K2.5) and/or
large-vocab targets (vocab > 131072).

1. wrap_eagle3_logits_processors_in_module: materialize named_modules() iterator
   before mutation, use set_submodule() for correct dotted-path navigation.

2. SGLangEagle3TargetModel.set_aux_hidden_states_layers: probe for .language_model
   before calling set_eagle3_layers_to_capture, supporting multimodal wrappers.

3. LogSoftmaxLoss Triton block size ceiling: graceful fallback to @torch.compile
   reference _compute_loss when vocab exceeds 131072.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several robustness improvements and bug fixes to the EAGLE-3 implementation. Key changes include adding a fallback mechanism for the LogSoftmaxLoss Triton kernel when vocabulary sizes exceed the block-size limit, updating target model layer capturing to support multimodal wrappers by accessing the underlying language model, and refactoring the SGLang logits processor wrapping to use set_submodule and avoid iterator corruption. Review feedback suggests replacing the generic RuntimeError catch with explicit vocabulary size checks to prevent swallowing unrelated CUDA errors and correcting a positional argument mismatch in the LogitsProcessorForEAGLE3 initialization.

Comment thread specforge/core/eagle3.py
Comment on lines +95 to +100
try:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
except RuntimeError:
# Fused Triton kernel has a block-size ceiling (131072); fall back
# to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a generic RuntimeError to handle the Triton block size limit is suboptimal and potentially risky as it might swallow unrelated errors (e.g., device mismatches or other CUDA errors). Since the limit is a known constant (131072), it is better to perform an explicit check on the vocabulary size (logits.shape[-1]) to decide whether to use the fused kernel or the fallback implementation.

Suggested change
try:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
except RuntimeError:
# Fused Triton kernel has a block-size ceiling (131072); fall back
# to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)
if logits.shape[-1] <= 131072:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
else:
# Fused Triton kernel has a block-size ceiling (131072); fall back
# to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)

Comment thread specforge/core/eagle3.py
Comment on lines +561 to +566
try:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
except RuntimeError:
# Fused Triton kernel has a block-size ceiling (131072); fall
# back to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

As noted in the previous comment, using an explicit check for the vocabulary size is preferred over a try-except block for handling the Triton kernel's block size constraints.

Suggested change
try:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
except RuntimeError:
# Fused Triton kernel has a block-size ceiling (131072); fall
# back to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)
if logits.shape[-1] <= 131072:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
else:
# Fused Triton kernel has a block-size ceiling (131072); fall
# back to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)

if isinstance(submodule, LogitsProcessor)
]
for name, submodule in to_wrap:
wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return_full_logits parameter is being passed as a positional argument to LogitsProcessorForEAGLE3, which maps it to return_last_hidden_states (the second parameter) instead of return_logits (the third parameter). This appears to be a logic error given the parameter names. Using keyword arguments would make this safer and clearer.

Suggested change
wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits)
wrapped = LogitsProcessorForEAGLE3(submodule, return_logits=return_full_logits)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant