Skip to content

evo2 SAE recipe: streaming extract.py + env-overridable sweep runner#1583

Draft
polinabinder1 wants to merge 8 commits into
NVIDIA-BioNeMo:mainfrom
polinabinder1:evo2-sae-streaming-extract
Draft

evo2 SAE recipe: streaming extract.py + env-overridable sweep runner#1583
polinabinder1 wants to merge 8 commits into
NVIDIA-BioNeMo:mainfrom
polinabinder1:evo2-sae-streaming-extract

Conversation

@polinabinder1
Copy link
Copy Markdown
Collaborator

Read this before reviewing

This PR is stacked on top of #1579 (evo2-sae-recipe). Both branches live on the same fork, so GitHub can't show a base-on-fork comparison — the diff below includes #1579's commits at the bottom. The new work here is the top three commits only:

Commit Scope
a5b1db7 evo2 sae recipe: streaming-extract pipeline, drop pt_to_parquet shim 1b.sh rewrite, pt_to_parquet.py deletion, README update
9b7856a evo2 sae: streaming extractor + prok+euk FASTA composer extract.py (+206) and compose_prokeuk_fasta.py (+142), both new
89ff40c evo2_megatron predict: skip empty batches on DP shard boundary 8-line bug fix in predict.py

Net diff for the new work: 6 files, +439 / −107. Once #1579 merges into main, this PR's diff auto-cleans to just these three commits.

What changes

The recipe's extraction pipeline collapses from two steps to one. Before:

predict_evo2  --embedding-layer N  (FASTA in, .pt out)
        |
pt_to_parquet shim  (.pt -> ActivationStore parquet shards)
        |
train.py

After:

extract.py  (FASTA in, ActivationStore parquet shards out)
        |
train.py

extract.py monkey-patches predict_evo2's writer so activations stream straight into the SAE ActivationStore parquet format during inference. No .pt intermediate, no shim. This is the pipeline used for the current 100M / 500M layer-22 prok+euk training runs — the old .pt path was never used by any real run, so pt_to_parquet.py is deleted.

Files in the new work

  • scripts/extract.py (new, +206) — streaming activation extractor. Re-uses bionemo.evo2.run.predict for all the heavy machinery (Megatron model load, DP/CP/TP/PP, FASTA dataloader, inference loop) but swaps the per-batch .pt writer for an in-process ActivationStore. Handles multi-rank merge with a file-based wait (dist.barrier() no-ops after predict.main() tears down the process group).
  • scripts/compose_prokeuk_fasta.py (new, +142) — composes a prokaryotic + eukaryotic FASTA mix from OpenGenome2 shards with a hard token budget. Excludes metagenome sources; matches the training mix used in 100M / 500M runs.
  • scripts/pt_to_parquet.py (deleted) — obsoleted by extract.py.
  • scripts/1b.sh (rewritten) — 3-step pipeline (convert -> extract -> train) instead of 4. Env-overridable hyperparams (RUN_TAG, LAYER, MAX_TOKENS, MICRO_BATCH, DEVICES, EXPANSION_FACTOR, TOP_K, AUXK, AUXK_COEF, DEAD_TOKENS_THRESHOLD, N_EPOCHS, LR) so the same script drives a multi-config sweep. TRAIN_ONLY=1 skips extraction against a cached parquet. WANDB_API_KEY gates wandb.
  • evo2_megatron/.../predict.py (+8) — skip empty batches on DP shard boundaries. Fixes a hang where the last DP rank could receive 0 sequences and stall predict_evo2.
  • README.md — updated pipeline diagram and quick-start examples for the new flow.

How to use

# Default 1B prok+euk run on layer 12, 4 GPU
bash scripts/1b.sh

# Layer 22, custom FASTA, tagged paths, more epochs
LAYER=22 \
  FASTA=/data/interp/evo2/OpenGenome2/fasta/prokeuk_25M.fasta \
  RUN_TAG=25M_prokeuk \
  N_EPOCHS=4 \
  bash scripts/1b.sh

# Re-train against an existing parquet dir without re-extracting
TRAIN_ONLY=1 \
  PARQUET_DIR=/data/interp/evo2/activations/evo2_1b_base_layer22_parquet_25M_prokeuk \
  AUXK=2048 RUN_TAG=auxk2048 \
  bash scripts/1b.sh

Compose a custom mix first if needed:

python scripts/compose_prokeuk_fasta.py \
  --out /data/interp/evo2/OpenGenome2/fasta/prokeuk_25M.fasta \
  --target-tokens 25_000_000

Test plan

  • bash scripts/1b.sh end-to-end on a small FASTA: convert step skips if MBridge ckpt exists, extract writes parquet shards + metadata.json, train completes and writes checkpoint_final.pt.
  • compose_prokeuk_fasta.py --target-tokens 1_000_000 writes a FASTA close to the target (±genome-record granularity) and excludes metagenome sources.
  • TRAIN_ONLY=1 against an existing parquet skips chunk/convert/extract and runs train directly.
  • predict_evo2 on a sequence count not divisible by DP world size completes instead of hanging.
  • All env overrides (LAYER, RUN_TAG, AUXK, N_EPOCHS, etc.) surface in the printed step banners and the underlying CLI flags.

🤖 Generated with Claude Code

polinabinder1 and others added 8 commits May 21, 2026 00:42
torch 2.6 changed the default of `weights_only` to True. The Savanna
checkpoint pickle includes numpy globals (`numpy.core.multiarray._reconstruct`),
which the safer loader rejects. The converter then exits 0 with no output
written and the error gets buried in stderr — silent failure.

The Savanna repos under arcinstitute/* are trusted sources, so load with
weights_only=False.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Mirrors the existing esm2 / codonfm SAE recipes. Pipeline:

  chunk -> convert (Savanna->MBridge) -> predict_evo2 -> pt_to_parquet -> train

Differences from esm2/codonfm are forced by Evo2 specifics:
  - Hyena/Megatron-Core model, no HF AutoModel path => reuses the
    existing `predict_evo2` CLI for inference instead of writing
    a custom extract.py
  - `pt_to_parquet.py` shim bridges predict_evo2's .pt output to
    the universal `sae.activation_store` parquet contract
  - `chunk_fasta.py` preprocessor keeps inputs within the model's
    trained context length (8192 bp for 1B); Hyena fftconv OOMs
    on long sequences even at micro-batch=1
  - `train.py` is the same as codonfm's, copied verbatim per
    bionemo-recipes' KISS-over-DRY convention

Validated end-to-end on 100 organelle sequences (Evo2 1B layer 12):
loss 0.67 -> 0.045, FVU 0.90 -> 0.10, var_exp 0.10 -> 0.90, 2m14s wall.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The recipe currently has no model-specific Python module — the extractor
is upstream (`predict_evo2`) and the two scripts are simple CLIs in
scripts/. Drop the empty package and adjust pyproject.toml so setuptools
doesn't try to discover anything. Will reintroduce when there's actual
library code to put there (eval, dashboard, dataloaders).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
torchrun --nproc_per_node N can hand a rank an empty batch when the
last micro-batch falls past the shard boundary. _padding_collate_fn
then crashed in max() with "iterable argument is empty".

Return None from the collate when batch is empty and skip the loop
iteration in predict(). Required for predict_evo2 to run reliably
under DP > 1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
scripts/extract.py
  Codonfm-style streaming activation extractor. Reuses predict_evo2's
  Megatron model/DP/dataloader machinery by monkey-patching its
  _write_predictions_batch, then streams pad-stripped layer-N activations
  directly into an ActivationStore (parquet shards) inside the inference
  loop — skipping the .pt intermediate that pt_to_parquet had to walk.
  --max-tokens caps each rank's budget. File-based rank wait + merge
  (not dist.barrier — predict.main tears down the process group before
  the writer hook returns, so the barrier silently no-ops and rank 0
  races ahead; observed orphaned 18M tokens before this was fixed).
  Saves ~30 min and ~7 TB scratch per 25M-token run vs the old pipeline.

scripts/compose_prokeuk_fasta.py
  Builds a balanced prokaryotic + eukaryotic mixed FASTA from
  OpenGenome2 subsets (metagenomes + eukaryotic_genic_windows). Truncates
  metagenome contigs to --metagenome-window bp each (default 50k) — they
  average ~1.1 Mbp, so a handful of full contigs would dominate the mix.
  Emits unique seq_{i} headers so predict_evo2's dup-id check passes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
extract.py replaces the predict_evo2 -> .pt -> pt_to_parquet path with a
single streaming step that writes ActivationStore parquet shards directly
during inference. Delete the now-unused shim, rewrite 1b.sh as a 3-step
pipeline (convert -> extract -> train), and update the README accordingly.

1b.sh:
- collapse predict_evo2 + pt_to_parquet into a single 'STEP 2: extract'
  that calls torchrun extract.py
- expose RUN_TAG, PARQUET_DIR/OUTPUT_DIR, MAX_TOKENS, MICRO_BATCH, DEVICES,
  and SAE training hyperparams (EXPANSION_FACTOR, TOP_K, AUXK, AUXK_COEF,
  DEAD_TOKENS_THRESHOLD, N_EPOCHS, LR) as env overrides so the same script
  drives a multi-config sweep
- TRAIN_ONLY=1 skips chunk/convert/extract against a cached parquet
- WANDB_API_KEY gates wandb logging; WANDB_PROJECT/WANDB_RUN_NAME override

README: pipeline diagram + quick-start examples for the new env-overridable
flow; remove all references to .pt intermediates and pt_to_parquet.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
topk.py: aux-loss target was `x - recon + pre_bias`, which simplifies to
`x - decoder(codes)` -- norm dominated by ||pre_bias|| (~449 on evo2 L22)
rather than the actual reconstruction error (~8). The denominator
(`target_var = residual.pow(2).mean(-1)`) was inflated by the same factor,
so the aux gradient was scaled by roughly (||pre_bias|| / ||error||)^2 ~ 3000x
below the canonical formulation. Fix to `residual = x - recon`, matching
the OpenAI/Gao TopK formulation. Numerically verified on the 500M L22
checkpoint: residual (a) ||x - recon|| = 8.0 vs buggy (b) ||x - recon + pre_bias|| = 449.7.

1b.sh: default DEAD_TOKENS_THRESHOLD to 10_000_000, matching the train.py
default and codonfm convention (Gao et al.). Prior 500_000 default flagged
~70% of latents as 'dead' even when they were firing once per ~half-million
tokens, vs codonfm's 0.003% under the canonical threshold. Still overridable
via env.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 29, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 29, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: c96d3048-8b9b-479a-81fb-4b1cb4731915

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant