Skip to content

Add Gemma 4 31B to supervised_finetuning recipes#326

Open
dgallitelli wants to merge 1 commit into
aws-samples:mainfrom
dgallitelli:add-gemma-4-31b-recipe
Open

Add Gemma 4 31B to supervised_finetuning recipes#326
dgallitelli wants to merge 1 commit into
aws-samples:mainfrom
dgallitelli:add-gemma-4-31b-recipe

Conversation

@dgallitelli

Copy link
Copy Markdown
Contributor

Add Gemma 4 31B to supervised_finetuning recipes

Summary

Adds finetune--google--gemma-4-31B-it.ipynb to 0_model_customization_recipes/supervised_finetuning/. The notebook trains Google's Gemma 4 31B with QLoRA on a single L40S GPU (ml.g6e.2xlarge, 48 GB) using the existing shared sagemaker_code/sft.py driver — no shared file is modified in this PR.

Gemma 4 introduces architectural details that the shared scaffolding doesn't yet handle out of the box, so the notebook applies a small, scoped, idempotent set of patches at job-submit time. Every patch is gated on "gemma-4" in model_id so sibling recipes are unaffected, and the final cell of the notebook restores the originals from .bak.

Validation

End-to-end SageMaker training job (google--gemma-4-31B-it-smoke-20260522211145, us-east-1):

  • Status: Completed
  • Instance: ml.g6e.2xlarge (1× L40S 48 GB)
  • Image: pytorch-training:2.7.1-gpu-py312 (DLC, unchanged)
  • Wall clock: ~17 min (image pull + pip install + 50 training steps + adapter save)
  • Billable: 1039 sec (~$0.65)
  • Loss: 5.235 → 2.737 (50 steps, 5 epochs on a 100-row slice, ~48 % drop)
  • Artifact: s3://...output/model.tar.gz (40.6 MB) → peft_adapter/ containing
    adapter_config.json, adapter_model.safetensors (49 MB), tokenizer, chat template

13 prior iterations (~$3.50 total) surfaced and resolved 7 distinct failure modes, all documented inline in the notebook's troubleshooting cell.

What the notebook patches

All patches are applied via %%writefile / Python edit cells with .bak backup, gated on Gemma 4, and reverted in the final cell:

Target Why Mitigation
sagemaker_code/requirements.txt Gemma 4 needs transformers >= 5.5.0 (shared file pins 4.57.0) Bumps transformers / peft / accelerate / trl / datasets / liger-kernel in lockstep
sagemaker_code/sft.py (1) TRL 1.x renamed ModelConfig.torch_dtype → dtype Read either field via getattr(.., "dtype") or getattr(.., "torch_dtype")
sagemaker_code/sft.py (2) Gemma4ClippableLinear wraps every linear in the model; PEFT 0.19's LoRA dispatcher rejects all targets (a) peft_config.exclude_modules = ["vision_tower", "audio_tower"] keeps PEFT off the multimodal towers; (b) on the language stack (use_clipped_linears=False) the wrapper is a no-op — replace each with its inner nn.Linear so the existing bnb-4bit dispatcher matches
sagemaker_code/sft.py (3) Gemma 4's forward pass requires mm_token_type_ids even for text-only training; SFTTrainer's padding-free path rejects custom data_collator Subclass SFTTrainer and override _prepare_inputs to inject mm_token_type_ids = zeros_like(input_ids) per batch — Trainer._prepare_inputs is the canonical extension point and runs after the default collator
sagemaker_code/utils/merge_adapter_weights.py Gemma 4 31B fp16 is ~62 GB and won't fit on the single-GPU instance the recipe targets Skip the merge step for Gemma 4. Bare LoRA adapter is the deliverable — vLLM (--enable-lora --lora-modules), DJL Serving / LMI (option.enable_lora=true), and Bedrock all consume it directly

Each patch in sft.py is < 25 lines, anchored on exact pre-existing code blocks with assert old in src so a future drift in sft.py will surface immediately rather than silently corrupting the file.

Should these patches be upstreamed?

Yes, eventually — no for this PR. Reasoning:

  • The transformers >= 5.5.0 bump cascades across peft, trl, accelerate, datasets, and liger-kernel, and TRL 1.4 renames a ModelConfig field that other recipes pass through their YAML. Bumping requirements.txt for everyone is a maintainer call and would re-validate every sibling notebook.
  • The mm_token_type_ids injection and Gemma4ClippableLinear unwrap are Gemma-4-specific. They could live in a EXCEPTION_MODEL_LIST-style branch in sft.py similar to the existing Qwen2-Audio / Qwen3-VL handling, but that's a structural change worth a separate PR.

This PR keeps the notebook self-contained so reviewers can merge it without touching shared code or other recipes. The patches are also packaged as standalone unified diffs (sft.py.patch, merge_adapter_weights.py.patch) for reference if a future PR wants to upstream them.

Files changed

  • 0_model_customization_recipes/supervised_finetuning/finetune--google--gemma-4-31B-it.ipynb (new)

Files NOT changed in this PR (but patched at job-submit time by the notebook)

  • sagemaker_code/sft.py
  • sagemaker_code/utils/merge_adapter_weights.py
  • sagemaker_code/requirements.txt
  • sagemaker_code/hf_recipes/google/gemma-4-31B-it--vanilla-peft-qlora.yaml (created at runtime)

Adds finetune--google--gemma-4-31B-it.ipynb under
0_model_customization_recipes/supervised_finetuning/ for QLoRA fine-tuning
of Google Gemma 4 31B on a single L40S GPU (ml.g6e.2xlarge).

The notebook applies a small, scoped, idempotent set of patches to
sagemaker_code/{sft.py, utils/merge_adapter_weights.py, requirements.txt}
at job-submit time so the shared training scaffolding handles Gemma 4 31B.
Every patch is gated on "gemma-4" in model_id (sibling recipes unaffected),
saves a .bak of the original, and is reverted by the final cell.

Validated end-to-end with two SageMaker training jobs in us-east-1:
- Status: Completed
- Loss: 12.32 -> 2.74 over 50 steps (78% drop)
- Wall: ~13-17 min on ml.g6e.2xlarge
- Adapter: standard PEFT 0.19 LoRA, 49 MB safetensors, vLLM/LMI-ready
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant