Skip to content

parcefal99/babylm

Repository files navigation

BabyLM Fine-Tuning Pipeline

Four-task pipeline to benchmark, fine-tune, and re-evaluate BabyLLaMa-100M on grammaticality judgments using BLiMP.

Requirements

pip install torch transformers nltk tqdm

Task 1 — Benchmark the base model

Evaluates BabyLLaMa-100M on BLiMP (grammaticality judgment via log-probability comparison).

python task1_benchmark.py \
  --model babylm/babyllama-100m-2024 \
  --data blimp_validation.json \
  --output task1_results.json
Argument Default Description
--model babylm/babyllama-100m-2024 HuggingFace model name or local path
--data blimp_validation.json BLiMP validation set
--output (none) Save results to JSON (required for Task 4 shortcut)
--device auto cuda or cpu

Output: Per-paradigm accuracy table printed to stdout + task1_results.json.


Task 2 — Generate fine-tuning data

Uses a larger LLM (e.g. Llama-3.1-8B-Instruct) to generate synthetic dialogue pairs for fine-tuning. Two versions of the dataset exist with different generation strategies:

Dataset versions

File Style Phenomena Sentences
fine_tune_data_simple.json Simple parent–child dialogues. Parent gently corrects child's grammar errors. Short, child-like sentences. 25 generic phenomena 37,109
fine_tune_data_v2.json Realistic adult dialogues. Two adult speakers, syntactically complex sentences, one corrects the other. Directly targets all 67 BLiMP paradigm UIDs. 67 BLiMP paradigms

fine_tune_data_simple.json was the first generation run. Sentences are simple and child-directed, which caused the model to regress on complex BLiMP paradigms after fine-tuning.

fine_tune_data_v2.json is the improved version. It covers all 67 BLiMP paradigms explicitly, uses adult-level sentences, and instructs the generator to produce syntactically complex minimal pairs — designed to better match BLiMP's evaluation distribution.

Running Task 2

# Simple version (child-directed, 25 phenomena)
python task2_generate_data.py \
  --model meta-llama/Llama-3.1-8B-Instruct \
  --output fine_tune_data_simple.json \
  --n_pairs 5000 \
  --max_tokens 900000 \
  --device cuda

# Improved version (adult-level, all 67 BLiMP paradigms)
python task2_generate_data.py \
  --model meta-llama/Llama-3.1-8B-Instruct \
  --output fine_tune_data_v2.json \
  --n_pairs 10000 \
  --max_tokens 899000 \
  --device cuda
Argument Default Description
--model (required) HuggingFace model for data generation
--output fine_tune_data.json Output JSONL file
--n_pairs 5000 Max number of dialogue pairs to generate
--max_tokens 900000 Word-based token budget (must be < 1,000,000)
--device cuda cuda or cpu
--append off Append to existing file instead of overwriting

Output: JSONL file with dialogue pairs and sentence_good/sentence_bad fields per entry.


Task 3 — Fine-tune BabyLM

Supervised fine-tuning (SFT) of BabyLLaMa-100M on the grammatical sentences from Task 2.

python task3_finetune.py \
  --model babylm/babyllama-100m-2024 \
  --data fine_tune_data.json \
  --output_dir ./finetuned_model \
  --epochs 3 \
  --batch_size 8 \
  --lr 5e-5 \
  --max_length 128
Argument Default Description
--model babylm/babyllama-100m-2024 Base model to fine-tune
--data fine_tune_data.json Training data from Task 2
--output_dir ./finetuned_model Directory to save fine-tuned model
--epochs 3 Number of training epochs
--batch_size 8 Batch size
--lr 5e-5 Learning rate
--max_length 128 Max token length per sentence
--device auto cuda or cpu

Output: Fine-tuned model saved to ./finetuned_model/ with per-epoch checkpoints and a full training_log.jsonl recording loss, perplexity, gradient norm, and learning rate at every step.


Task 4 — Re-evaluate and compare

Runs BLiMP evaluation on the fine-tuned model and compares it against the base model.

Option A — reuse Task 1 results (recommended, faster):

python task4_reevaluate.py \
  --base_results task1_results.json \
  --finetuned_model ./finetuned_model \
  --data blimp_validation.json \
  --output comparison_results.json

Option B — re-run base model evaluation from scratch:

python task4_reevaluate.py \
  --base_model babylm/babyllama-100m-2024 \
  --finetuned_model ./finetuned_model \
  --data blimp_validation.json \
  --output comparison_results.json
Argument Default Description
--finetuned_model (required) Path to fine-tuned model (from Task 3)
--base_results (one required) Pre-computed Task 1 JSON results
--base_model (one required) Base model name to re-evaluate
--data blimp_validation.json BLiMP validation set
--output comparison_results.json Full comparison results JSON
--device auto cuda or cpu

Output: Per-paradigm comparison table with deltas, top gains/losses, and overall conclusion.


Full pipeline (quick reference)

# Task 1 — baseline
python task1_benchmark.py --output task1_results.json

# Task 2 — generate data (use v2 for better coverage)
python task2_generate_data.py --model meta-llama/Llama-3.1-8B-Instruct --output fine_tune_data_v2.json --n_pairs 10000 --max_tokens 899000

# Task 3 — fine-tune
python task3_finetune.py --data fine_tune_data_v2.json --output_dir ./finetuned_model_v2

# Task 4 — compare
python task4_reevaluate.py --base_results task1_results.json --finetuned_model ./finetuned_model_v2

File overview

File Description
task1_benchmark.py BLiMP evaluation on base model
task2_generate_data.py Dialogue data generation via parent LLM
task3_finetune.py SFT fine-tuning of BabyLLaMa-100M
task4_reevaluate.py Post-fine-tuning evaluation and comparison
blimp_validation.json BLiMP evaluation dataset
fine_tune_data_simple.json Simple child-directed dialogues, 25 phenomena (v1 dataset)
fine_tune_data_v2.json Adult-level complex dialogues, all 67 BLiMP paradigms (v2 dataset)
task1_results.json Base model results (output of Task 1)
comparison_results.json Final comparison report (output of Task 4)
finetuned_model/ Fine-tuned model trained on v1 simple data
finetuned_model_v2/ Fine-tuned model trained on v2 complex data (output of Task 3 v2)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages