From bb5494033a6f6c72beee7de55c93669093dcce1c Mon Sep 17 00:00:00 2001 From: Fausto Milletari Date: Tue, 30 Jun 2026 16:45:22 +0000 Subject: [PATCH 1/3] Sync internal codebase: bug fixes and binder design improvements Sync from internal codebase implementing this package. Substantive changes across binder design tutorial, MSA handling, molecular complex (adds pde, num_tokens fields), sdk/api, structure input builder, and widget prompt selectors (fix entry deletion via range_string instead of label.tag). Co-Authored-By: Claude Opus 4.8 (1M context) --- README.md | 33 +- cookbook/tutorials/binder_design.ipynb | 555 ++++------------------- cookbook/tutorials/binder_design.py | 282 +++++++----- cookbook/tutorials/embed.ipynb | 6 +- cookbook/tutorials/esm3_generate.ipynb | 26 +- cookbook/tutorials/esmfold2.ipynb | 66 +++ esm/models/vqvae.py | 2 +- esm/sdk/api.py | 46 +- esm/utils/msa/msa.py | 12 +- esm/utils/msa/msa_test.py | 4 +- esm/utils/structure/input_builder.py | 2 + esm/utils/structure/molecular_complex.py | 3 + esm/widgets/utils/prompting.py | 3 +- esm/widgets/utils/protein_import.py | 1 - 14 files changed, 414 insertions(+), 627 deletions(-) diff --git a/README.md b/README.md index 7b6ddb3e..b0e4f4f5 100644 --- a/README.md +++ b/README.md @@ -54,16 +54,16 @@ Codebase, model weights, and model variants for ESMC are available through [Hugg There are two primary ways of running the ESM models: through the [**Biohub Platform**](https://biohub.ai/) or locally with Hugging Face. The Biohub Platform enables users to easily run inference with ESM models with minimal setup. Users interested in customizing or fine-tuning ESM models can use the models from Hugging Face. -### Running ESMC Locally - +### Running ESMC Through Hugging Face + -Install `esm` from GitHub (a PyPI release is coming soon): +First, install `esm` from GitHub (a PyPI release is coming soon): ``` pip install esm@git+https://github.com/Biohub/esm.git@main ``` -The following code demonstrates how to run ESMC locally +Then use the following code to run ESMC using the Transformers library via Hugging Face: ```python import torch @@ -144,6 +144,16 @@ The sparse autoencoder used in the Atlas and analyzed in the paper, `ESMC-6B-sae Codebase, model weights, and model variants for ESMC SAEs are available through [Hugging Face](https://huggingface.co/collections/biohub/esmc-saes-for-hidden-states-all-layers). +### Running SAEs Through Hugging Face + +First, install `esm` from GitHub (a PyPI release is coming soon): + +``` +pip install esm@git+https://github.com/Biohub/esm.git@main +``` + +Then use the following code to set up an ESMC SAE using the Transformers library via Hugging Face: + ```python import torch from transformers import AutoModel, AutoTokenizer @@ -170,19 +180,30 @@ output["sae_outputs"]["layer60"] # sparse.coo tensor print(output["sae_outputs"]["layer60"].shape) ``` +### Running SAEs Through The Biohub Platform -For tutorials on how to use ESMC SAEs, see our [tutorials](https://github.com/Biohub/esm/tree/main/cookbook/tutorials). +For a tutorial on using SAEs using the Biohub Platform, see [here](https://github.com/Biohub/esm/blob/main/cookbook/tutorials/esmc_sae_feature_interpretation.ipynb). ## ESMFold2 + [ESMFold2](https://huggingface.co/biohub/ESMFold2) is a state-of-the-art protein structure prediction model that combines ESMC (6B parameter) language model embeddings with a diffusion-based structure prediction architecture. The model predicts high-resolution, all-atom 3D protein structures directly from amino acid sequences, with optional multiple sequence alignment (MSA) input for enhanced accuracy on challenging targets. ESMFold2 achieves state-of-the-art performance matching or exceeding AlphaFold3 across diverse evaluation datasets, while offering improved computational efficiency through optimized diffusion sampling and architectural innovations. Codebase, model weights, and model variants for ESMFold2 are available through [Hugging Face](https://huggingface.co/biohub/ESMFold2) -### Running ESMFold2 Locally +### Running ESMFold2 Through Hugging Face + + +First, install `esm` from GitHub (a PyPI release is coming soon): + +``` +pip install esm@git+https://github.com/Biohub/esm.git@main +``` + +Then use the following code to run ESMFold2 locally using the Transformers library via Hugging Face: ```python from esm.models.esmfold2 import ( diff --git a/cookbook/tutorials/binder_design.ipynb b/cookbook/tutorials/binder_design.ipynb index be4ddd1c..8985cd5d 100644 --- a/cookbook/tutorials/binder_design.ipynb +++ b/cookbook/tutorials/binder_design.ipynb @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "37c03b59", "metadata": {}, "outputs": [], @@ -141,7 +141,7 @@ "\n", "`modal.Cls.from_name(...)` grabs a handle to the design app you just deployed, without rerunning anything on Modal yet. Instantiating it gives you `app`, which is what you'll call `app.design.spawn(...)` on to launch design jobs.\n", "\n", - "`use_scaling_critics=False` is the default. Setting it to `True` adds extra critic models from the paper that improve selection at the cost of more compute per job." + "`use_scaling_critics=True` is the default. Setting it to `False` skips the additional critic models from the paper, reducing compute per job at the cost of selection quality." ] }, { @@ -153,8 +153,8 @@ "source": [ "ESMFold2Design = modal.Cls.from_name(\"esmfold2-design\", \"ESMFold2DesignModal\")\n", "# Set 'use_scaling_critics' to evaluate with the additional critics.\n", - "# Off by default. But cells below were populated with them enabled.\n", - "app = ESMFold2Design(use_scaling_critics=False)" + "# On by default.\n", + "app = ESMFold2Design(use_scaling_critics=True)" ] }, { @@ -170,7 +170,7 @@ "\n", "**Option 1** uses a built-in target and binder scaffold. Available targets: `ctla4`, `egfr`, `pdgfrb`, `pd-l1`, `cd45`. Available binder types: `minibinder`, `trastuzumab_framework_vhvl` (an antibody scaffold). Easiest if your target is one of the built-ins.\n", "\n", - "**Option 2** takes your own target sequence and binder scaffold. In the binder scaffold, `#` means \"design this position\" and any amino acid letter means \"keep this position fixed.\" For example:\n", + "**Option 2** takes your own target sequence and binder scaffold. Pass a display `target_name` and `binder_name` (these do not need to be preset keys) along with the sequences. In the binder scaffold, `#` means \"design this position\" and any amino acid letter means \"keep this position fixed.\" For example:\n", "- `\"#\" * 60` designs a fully free 60-residue minibinder\n", "- A trastuzumab-style antibody scaffold (shown below) fixes the framework regions and lets the model design the CDR loops \n", "\n", @@ -181,24 +181,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "826c88d1", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'https://modal.com/id/fc-01KT1ZA3NQ0JTF2B4HNCR159NM'" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# ---- Option 1: Use presets. ----\n", - "# Relies on the registry in modal_binder_design.py::{TARGET_SEQUENCES,BINDER_PROMPT_FACTORIES}, which can be modified.\n", + "# Relies on the registry in binder_design.py::{TARGET_SEQUENCES,BINDER_PROMPT_FACTORIES}, which can be modified.\n", "future = app.design.spawn(target_name=\"ctla4\", binder_name=\"minibinder\")\n", "future.get_dashboard_url() # A clickable link to Modal dashboard" ] @@ -216,7 +205,9 @@ "# A sample of 'trastuzumab_framework_vhvl' template. From binder_design.py::BINDER_PROMPT_FACTORIES.\n", "trastuzumab_framework_vhvl = \"EVQLVESGGGLVQPGGSLRLSCAAS#######YIHWVRQAPGKGLEWVARI#####TRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSR###########WGQGTLVTVSSGGGSGGGSGGGSGGGSDIQMTQSPSSLSASVGDRVTITC###########WYQQKPGKAPKLLIY#######GVPSRFSGSRSGTDFTLTISSLQPEDFATYYC#########FGQGTKVEIK\"\n", "future2 = app.design.spawn(\n", + " target_name=\"pd-l1-custom\",\n", " target_sequence=pdl1_sequence,\n", + " binder_name=\"trastuzumab-custom\",\n", " binder_sequence=trastuzumab_framework_vhvl,\n", " is_antibody=True,\n", ")\n", @@ -244,6 +235,7 @@ "source": [ "# ---- Load result ----\n", "best_sequences, trajectory, critic_results = future.get()\n", + "# best_sequences, trajectory, critic_results = future2.get()\n", "print(\"Best sequences: \", best_sequences)\n", "df = pd.DataFrame(critic_results)\n", "df.drop(columns=[\"logits\", \"complex\"])" @@ -251,79 +243,10 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "id": "d80597fa", "metadata": {}, - "outputs": [ - { - "data": { - "application/3dmoljs_load.v0": "
\n

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n
\n", - "text/html": [ - "
\n", - "

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n", - "
\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 54, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# ---- Visualize ----\n", "protein_complex = (\n", @@ -332,14 +255,30 @@ "(\n", " py3Dmol.view(width=600, height=600)\n", " .addModel(protein_complex.to_pdb_string(), \"pdb\")\n", - " .setStyle({\"chain\": \"A\"}, {\"cartoon\": {\"color\": \"green\"}}) # pyright: ignore\n", - " .setStyle({\"chain\": \"B\"}, {\"cartoon\": {\"color\": \"cyan\"}}) # pyright: ignore\n", - " .addStyle( # pyright: ignore\n", - " {\"not\": {\"atom\": [\"N\", \"C\", \"O\"]}},\n", - " {\"stick\": {\"colorscheme\": \"default\", \"radius\": 0.2}},\n", + " .setStyle({\"chain\": \"A\"}, {\"cartoon\": {\"color\": \"green\"}})\n", + " .setStyle(\n", + " {\"chain\": \"B\"},\n", + " {\n", + " \"cartoon\": {\n", + " \"colorscheme\": {\"prop\": \"b\", \"gradient\": \"rwb\", \"min\": 60, \"max\": 100}\n", + " }\n", + " },\n", + " )\n", + " .addStyle( # B factor coloring for binder\n", + " {\"and\": [{\"chain\": \"B\"}, {\"not\": {\"atom\": [\"N\", \"C\", \"O\"]}}]},\n", + " {\n", + " \"stick\": {\n", + " \"colorscheme\": {\"prop\": \"b\", \"gradient\": \"rwb\", \"min\": 60, \"max\": 100},\n", + " \"radius\": 0.2,\n", + " }\n", + " },\n", " )\n", - " .center() # pyright: ignore\n", - " .zoomTo() # pyright: ignore\n", + " .addStyle( # Target colored green\n", + " {\"and\": [{\"chain\": \"A\"}, {\"not\": {\"atom\": [\"N\", \"C\", \"O\"]}}]},\n", + " {\"stick\": {\"color\": \"green\", \"radius\": 0.2}},\n", + " )\n", + " .center()\n", + " .zoomTo()\n", ")" ] }, @@ -351,7 +290,7 @@ "## 3. Run a sweep for real designs\n", "For real candidates worth ordering, sweep across many seeds (and optionally multiple targets, binder types, or lengths) and select the best.\n", "\n", - "Edit `line_sweeps` below to define your campaign. Each key is a sweep axis; the notebook runs one job per combination of values. The default below sweeps 128 seeds across two binder types against PD-L1. \n", + "Edit `targets`, `binders`, and the other sweep axes below. Each target/binder is a `(name, sequence)` tuple: pass `None` for sequence to use a built-in preset (name is the preset key); otherwise name is your display name and sequence is the amino-acid string. The default below sweeps 128 seeds across two binder types against PD-L1. \n", "\n", "**Before you click Run on the Launch cell, check the printed shape of the dataframe and confirm it's the number of jobs you intended.** " ] @@ -361,120 +300,36 @@ "execution_count": null, "id": "ac02bbaa", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
target_nametarget_sequencebinder_namebinder_sequenceuse_scaling_criticsseedbatch_size
0pd-l1NoneminibinderNoneFalse01
1pd-l1NoneminibinderNoneFalse11
\n", - "
" - ], - "text/plain": [ - " target_name target_sequence binder_name binder_sequence \\\n", - "0 pd-l1 None minibinder None \n", - "1 pd-l1 None minibinder None \n", - "\n", - " use_scaling_critics seed batch_size \n", - "0 False 0 1 \n", - "1 False 1 1 " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(256, 7)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# ---- Config ----\n", "save_dir = Path(\"sweep\")\n", "save_dir.mkdir(exist_ok=True)\n", "\n", - "# Sweep settings - each key-value pair is an axis of a grid sweep.\n", + "targets = [(\"pd-l1\", None)]\n", + "binders = [(\"minibinder\", None), (\"trastuzumab_framework_vhvl\", None)]\n", + "\n", "line_sweeps = dict(\n", - " target_name=[\"pd-l1\"],\n", - " target_sequence=[None],\n", - " binder_name=[\"minibinder\", \"trastuzumab_framework_vhvl\"], # two modalities\n", - " binder_sequence=[None],\n", - " use_scaling_critics=[False],\n", + " target=targets,\n", + " binder=binders,\n", + " use_scaling_critics=[True],\n", " seed=list(range(128)),\n", " batch_size=[1],\n", ")\n", - "df = pd.DataFrame(product(*line_sweeps.values()), columns=list(line_sweeps.keys()))\n", + "df = pd.DataFrame(product(*line_sweeps.values()), columns=line_sweeps.keys())\n", + "df[\"target_name\"], df[\"target_sequence\"] = zip(*df[\"target\"], strict=True)\n", + "df[\"binder_name\"], df[\"binder_sequence\"] = zip(*df[\"binder\"], strict=True)\n", + "df = df.drop(columns=[\"target\", \"binder\"])\n", "display(df.head(2))\n", "df.shape" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "9b768813", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Spawned 256 jobs. It is safe to close the notebook.The next cell will resume from call_id's, saved by Modal for up to 7 days.\n" - ] - } - ], + "outputs": [], "source": [ "# ---- Launch ----\n", "df[\"call_id\"] = [\n", @@ -517,58 +372,10 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "id": "b9a637f0", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "700b75ea416b483890f292b204d4d0fb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/256 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
designed_sequenceiptm_scoreiptm_proxy_scoreselection_score
target_namebinder_name
pd-l1minibinder70AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9679820.9443970.956190
13AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9644700.9262160.945343
78AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9619200.9181860.940053
75AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9575160.9160830.936799
45AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9646650.9084660.936566
..................
trastuzumab_framework_vhvl2AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9218840.7256630.823773
107AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9206200.7263910.823506
46AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.8790640.7673230.823194
68AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9211960.7241390.822668
50AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.8944140.7472550.820834
\n", - "

168 rows × 4 columns

\n", - "" - ], - "text/plain": [ - " designed_sequence \\\n", - "target_name binder_name \n", - "pd-l1 minibinder 70 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 13 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 78 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 75 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 45 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - "... ... \n", - " trastuzumab_framework_vhvl 2 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 107 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 46 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 68 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 50 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - "\n", - " iptm_score iptm_proxy_score \\\n", - "target_name binder_name \n", - "pd-l1 minibinder 70 0.967982 0.944397 \n", - " 13 0.964470 0.926216 \n", - " 78 0.961920 0.918186 \n", - " 75 0.957516 0.916083 \n", - " 45 0.964665 0.908466 \n", - "... ... ... \n", - " trastuzumab_framework_vhvl 2 0.921884 0.725663 \n", - " 107 0.920620 0.726391 \n", - " 46 0.879064 0.767323 \n", - " 68 0.921196 0.724139 \n", - " 50 0.894414 0.747255 \n", - "\n", - " selection_score \n", - "target_name binder_name \n", - "pd-l1 minibinder 70 0.956190 \n", - " 13 0.945343 \n", - " 78 0.940053 \n", - " 75 0.936799 \n", - " 45 0.936566 \n", - "... ... \n", - " trastuzumab_framework_vhvl 2 0.823773 \n", - " 107 0.823506 \n", - " 46 0.823194 \n", - " 68 0.822668 \n", - " 50 0.820834 \n", - "\n", - "[168 rows x 4 columns]" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# ---- Select ----\n", "\n", @@ -889,7 +478,6 @@ " scores[\"selection_score\"] = 0.5 * scores.iptm_score.fillna(\n", " 0\n", " ) + 0.5 * scores.iptm_proxy_score.fillna(0)\n", - "\n", " return scores.nlargest(min(len(scores), 84), \"selection_score\")\n", "\n", "\n", @@ -900,6 +488,33 @@ "df_select" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a131c33", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- Write selected structures ----\n", + "complexes = (\n", + " df_result[df_result.critic_name.eq(\"ESMFold2-Experimental-Cutoff2025\")]\n", + " .drop_duplicates(\"designed_sequence\")\n", + " .set_index(\"designed_sequence\")[\"complex\"]\n", + ")\n", + "pdb_dir = save_dir / \"selected_structures\"\n", + "pdb_dir.mkdir(exist_ok=True)\n", + "selected = df_select.reset_index().sort_values(\n", + " [\"target_name\", \"binder_name\", \"selection_score\"], ascending=[True, True, False]\n", + ")\n", + "for rank, row in enumerate(selected.itertuples(), start=1):\n", + " pdb_path = (\n", + " pdb_dir\n", + " / f\"{rank:04d}_{row.target_name}_{row.binder_name}_score{row.selection_score:.3f}.pdb\"\n", + " )\n", + " pdb_path.write_text(complexes[row.designed_sequence].to_pdb_string())\n", + "print(f\"Wrote {len(selected)} PDBs to {pdb_dir.resolve()}\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -979,7 +594,7 @@ ], "metadata": { "kernelspec": { - "display_name": "default", + "display_name": "Python 3", "language": "python", "name": "python3" }, diff --git a/cookbook/tutorials/binder_design.py b/cookbook/tutorials/binder_design.py index 3fd97049..b6e424f2 100644 --- a/cookbook/tutorials/binder_design.py +++ b/cookbook/tutorials/binder_design.py @@ -17,8 +17,9 @@ import os import random import string +import time from dataclasses import dataclass -from functools import cache, partial +from functools import cache from typing import Any import biotite.structure @@ -27,15 +28,7 @@ import torch import torch.nn.functional as F import torch.optim as optim -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, - apply_activation_checkpointing, - checkpoint_wrapper, -) from transformers.models.esmc.modeling_esmc import ESMCForMaskedLM -from transformers.models.esmc.modeling_esmc import ( - UnifiedTransformerBlock as TransformerBlock, -) from transformers.models.esmc.tokenization_esmc import ESMCTokenizer from transformers.models.esmfold2.modeling_esmfold2_common import ( CUE_AVAILABLE, @@ -64,6 +57,7 @@ PROTEIN_3TO1, RES_TYPE_TO_CCD, ) +from esm.utils.structure.mmcif_parsing import PLDDT_B_FACTOR_SCALE from esm.utils.structure.protein_chain import ProteinChain from esm.utils.structure.protein_complex import ProteinComplex @@ -72,6 +66,9 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +TrajectoryStep = dict[str, torch.Tensor | float] +Trajectory = dict[int, TrajectoryStep] + # ---- Constants ---- @@ -97,13 +94,15 @@ LEARNING_RATE = 0.1 TEMPERATURE_MIN = 1e-2 ESMC_MASK_FRACTION = 0.15 -CHECKPOINT_LM = False -COMPILE = False +LM_LOSS_BATCH_SIZE = 128 +LM_MASK_PASSES = 4 +BYTES_PER_GIB = 1024**3 +COMPILE = True # NOTE - This significantly reduces VRAM usage. -# On config (target_name=cd45", binder_name="trastuzumab_framework_vhvl, batch_size=1) +# On config (target_name="cd45", binder_name="trastuzumab_framework_vhvl", batch_size=1) # this reduces VRAM from 51GB -> 27GB. And enables increasing batch size up to 6. # We are testing this setting in silico, and may change the default to True, in the future. -REUSE_ESMC = False +REUSE_ESMC = True # ---- Prompts ---- @@ -384,6 +383,20 @@ def _entropy_to_confidence(mean_entropy: float) -> float: return float(max(0.0, min(1.0, 1.0 - mean_entropy / math.log(51)))) +def _is_antibody_sequence(binder_sequence: str) -> bool: + """Return True if ANARCI recognizes binder_sequence as antibody variable domain(s).""" + from abnumber.common import _anarci_align + + sequence = binder_sequence.replace(MUTABLE_TOKEN, "A") + result = _anarci_align( + sequences=[sequence], scheme="chothia", allowed_species=None + )[0] + if not result: + return False + valid_chain_types = {"H", "K", "L"} + return all(chain_type in valid_chain_types for _, chain_type, *_ in result) + + def _cdr_indices(binder_sequence: str) -> list[int]: """0-based binder indices for all Chothia CDRs.""" from abnumber import Chain @@ -646,6 +659,9 @@ def build_complex( inputs: dict[str, torch.Tensor], output: dict[str, Any] ) -> ProteinComplex: """Build ProteinComplex from model output.""" + plddt_per_atom = output.get("plddt_per_atom") + if plddt_per_atom is not None: + plddt_per_atom = plddt_per_atom[0].cpu().numpy() * PLDDT_B_FACTOR_SCALE atom_arr = to_atom_array( coords=output["sample_atom_coords"][0].cpu().numpy(), atom_to_token=inputs["atom_to_token"][0].cpu().numpy(), @@ -656,9 +672,13 @@ def build_complex( ref_atom_name_chars=inputs["ref_atom_name_chars"][0].cpu().numpy(), ref_element=inputs["ref_element"][0].cpu().numpy(), atom_attention_mask=inputs["atom_attention_mask"][0].cpu().numpy(), + plddt_per_atom=plddt_per_atom, ) return ProteinComplex.from_chains( - [ProteinChain.from_atomarray(a) for a in biotite.structure.chain_iter(atom_arr)] + [ + ProteinChain.from_atomarray(a, is_predicted=True) + for a in biotite.structure.chain_iter(atom_arr) + ] ) @@ -714,8 +734,8 @@ def compute_esmc_pseudoperplexity_nll( device=device, ) tokenizer = ESMCTokenizer() - input_ids[:, 0, tokenizer.cls_token_id] = 1 - input_ids[:, -1, tokenizer.eos_token_id] = 1 + input_ids[:, 0, tokenizer.cls_token_id] = 1 # pyright: ignore + input_ids[:, -1, tokenizer.eos_token_id] = 1 # pyright: ignore input_ids[:, 1:-1, 4:24] = input_esm.to(model_dtype) if score_mask.ndim == 1: @@ -731,7 +751,8 @@ def compute_esmc_pseudoperplexity_nll( mask_token[esmc_model.config.mask_token_id] = 1 esmc = esmc_model.esmc - losses = [] + all_masked_sequences = [] + all_pass_masks = [] for batch_idx in range(binder_design.size(0)): position_indices = score_mask[batch_idx].nonzero(as_tuple=False).flatten() num_positions = int(position_indices.numel()) @@ -755,28 +776,35 @@ def compute_esmc_pseudoperplexity_nll( mask_rows, mask_cols = pass_masks.nonzero(as_tuple=True) masked_sequences[mask_rows, mask_cols + 1] = mask_token - target_weights = target_esm[batch_idx] - masked_nlls = [] - for start in range(0, n_passes, batch_size): - stop = min(start + batch_size, n_passes) - chunk = masked_sequences[start:stop] - with torch.autocast( - device_type="cuda", dtype=torch.bfloat16, enabled=device.type == "cuda" - ): - hidden, *_ = esmc.transformer( - chunk @ esmc.embed.weight.to(chunk.dtype), - sequence_id=None, - layers_to_collect=[], - output_attentions=False, - ) - logits = esmc_model.lm_head(hidden) - log_probs = logits.log_softmax(dim=-1)[:, 1:-1, 4:24] - nlls = -(log_probs * target_weights.to(log_probs.dtype).unsqueeze(0)).sum( - dim=-1 + all_masked_sequences.append(masked_sequences) + all_pass_masks.append(pass_masks) + + logit_chunks = [] + masked_sequence_rows = torch.cat(all_masked_sequences, dim=0) + for start in range(0, masked_sequence_rows.size(0), batch_size): + chunk = masked_sequence_rows[start : start + batch_size] + with torch.autocast( + device_type="cuda", dtype=torch.bfloat16, enabled=device.type == "cuda" + ): + hidden, *_ = esmc.transformer( + chunk @ esmc.embed.weight.to(chunk.dtype), + sequence_id=None, + layers_to_collect=[], + output_attentions=False, ) - masked_nlls.append(nlls[pass_masks[start:stop]]) + logit_chunks.append(esmc_model.lm_head(hidden)) + logits = torch.cat(logit_chunks, dim=0) - losses.append(torch.cat(masked_nlls, dim=0).mean()) + losses = [] + for batch_idx, pass_masks in enumerate(all_pass_masks): + start = batch_idx * n_passes + stop = start + n_passes + target_weights = target_esm[batch_idx] + log_probs = logits[start:stop].log_softmax(dim=-1)[:, 1:-1, 4:24] + nlls = -(log_probs * target_weights.to(log_probs.dtype).unsqueeze(0)).sum( + dim=-1 + ) + losses.append(nlls[pass_masks].mean()) return torch.stack(losses, dim=0) @@ -801,14 +829,14 @@ def design_binder( inversion_models: dict[str, ESMFold2ExperimentalModel], hf_critic_models: dict[str, ESMFold2ExperimentalModel], esmc_model: ESMCForMaskedLM, - target_name: str | None, + target_name: str, target_sequence: str | None, - binder_name: str | None, + binder_name: str, binder_sequence: str | None, is_antibody: bool | None, seed: int, batch_size: int = 1, -) -> tuple[list[str], dict[int, dict[str, torch.Tensor]], list[dict]]: +) -> tuple[list[str], Trajectory, list[dict]]: """ Algorithm 11 Gradient-Guided Binder Sequence Optimization. @@ -820,28 +848,25 @@ def design_binder( ``distogram_binding_confidence`` / ``cdr_distogram_binding_confidence`` come from the distogram in all cases. """ - # Vet inputs - assert (target_name is None) ^ ( - target_sequence is None - ), "Provide either target name or sequence." - assert (binder_name is None) ^ ( - binder_sequence is None - ), "Provide either binder name or sequence." - # Setup device = "cuda" - if target_name is not None: + if target_name in TARGET_SEQUENCES: + if target_sequence is not None: + raise ValueError( + f"{target_name!r} is a preset target; omit target_sequence." + ) target_sequence = TARGET_SEQUENCES[target_name] - else: - assert target_sequence is not None + elif target_sequence is None: + raise ValueError( + f"{target_name!r} is not a preset target; provide target_sequence." + ) target_one_hot = sequence_to_one_hot(target_sequence, device=device) - if binder_name is None: - assert binder_sequence is not None - # If no binder_name and is_antibody is not specified, assume False. - if is_antibody is None: - is_antibody = False - else: + if binder_name in BINDER_PROMPT_FACTORIES: + if binder_sequence is not None: + raise ValueError( + f"{binder_name!r} is a preset binder; omit binder_sequence." + ) binder_prompt_factor = BINDER_PROMPT_FACTORIES[binder_name] if is_antibody is not None: assert ( @@ -849,6 +874,12 @@ def design_binder( ), "Conflict in is_antibody settings." is_antibody = binder_prompt_factor.is_antibody binder_sequence = binder_prompt_factor.sample(seed=seed) + elif binder_sequence is None: + raise ValueError( + f"{binder_name!r} is not a preset binder; provide binder_sequence." + ) + elif is_antibody is None: + is_antibody = _is_antibody_sequence(binder_sequence) binder_length = len(binder_sequence) @@ -864,8 +895,8 @@ def design_binder( ) gradient_mask = build_gradient_mask(binder_sequence, batch_size=batch_size) - # step -> {loss_name: [B] tensor on CPU} - trajectory: dict[int, dict[str, torch.Tensor]] = {} + # step -> {loss_name: [B] tensor on CPU, metric_name: float} + trajectory: Trajectory = {} global_step = 0 def run_step( @@ -875,6 +906,8 @@ def run_step( calculate_confidence: bool, ) -> tuple[torch.Tensor, list[str], list[float] | None]: nonlocal global_step + torch.cuda.reset_peak_memory_stats() + start = time.time() optimizer.zero_grad() random.seed(seed + global_step) @@ -907,8 +940,8 @@ def run_step( esmc_model=esmc_model, binder_design=design, score_mask=score_mask, - batch_size=4, - n_passes=4, + batch_size=LM_LOSS_BATCH_SIZE, + n_passes=LM_MASK_PASSES, ) plm_grad = torch.autograd.grad(plm_loss.mean(), logits)[0] @@ -922,12 +955,20 @@ def run_step( optimizer.step() step = global_step - step_losses = {k: v.detach().cpu() for k, v in losses.items()} + step_losses: TrajectoryStep = {k: v.detach().cpu() for k, v in losses.items()} step_losses["plm_loss"] = plm_loss.detach().cpu() step_losses["total_loss"] = (structure_loss + plm_loss).detach().cpu() + step_losses["time"] = time.time() - start + step_losses["peak_allocated_gib"] = ( + torch.cuda.max_memory_allocated() / BYTES_PER_GIB + ) + step_losses["peak_reserved_gib"] = ( + torch.cuda.max_memory_reserved() / BYTES_PER_GIB + ) trajectory[step] = step_losses loss_str = " ".join( - f"{k}={v.mean().item():.4f}" for k, v in step_losses.items() + f"{k}={v.mean().item() if torch.is_tensor(v) else v:.4f}" + for k, v in step_losses.items() ) if step % LOG_INTERVAL == 0: logger.info(f" step {step:3d} | {loss_str} T={temperature:.4f}") @@ -959,6 +1000,8 @@ def run_step( # Score critic_results: list[dict] = [] target_length = len(target_sequence.replace("|", "")) + final_total_loss = trajectory[global_step - 1]["total_loss"] + assert isinstance(final_total_loss, torch.Tensor) for batch_idx in range(batch_size): best_seq = best_sequences[batch_idx] binder_seq = best_seq.split("|")[-1] @@ -991,9 +1034,7 @@ def run_step( "batch_idx": batch_idx, "designed_sequence": best_seq, "complex": pred_complex, - "final_loss": trajectory[global_step - 1]["total_loss"][ - batch_idx - ].item(), + "final_loss": final_total_loss[batch_idx].item(), "iptm": iptm, "logits": logits[batch_idx].detach().cpu(), **iptm_proxy_scores, @@ -1007,9 +1048,7 @@ def run_step( "is_antibody": is_antibody, "batch_idx": batch_idx, "designed_sequence": best_sequences[batch_idx], - "final_loss": trajectory[global_step - 1]["total_loss"][ - batch_idx - ].item(), + "final_loss": final_total_loss[batch_idx].item(), "logits": logits[batch_idx].detach().cpu(), } ) @@ -1019,23 +1058,23 @@ def run_step( # ---- Model Loading ---- -_ESMC = None +_ESMC_CACHE: dict[str, torch.nn.Module] = {} def _load_hf_model( critic_name: str, lm_dropout: float, cache_esmc: bool, device: str ) -> Any: - """Loads ESMFold2 from huggingface. Will cache ESMC-6B among + """Loads ESMFold2 from huggingface. Will cache ESMC by checkpoint ID among all non-scaling checkpoints, to save on VRAM and load time.""" - global _ESMC repo_id = f"biohub/{critic_name}" model = ESMFold2ExperimentalModel.from_pretrained(repo_id, load_esmc=not cache_esmc) if cache_esmc: - if _ESMC is None: - model.load_esmc(model.config.esmc_id) - _ESMC = model._esmc - else: - model._esmc = _ESMC + esmc_id = model.config.esmc_id + if esmc_id not in _ESMC_CACHE: + model.load_esmc(esmc_id) + assert model._esmc is not None + _ESMC_CACHE[esmc_id] = model._esmc + model._esmc = _ESMC_CACHE[esmc_id] model.configure_lm_dropout(lm_dropout, force_lm_dropout_during_inference=True) model.set_kernel_backend("cuequivariance" if CUE_AVAILABLE else None) return model.to(device=device).eval().requires_grad_(False) @@ -1046,7 +1085,7 @@ def _apply_torch_compile(model: torch.nn.Module) -> None: torch._dynamo.config.cache_size_limit = 512 torch._dynamo.config.accumulated_cache_size_limit = 512 - compile_targets = (ESMFold2MSAEncoder, PairUpdateBlock, TransformerBlock) + compile_targets = (ESMFold2MSAEncoder, PairUpdateBlock) def _maybe_compile_module(module: torch.nn.Module) -> None: if not isinstance(module, compile_targets): @@ -1102,32 +1141,27 @@ def load(self, use_scaling_critics: bool): self.lm_name, torch_dtype=torch.float32 ) if REUSE_ESMC: + reusable_esmc_model = self.inversion_models["ESMFold2-Experimental-Fast"] + assert reusable_esmc_model.config.esmc_id == self.lm_name, ( + f"Cannot reuse ESMC trunk from {reusable_esmc_model.config.esmc_id!r} " + f"with LM head from {self.lm_name!r}." + ) + assert reusable_esmc_model._esmc is not None del self.esmc_model.esmc torch.cuda.empty_cache() - self.esmc_model.esmc = self.inversion_models[ - "ESMFold2-Experimental-Fast" - ]._esmc + self.esmc_model.esmc = reusable_esmc_model._esmc self.esmc_model = self.esmc_model.cuda().eval().requires_grad_(False) - if CHECKPOINT_LM: - apply_activation_checkpointing( - self.esmc_model, - checkpoint_wrapper_fn=partial( - checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT - ), - check_fn=lambda module: isinstance(module, TransformerBlock), - ) - def design( self, - target_name: str | None = None, + target_name: str, + binder_name: str, target_sequence: str | None = None, - binder_name: str | None = None, binder_sequence: str | None = None, is_antibody: bool | None = None, seed: int = 0, batch_size: int = 1, - ) -> tuple[list[str], dict[int, dict[str, torch.Tensor]], list[dict]]: + ) -> tuple[list[str], Trajectory, list[dict]]: return design_binder( self.inversion_models, self.hf_critic_models, @@ -1150,10 +1184,50 @@ def get_base_image(): modal.Image.micromamba(python_version="3.12") .run_commands("apt update && apt install -y git build-essential") .micromamba_install( - "anarci>=2020.04.03", "hmmer=3.4", channels=["conda-forge", "bioconda"] + "anarci>=2020.04.03", + "hmmer=3.4", + "cuda-version=12.8", + "cuda-libraries-dev=12.8", + "cuda-nvcc=12.8", + "cmake", + "ninja", + channels=["conda-forge", "bioconda"], + ) + .pip_install( + "torch==2.8.0", + "triton==3.4.0", + index_url="https://download.pytorch.org/whl/cu128", + ) + .pip_install( + "flash-attn==2.8.3", + "transformer-engine[core-cu12,pytorch]==2.13.0", + "xformers==0.0.32.post1", + extra_options="--no-build-isolation", + env={ + "CPATH": ( + "/opt/conda/lib/python3.12/site-packages/nvidia/cudnn/include:" + "/opt/conda/lib/python3.12/site-packages/nvidia/nccl/include:" + "/opt/conda/lib/python3.12/site-packages/nvidia/nvtx/include" + ), + "LIBRARY_PATH": ( + "/opt/conda/lib/python3.12/site-packages/nvidia/cudnn/lib:" + "/opt/conda/lib/python3.12/site-packages/nvidia/nccl/lib:" + "/opt/conda/lib/python3.12/site-packages/nvidia/nvtx/lib" + ), + "MAX_JOBS": "8", + "NVTE_FRAMEWORK": "pytorch", + }, + ) + .pip_install( + "abnumber", "esm@git+https://github.com/Biohub/esm.git@main", "modal" + ) + .env( + { + "HF_HOME": "/models", + "HF_XET_HIGH_PERFORMANCE": "1", + "XFORMERS_IGNORE_FLASH_VERSION_CHECK": "1", + } ) - .pip_install("abnumber", "esm@git+https://github.com/Biohub/esm.git@main") - .env({"HF_HOME": "/models", "HF_XET_HIGH_PERFORMANCE": "1"}) ) @@ -1166,15 +1240,17 @@ def get_base_image(): ) -# If use_scaling_checkpoints is True, `memory` should be increased to 60 * 1024. -@app.cls(gpu="H100", timeout=60 * 60, cpu=16, memory=10 * 1024) +# NOTE - Currently the memory usage is quite high, inflating costs. +# In an update coming soon, scaling critics will be loaded on demand +# to avoid needing this large amount of RAM. +@app.cls(gpu="H100", timeout=60 * 60, cpu=16, memory=80 * 1024) class ESMFold2DesignModal(ESMFold2Design): """Modal entrypoint. Hero critics are HF experimental exports with confidence heads. Set ``use_scaling_critics=True`` to also load the 15-checkpoint scaling-experiment ensemble (distogram binding confidence only). """ - use_scaling_critics: bool = modal.parameter(default=False) + use_scaling_critics: bool = modal.parameter(default=True) @modal.enter() def load(self): @@ -1187,11 +1263,11 @@ def design(self, *args, **kws): @app.local_entrypoint() def main( - target_name: str | None = None, + target_name: str, + binder_name: str, target_sequence: str | None = None, - binder_name: str | None = None, binder_sequence: str | None = None, - use_scaling_critics: bool = False, + use_scaling_critics: bool = True, is_antibody: bool | None = None, local: bool = False, seed: int = 0, @@ -1242,5 +1318,5 @@ def main( seed=0, batch_size=1, local=True, - use_scaling_critics=False, + use_scaling_critics=True, ) diff --git a/cookbook/tutorials/embed.ipynb b/cookbook/tutorials/embed.ipynb index 38678117..6f0693e7 100644 --- a/cookbook/tutorials/embed.ipynb +++ b/cookbook/tutorials/embed.ipynb @@ -308,9 +308,9 @@ ], "metadata": { "kernelspec": { - "display_name": "default", + "display_name": "Pixi (ESM)", "language": "python", - "name": "python3" + "name": "pixi-esm" }, "language_info": { "codemirror_mode": { @@ -322,7 +322,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.9" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/cookbook/tutorials/esm3_generate.ipynb b/cookbook/tutorials/esm3_generate.ipynb index 52b68f31..d75ccd52 100644 --- a/cookbook/tutorials/esm3_generate.ipynb +++ b/cookbook/tutorials/esm3_generate.ipynb @@ -347,7 +347,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)\n" + "Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 48-120) in a protein structure (colored in blue below)\n" ] }, { @@ -360,16 +360,19 @@ "view = py3Dmol.view(width=500, height=500)\n", "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\")\n", "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", - "helix_region = np.arange(38, 111) # zero-indexed\n", + "helix_region = np.arange(47, 120) # zero-indexed\n", "view.addStyle(\n", " {\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\": \"lightblue\"}}\n", ")\n", "view.zoomTo()\n", "view.show()\n", - "helix_shortening_ss8 = \"CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC\"\n", + "helix_shortening_ss8 = \"CCCCCCCCCCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHCCCCCCCCCCCCCCCCCC\"\n", "print(\n", " \"Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \\n\\t\",\n", " helix_shortening_ss8,\n", + ")\n", + "print(\n", + " f\"SS8 length: {len(helix_shortening_ss8)}, Sequence length: {len(helix_shortening_chain.sequence)}\"\n", ")" ] }, @@ -583,7 +586,9 @@ " protein_prompt,\n", " GenerationConfig(\n", " track=\"sequence\",\n", - " num_steps=protein_prompt.sequence.count(\"_\") // 2,\n", + " num_steps=min(\n", + " protein_prompt.sequence.count(\"_\") // 2, 100\n", + " ), # Cap at 100 (API limit)\n", " temperature=0.5,\n", " ),\n", " )\n", @@ -647,20 +652,13 @@ "view.zoomTo()\n", "view.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python (pixi)", "language": "python", - "name": "python3" + "name": "pixi" }, "language_info": { "codemirror_mode": { @@ -672,7 +670,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/cookbook/tutorials/esmfold2.ipynb b/cookbook/tutorials/esmfold2.ipynb index a7b4d5a3..c1a06144 100644 --- a/cookbook/tutorials/esmfold2.ipynb +++ b/cookbook/tutorials/esmfold2.ipynb @@ -2029,6 +2029,72 @@ "\n", "Deeper MSAs (more sequences) provide more evolutionary information but increase computation time." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# **Example 4: Folding Complexes with Paired MSAs**\n", + "\n", + "For multi-chain complexes — antibody–antigen being the canonical case — ESMFold2 does more than concatenate per-chain MSAs: it **pairs MSA rows across chains by taxonomy**, so the model can read co-evolutionary signal *between* chains. This inter-chain pairing is a large part of what drives ESMFold2's antibody–antigen accuracy.\n", + "\n", + "You supply **one MSA per chain** (each its own a3m) and pass each chain as its own `ProteinInput`. Pairing is automatic:\n", + "\n", + "- Rows whose FASTA header carries a `key=` token are **paired** across chains that share the same id (one paired row per shared organism), placed at the top of the stacked MSA.\n", + "- Every remaining hit is kept as an **unpaired** row (stacked densely, gap-filled in the other chains' columns), so single-chain evolutionary signal is preserved too.\n", + "\n", + "So you get *both* paired and unpaired MSA from the same per-chain inputs — you don't build the block layout yourself, you just annotate taxonomy in the headers.\n", + "\n", + "> **Header requirement:** pairing keys off `key=` in each sequence's header (the query/first row needs no key). Bring per-chain a3m files annotated with `key=` — e.g. from a UniProt/ColabFold search, rewriting each header's `OX=` to `key=`. A chain whose hits have no `key=` still folds; its rows are simply treated as unpaired.\n", + "\n", + "An a3m header for a chain then looks like:\n", + "\n", + "```\n", + ">query\n", + "EVQLVESGGGLVQPGGSLRLSCAAS...\n", + ">UniRef100_A0A... key=9606 # human homolog → pairs with key=9606 rows in other chains\n", + "QVQLVQSGAEVKKPGAS...\n", + ">UniRef100_B1B... key=10090 # mouse homolog\n", + "...\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from esm.utils.msa import MSA\n", + "from esm.utils.structure.input_builder import ProteinInput, StructurePredictionInput\n", + "\n", + "# Example antibody–antigen complex: antibody heavy + light chains and the antigen.\n", + "heavy_seq = \"EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWGGDGFYAMDYWGQGTLVTVSS\"\n", + "light_seq = \"DIQMTQSPSSLSASVGDRVTITCRASQDVNTAVAWYQQKPGKAPKLLIYSASFLYSGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQHYTTPPTFGQGTKVEIK\"\n", + "antigen_seq = \"MKQLEDKVEELLSKNYHLENEVARLKKLVGER\"\n", + "\n", + "# One MSA per chain. Headers must carry `key=` for cross-chain pairing.\n", + "heavy_msa = MSA.from_a3m(path=\"antibody_heavy.a3m\", remove_insertions=True)\n", + "light_msa = MSA.from_a3m(path=\"antibody_light.a3m\", remove_insertions=True)\n", + "antigen_msa = MSA.from_a3m(path=\"antigen.a3m\", remove_insertions=True)\n", + "\n", + "# Pass each chain as its own ProteinInput; pairing happens automatically.\n", + "fold_result = client.fold_all_atom(\n", + " StructurePredictionInput(\n", + " sequences=[\n", + " ProteinInput(id=\"H\", sequence=heavy_seq, msa=heavy_msa),\n", + " ProteinInput(id=\"L\", sequence=light_seq, msa=light_msa),\n", + " ProteinInput(id=\"A\", sequence=antigen_seq, msa=antigen_msa),\n", + " ]\n", + " )\n", + ")\n", + "\n", + "print(f\"pTM: {fold_result.ptm:.3f}, ipTM: {fold_result.iptm:.3f}\")\n", + "print(f\"Average pLDDT: {fold_result.plddt.mean().item():.1f}\")\n", + "\n", + "with open(\"antibody_antigen_paired_msa.cif\", \"w\") as f:\n", + " f.write(fold_result.complex.to_mmcif())" + ] } ], "metadata": { diff --git a/esm/models/vqvae.py b/esm/models/vqvae.py index 8070286f..a2725b08 100644 --- a/esm/models/vqvae.py +++ b/esm/models/vqvae.py @@ -280,7 +280,7 @@ def find_knn_edges( (coords.shape[0], coords.shape[1]), device=coords.device ).long() - with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): + with torch.no_grad(), torch.amp.autocast("cuda", enabled=False): ca = coords[..., 1, :] edges, edge_mask = knn_graph( ca, coord_mask, padding_mask, sequence_id, no_knn=knn diff --git a/esm/sdk/api.py b/esm/sdk/api.py index 86777646..79f2969e 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -498,26 +498,29 @@ class ForwardTrackData: @define class LogitsConfig: """ - sequence (bool): Return sequence logits. - structure (bool): Return structure logits. - secondary_structure (bool): Return secondary structure logits. - sasa (bool): Return sasa logits. - function (bool): Return function logits. - residue_annotations (bool): Return residue annotations logits. - return_embeddings (bool): Whether embeddings should be returned. - return_hidden_states (bool): Whether to return per-residue hidden states. With - ith_hidden_layer=-1, returns all layers as a tensor of shape [n_layers + 1, - B, L, D]. With ith_hidden_layer!= -1, returns the selected layer as a tensor - of shape [1, B, L, D]. - return_mean_embedding (bool): Whether mean embeddings should be returned. - return_mean_hidden_states (bool): Whether hidden states mean-pooled along the - sequence length (L) dimension should be returned. Returns a tensor of shape - [B, n_layers + 1, D]. - ith_hidden_layer (int): Valid values for ith_hidden_layer are 0 to - max_ith_hidden_layer (inclusive), where index 0 is the embedding layer. -1 - returns all layers, but is not supported for ESMC 6B or any ESM3 model. Here - is the max_ith_hidden_layer for each ESMC and ESM3 model (except ESM3 - Large). + sequence (bool): (ESM3, ESMC) Return sequence logits. + structure (bool): (not supported on Forge/Biohub Platform) Return structure + logits. + secondary_structure (bool): (not supported on Forge/Biohub Platform) Return + secondary structure logits. + sasa (bool): (not supported on Forge/Biohub Platform) Return sasa logits. + function (bool): (not supported on Forge/Biohub Platform) Return function + logits. + residue_annotations (bool): (ESM3) Return residue annotations logits. + return_embeddings (bool): (ESM3, ESMC) Whether embeddings should be returned. + return_hidden_states (bool): (ESMC, ESM3 except Large) Whether to return per- + residue hidden states. With ith_hidden_layer=-1, returns all layers as a + tensor of shape [n_layers + 1, B, L, D]. With ith_hidden_layer!= -1, returns + the selected layer as a tensor of shape [1, B, L, D]. ESM3 requires + ith_hidden_layer to be specified. + return_mean_embedding (bool): (ESM3, ESMC) Whether mean embeddings should be + returned. + return_mean_hidden_states (bool): (ESMC, ESM3 except Large) Whether hidden + states mean-pooled along the sequence length (L) dimension should be + returned. Returns a tensor of shape [B, n_layers + 1, D]. + ith_hidden_layer (int): (ESM3, ESMC) Valid values are 0 to max_ith_hidden_layer + (inclusive), where index 0 is the embedding layer. -1 returns all layers, + but is not supported for ESMC 6B or any ESM3 model. | Model Name | max_ith_hidden_layer | |-------------------------------|--------------------------------| | esmc-300-2024-12 | 30 | @@ -527,7 +530,8 @@ class LogitsConfig: | esm3-small-2024-08 | 48 | | esm3-medium-2024-03 | 96 | | esm3-medium-2024-08 | 96 | - sae_config (SAEConfig | None): SAE config. Only applies to ESMC models. + sae_config (SAEConfig | None): (ESMC) SAE config for requesting sparse + autoencoder features. """ # Logits. diff --git a/esm/utils/msa/msa.py b/esm/utils/msa/msa.py index ebd81dfa..0fe6c432 100644 --- a/esm/utils/msa/msa.py +++ b/esm/utils/msa/msa.py @@ -76,7 +76,7 @@ def to_fast_msa(self) -> FastMSA: def from_a3m( cls, path: PathOrBuffer, - remove_insertions: bool = True, + remove_insertions: bool = False, max_sequences: int | None = None, ) -> MSA: entries = [] @@ -180,21 +180,25 @@ def state_dict(self, json_serializable: bool = False) -> dict[str, Any]: :meth:`from_a3m`) alongside the sequences, so the feature survives even when the default ``remove_insertions`` strips the lowercase insertions out of the sequences. With ``json_serializable=True`` the array is returned as a list. - Headers are not serialized. """ dct: dict[str, Any] = {"sequences": self.sequences} if self.deletions is not None: dct["deletions"] = ( self.deletions.tolist() if json_serializable else self.deletions ) + if any(self.headers): + dct["headers"] = self.headers return dct @classmethod def from_state_dict(cls, dct: dict[str, Any]) -> MSA: - """Inverse of :meth:`state_dict`; sequences are taken verbatim.""" + """Inverse of :meth:`state_dict`; sequences and (when present) headers are taken + verbatim.""" deletions = dct.get("deletions") + sequences = dct["sequences"] + headers = dct.get("headers") or ["" for _ in sequences] return cls( - entries=[FastaEntry("", seq) for seq in dct["sequences"]], + entries=[FastaEntry(h, seq) for h, seq in zip(headers, sequences)], deletions=None if deletions is None else np.asarray(deletions, dtype=np.float32), diff --git a/esm/utils/msa/msa_test.py b/esm/utils/msa/msa_test.py index 31816093..0eda8265 100644 --- a/esm/utils/msa/msa_test.py +++ b/esm/utils/msa/msa_test.py @@ -40,7 +40,7 @@ def _a3m_msa(tmp_path) -> MSA: """Build the shared `_A3M` fixture as an MSA (insertion-stripped, deletions set).""" p = tmp_path / "m.a3m" _write_a3m(p, gz=False) - return MSA.from_a3m(str(p)) + return MSA.from_a3m(str(p), remove_insertions=True) def test_a3m_deletion_counts_vectorized(): @@ -63,7 +63,7 @@ def test_from_a3m_records_deletions(tmp_path): def test_from_a3m_gz(tmp_path): p = tmp_path / "m.a3m.gz" _write_a3m(p, gz=True) - msa = MSA.from_a3m(str(p)) + msa = MSA.from_a3m(str(p), remove_insertions=True) assert msa.deletions is not None np.testing.assert_array_equal(msa.deletions, _EXPECTED_DELETIONS) diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py index 0a7d7113..ad5aef4e 100644 --- a/esm/utils/structure/input_builder.py +++ b/esm/utils/structure/input_builder.py @@ -33,6 +33,7 @@ class RNAInput: id: str | list[str] sequence: str modifications: list[Modification] | None = None + msa: MSAInput = None @dataclass @@ -197,6 +198,7 @@ def _msa(chain: dict[str, Any]) -> MSAInput: id=chain["id"], sequence=chain["sequence"], modifications=_mods(chain), + msa=_msa(chain), ) ) elif t == "dna": diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py index fe007954..2e856b33 100644 --- a/esm/utils/structure/molecular_complex.py +++ b/esm/utils/structure/molecular_complex.py @@ -38,6 +38,7 @@ class MolecularComplexResult: ptm: float | None = None iptm: float | None = None pae: torch.Tensor | None = None + pde: torch.Tensor | None = None distogram: torch.Tensor | None = None pair_chains_iptm: torch.Tensor | None = None output_embedding_sequence: torch.Tensor | None = None @@ -45,6 +46,8 @@ class MolecularComplexResult: residue_index: torch.Tensor | None = None entity_id: torch.Tensor | None = None sae_features: np.ndarray | None = None # [L, n_features] + # Atom-expanded token count L + num_tokens: int | None = None @dataclass diff --git a/esm/widgets/utils/prompting.py b/esm/widgets/utils/prompting.py index 199e7ddc..3a142c12 100644 --- a/esm/widgets/utils/prompting.py +++ b/esm/widgets/utils/prompting.py @@ -280,14 +280,13 @@ def add_entry_to_ui(self, range_string): f"{range_string}" ) ) - entry_label.tag = range_string entry_container = widgets.HBox([entry_button, entry_label]) def delete_entry(b): self.entries_box.children = [ w for w in self.entries_box.children if w != entry_container ] - self.delete_prompt(entry_label.tag) + self.delete_prompt(range_string) self.redraw() for callback in self.delete_callbacks: callback() diff --git a/esm/widgets/utils/protein_import.py b/esm/widgets/utils/protein_import.py index 0d6f6632..9209d652 100644 --- a/esm/widgets/utils/protein_import.py +++ b/esm/widgets/utils/protein_import.py @@ -114,7 +114,6 @@ def add_pdb_id(self, pdb_id: str, chain_id: str): def add_entry_to_ui(self, protein_id: str): entry_button = widgets.Button(description="Remove") entry_label = widgets.Label(value=protein_id) - entry_label.tag = protein_id entry_container = widgets.HBox([entry_button, entry_label]) def delete_entry(b): From 2f521b7f61a2a4e2c86d3420781e5fbbcabfd513 Mon Sep 17 00:00:00 2001 From: Fausto Milletari Date: Tue, 30 Jun 2026 17:05:27 +0000 Subject: [PATCH 2/3] Fix test_oss_esmc_client to use esmc_client() factory test_oss_esmc_client called client() with an ESMC model, but client() only accepts esm3 models and raises ValueError. Use the dedicated esmc_client() factory instead. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/oss_pytests/test_oss_client.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index 746bf301..d534328a 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -3,7 +3,7 @@ import pytest import torch -from esm.sdk import client # pyright: ignore +from esm.sdk import client, esmc_client # pyright: ignore from esm.sdk.api import ( # pyright: ignore ESMProtein, ESMProteinTensor, @@ -58,19 +58,19 @@ def test_oss_esmc_client(): sequence = "MALWMRLLPLLALLALAVPDPAAA" model = "esmc-300m-2024-12" - esmc_client = client(model=model, url=URL, token=API_TOKEN) + esmc = esmc_client(model=model, url=URL, token=API_TOKEN) protein = ESMProtein(sequence) - encoded_protein = esmc_client.encode(input=protein) + encoded_protein = esmc.encode(input=protein) assert isinstance(encoded_protein, ESMProteinTensor) - decoded_protein = esmc_client.decode(input=encoded_protein) + decoded_protein = esmc.decode(input=encoded_protein) assert isinstance(decoded_protein, ESMProtein) logits_config = LogitsConfig( sequence=True, return_embeddings=True, return_hidden_states=True ) - result = esmc_client.logits(input=encoded_protein, config=logits_config) + result = esmc.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) assert result.logits is not None assert isinstance(result.logits.sequence, torch.Tensor) From bba9e61cf33b9d318cb53c3a178d1752fd5be1df Mon Sep 17 00:00:00 2001 From: Fausto Milletari Date: Tue, 30 Jun 2026 17:14:31 +0000 Subject: [PATCH 3/3] Revert "Fix test_oss_esmc_client to use esmc_client() factory" This reverts commit 2f521b7f61a2a4e2c86d3420781e5fbbcabfd513. --- tests/oss_pytests/test_oss_client.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index d534328a..746bf301 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -3,7 +3,7 @@ import pytest import torch -from esm.sdk import client, esmc_client # pyright: ignore +from esm.sdk import client # pyright: ignore from esm.sdk.api import ( # pyright: ignore ESMProtein, ESMProteinTensor, @@ -58,19 +58,19 @@ def test_oss_esmc_client(): sequence = "MALWMRLLPLLALLALAVPDPAAA" model = "esmc-300m-2024-12" - esmc = esmc_client(model=model, url=URL, token=API_TOKEN) + esmc_client = client(model=model, url=URL, token=API_TOKEN) protein = ESMProtein(sequence) - encoded_protein = esmc.encode(input=protein) + encoded_protein = esmc_client.encode(input=protein) assert isinstance(encoded_protein, ESMProteinTensor) - decoded_protein = esmc.decode(input=encoded_protein) + decoded_protein = esmc_client.decode(input=encoded_protein) assert isinstance(decoded_protein, ESMProtein) logits_config = LogitsConfig( sequence=True, return_embeddings=True, return_hidden_states=True ) - result = esmc.logits(input=encoded_protein, config=logits_config) + result = esmc_client.logits(input=encoded_protein, config=logits_config) assert isinstance(result, LogitsOutput) assert result.logits is not None assert isinstance(result.logits.sequence, torch.Tensor)