fix: EAGLE-3 training compatibility with multimodal-wrapped targets and large vocabs#535
Conversation
…nd large vocabs Three bugs that block EAGLE-3 offline training (prepare_hidden_states.py + train_eagle3.py) on multimodal-wrapped target models (e.g., Kimi-K2.5) and/or large-vocab targets (vocab > 131072). 1. wrap_eagle3_logits_processors_in_module: materialize named_modules() iterator before mutation, use set_submodule() for correct dotted-path navigation. 2. SGLangEagle3TargetModel.set_aux_hidden_states_layers: probe for .language_model before calling set_eagle3_layers_to_capture, supporting multimodal wrappers. 3. LogSoftmaxLoss Triton block size ceiling: graceful fallback to @torch.compile reference _compute_loss when vocab exceeds 131072.
There was a problem hiding this comment.
Code Review
This pull request introduces several robustness improvements and bug fixes to the EAGLE-3 implementation. Key changes include adding a fallback mechanism for the LogSoftmaxLoss Triton kernel when vocabulary sizes exceed the block-size limit, updating target model layer capturing to support multimodal wrappers by accessing the underlying language model, and refactoring the SGLang logits processor wrapping to use set_submodule and avoid iterator corruption. Review feedback suggests replacing the generic RuntimeError catch with explicit vocabulary size checks to prevent swallowing unrelated CUDA errors and correcting a positional argument mismatch in the LogitsProcessorForEAGLE3 initialization.
| try: | ||
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | ||
| except RuntimeError: | ||
| # Fused Triton kernel has a block-size ceiling (131072); fall back | ||
| # to the @torch.compile reference for large-vocab models. | ||
| loss = _compute_loss(logits, target_p, position_mask) |
There was a problem hiding this comment.
Catching a generic RuntimeError to handle the Triton block size limit is suboptimal and potentially risky as it might swallow unrelated errors (e.g., device mismatches or other CUDA errors). Since the limit is a known constant (131072), it is better to perform an explicit check on the vocabulary size (logits.shape[-1]) to decide whether to use the fused kernel or the fallback implementation.
| try: | |
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | |
| except RuntimeError: | |
| # Fused Triton kernel has a block-size ceiling (131072); fall back | |
| # to the @torch.compile reference for large-vocab models. | |
| loss = _compute_loss(logits, target_p, position_mask) | |
| if logits.shape[-1] <= 131072: | |
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | |
| else: | |
| # Fused Triton kernel has a block-size ceiling (131072); fall back | |
| # to the @torch.compile reference for large-vocab models. | |
| loss = _compute_loss(logits, target_p, position_mask) |
| try: | ||
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | ||
| except RuntimeError: | ||
| # Fused Triton kernel has a block-size ceiling (131072); fall | ||
| # back to the @torch.compile reference for large-vocab models. | ||
| loss = _compute_loss(logits, target_p, position_mask) |
There was a problem hiding this comment.
As noted in the previous comment, using an explicit check for the vocabulary size is preferred over a try-except block for handling the Triton kernel's block size constraints.
| try: | |
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | |
| except RuntimeError: | |
| # Fused Triton kernel has a block-size ceiling (131072); fall | |
| # back to the @torch.compile reference for large-vocab models. | |
| loss = _compute_loss(logits, target_p, position_mask) | |
| if logits.shape[-1] <= 131072: | |
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | |
| else: | |
| # Fused Triton kernel has a block-size ceiling (131072); fall | |
| # back to the @torch.compile reference for large-vocab models. | |
| loss = _compute_loss(logits, target_p, position_mask) |
| if isinstance(submodule, LogitsProcessor) | ||
| ] | ||
| for name, submodule in to_wrap: | ||
| wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits) |
There was a problem hiding this comment.
The return_full_logits parameter is being passed as a positional argument to LogitsProcessorForEAGLE3, which maps it to return_last_hidden_states (the second parameter) instead of return_logits (the third parameter). This appears to be a logic error given the parameter names. Using keyword arguments would make this safer and clearer.
| wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits) | |
| wrapped = LogitsProcessorForEAGLE3(submodule, return_logits=return_full_logits) |
Summary
Three bugs that block EAGLE-3 offline training (
prepare_hidden_states.py+train_eagle3.py) on multimodal-wrapped target models (e.g., Kimi-K2.5) and/or large-vocab targets (vocab > 131072).Discovered while fine-tuning
lightseekorg/kimi-k2.5-eagle3on 8×H200 via SpecForge offline mode.Bugs fixed
1.
wrap_eagle3_logits_processors_in_moduleiteration mutation + wrong targetingspecforge/modeling/target/sglang_backend/utils.pyRuntimeError: dictionary changed size during iterationon any model with nestedLogitsProcessorsubmodules (e.g.,KimiK25ForConditionalGeneration.language_model.logits_processor). Even if only the iteration bug is fixed,setattr(root, dotted_name, wrapped)silently creates a wrong attribute instead of navigating into the nested module — hidden-state capture produces junk data.list(module.named_modules())+ usemodule.set_submodule(name, wrapped)for correct dotted-path navigation.2.
SGLangEagle3TargetModel.set_aux_hidden_states_layersmissing MM wrapper delegationspecforge/modeling/target/eagle3_target_model.pyAttributeError: 'KimiK25ForConditionalGeneration' object has no attribute 'set_eagle3_layers_to_capture'because the method is on the text backbone (.language_model), not the outer wrapper.getattr(model, "language_model", model)probe before calling.3.
LogSoftmaxLossTriton block size ceiling blocks vocab > 131072specforge/core/loss.py+specforge/core/eagle3.pyRuntimeError: Cannot launch Triton kernel since n = 163840 exceeds the recommended Triton blocksize = 131072at the first training step.LogSoftmaxLoss.applyto@torch.compilereference_compute_losswhen vocab exceeds the Triton ceiling. No performance regression for models that fit.Test plan
prepare_hidden_states.pyruns to completion (1947 samples, all output files validated)train_eagle3.pytrains for 244 steps without crash, loss descends 0.46→0.26getattrfalls through,set_submoduleworks on flat models,LogSoftmaxLossstill used for vocab ≤ 131072)