diff --git a/docs/md_simulation_methodology.md b/docs/atomistic_md_simulation_methodology.md similarity index 93% rename from docs/md_simulation_methodology.md rename to docs/atomistic_md_simulation_methodology.md index 8ba7d65..427178b 100644 --- a/docs/md_simulation_methodology.md +++ b/docs/atomistic_md_simulation_methodology.md @@ -29,7 +29,7 @@ The goal of this pipeline is to compute **pairwise residue-residue interaction e - **Van der Waals (VdW)**: Attractive and repulsive components - **Electrostatic (ES)**: Attractive and repulsive components -The pipeline uses the **AMBER ff19SB** force field [4] with **TIP3P-FB** water model, implemented in OpenMM [3]. +The pipeline uses the **AMBER ff19SB** force field [3] with **TIP3P-FB** water model, implemented in OpenMM [2]. --- @@ -50,7 +50,7 @@ Before simulation, the PDB structure is "fixed" using PDBFixer: ## System Setup ### Force Field -- **Protein**: AMBER ff19SB (`amber19-all.xml`) [2, 4] +- **Protein**: AMBER ff19SB (`amber19-all.xml`) [1, 3] - **Water**: TIP3P-FB (`amber19/tip3pfb.xml`) ### Solvation @@ -67,7 +67,7 @@ When `forcefield.createSystem()` is called, OpenMM automatically assigns **parti #### RESP Charges (Restrained Electrostatic Potential) -AMBER force fields use **RESP charges** [8], which are derived from: +AMBER force fields use **RESP charges** [4], which are derived from: 1. **Quantum mechanical (QM) calculations** at the HF/6-31G* level of theory 2. **Electrostatic potential (ESP) fitting** - charges are optimized to reproduce the QM electrostatic potential around the molecule @@ -176,7 +176,7 @@ Alpha-carbon position restraints are maintained during temperature ramping to pr ### Van der Waals (Lennard-Jones) Interactions -The Lennard-Jones potential [1] describes van der Waals interactions: +The Lennard-Jones potential describes van der Waals interactions: #### Potential Energy @@ -187,7 +187,7 @@ Where: - $\epsilon_{ij}$ = well depth (combined) - $\sigma_{ij}$ = collision diameter (combined) -#### Combining Rules (Lorentz-Berthelot) [5, 6] +#### Combining Rules (Lorentz-Berthelot) $$\sigma_{ij} = \frac{\sigma_i + \sigma_j}{2}$$ @@ -202,7 +202,7 @@ $$\epsilon_{ij} = \sqrt{\epsilon_i \cdot \epsilon_j}$$ --- -### Electrostatic (Coulomb) Interactions [7] +### Electrostatic (Coulomb) Interactions #### Potential Energy @@ -448,7 +448,7 @@ NPT equilibration complete. ### Appendix B: Alpha-Carbon Restraint Implementation -During energy minimization, NPT equilibration, and NVT equilibration, alpha-carbon (CA) atoms are restrained using OpenMM's [3] `CustomExternalForce` with a **Cartesian harmonic** potential: +During energy minimization, NPT equilibration, and NVT equilibration, alpha-carbon (CA) atoms are restrained using OpenMM's [2] `CustomExternalForce` with a **Cartesian harmonic** potential: $$U_{restraint} = \frac{1}{2} k \left[(x - x_0)^2 + (y - y_0)^2 + (z - z_0)^2\right]$$ @@ -492,7 +492,7 @@ for idx in ca_indices: system.addForce(restraint_force) ``` -### Appendix C: RESP Charges [8] +### Appendix C: RESP Charges [4] #### RESP Charges Lookup Table Approach @@ -516,7 +516,7 @@ For standard amino acids, charges are **pre-computed and stored** in the force f | **Consistency** | Matched to force field VdW parameters | Standalone method | | **Use case** | MD simulations | Cheminformatics, docking | -RESP charges are specifically parameterized to work synergistically with the other AMBER force field terms (VdW, bonds, angles, dihedrals), ensuring accurate reproduction of experimental properties like solvation free energies and protein folding thermodynamics [8, 9]. +RESP charges are specifically parameterized to work synergistically with the other AMBER force field terms (VdW, bonds, angles, dihedrals), ensuring accurate reproduction of experimental properties like solvation free energies and protein folding thermodynamics [4, 5]. --- @@ -524,7 +524,7 @@ RESP charges are specifically parameterized to work synergistically with the oth This appendix provides detailed derivations and unit analysis for the force and energy calculations used in the pipeline. -#### Van der Waals (Lennard-Jones) Interactions [1] +#### Van der Waals (Lennard-Jones) Interactions ##### Potential Energy @@ -555,7 +555,7 @@ $$F_{LJ}(r) = \frac{24\epsilon_{ij}}{r} \left[ 2\left(\frac{\sigma_{ij}}{r}\righ - **Repulsive**: $F_{rep} = \frac{48\epsilon_{ij}}{r} \left(\frac{\sigma_{ij}}{r}\right)^{12}$ (positive, pushes apart) - **Attractive**: $F_{att} = -\frac{24\epsilon_{ij}}{r} \left(\frac{\sigma_{ij}}{r}\right)^{6}$ (negative, pulls together) -##### Combining Rules (Lorentz-Berthelot) [5, 6] +##### Combining Rules (Lorentz-Berthelot) $$\sigma_{ij} = \frac{\sigma_i + \sigma_j}{2} \quad \text{(arithmetic mean)}$$ @@ -563,13 +563,13 @@ $$\epsilon_{ij} = \sqrt{\epsilon_i \cdot \epsilon_j} \quad \text{(geometric mean --- -#### Electrostatic (Coulomb) Interactions [7] +#### Electrostatic (Coulomb) Interactions ##### Potential Energy $$U_{elec}(r) = \frac{k_e \cdot q_i \cdot q_j}{r}$$ -Where $k_e = 138.935456$ kJ·nm/(mol·e²) is the Coulomb constant in OpenMM units [3]. +Where $k_e = 138.935456$ kJ·nm/(mol·e²) is the Coulomb constant in OpenMM units [2]. ##### Force Derivation @@ -624,12 +624,8 @@ $$[k_e] \cdot [q]^2 / [r]^2 = \frac{\text{kJ} \cdot \text{nm}}{\text{mol} \cdot ## References -1. Lennard-Jones, J. E. (1931). "Cohesion". Proceedings of the Physical Society. 43 (5): 461–482. -2. Case, D. A., et al. (2020). "AMBER 2020 Reference Manual". -3. Eastman, P., et al. (2017). "OpenMM 7: Rapid development of high performance algorithms for molecular dynamics". PLOS Computational Biology. -4. Tian, C., et al. (2020). "ff19SB: Amino-Acid-Specific Protein Backbone Parameters Trained against Quantum Mechanics Energy Surfaces in Solution". Journal of Chemical Theory and Computation. 16 (1): 528–552. -5. Lorentz, H. A. (1881). "Ueber die Anwendung des Satzes vom Virial in der kinetischen Theorie der Gase". Annalen der Physik. 248 (1): 127–136. -6. Berthelot, D. (1898). "Sur le mélange des gaz". Comptes Rendus. 126: 1703–1706. -7. Coulomb, C. A. (1785). "Premier mémoire sur l'électricité et le magnétisme". Histoire de l'Académie Royale des Sciences. 569–577. -8. Bayly, C. I., et al. (1993). "A well-behaved electrostatic potential based method using charge restraints for deriving atomic charges: the RESP model". The Journal of Physical Chemistry. 97 (40): 10269–10280. -9. Cornell, W. D., et al. (1995). "A Second Generation Force Field for the Simulation of Proteins, Nucleic Acids, and Organic Molecules". Journal of the American Chemical Society. 117 (19): 5179–5197. +1. Case, D. A., et al. (2020). "AMBER 2020 Reference Manual". +2. Eastman, P., et al. (2024). "OpenMM 8: Molecular Dynamics Simulation with Machine Learning Potentials". Journal of Physical Chemistry B. 128 (1): 109–116. +3. Tian, C., et al. (2020). "ff19SB: Amino-Acid-Specific Protein Backbone Parameters Trained against Quantum Mechanics Energy Surfaces in Solution". Journal of Chemical Theory and Computation. 16 (1): 528–552. +4. Bayly, C. I., et al. (1993). "A well-behaved electrostatic potential based method using charge restraints for deriving atomic charges: the RESP model". The Journal of Physical Chemistry. 97 (40): 10269–10280. +5. Cornell, W. D., et al. (1995). "A Second Generation Force Field for the Simulation of Proteins, Nucleic Acids, and Organic Molecules". Journal of the American Chemical Society. 117 (19): 5179–5197. diff --git a/docs/martini_md_simulation_methodology.md b/docs/martini_md_simulation_methodology.md new file mode 100644 index 0000000..2b181f3 --- /dev/null +++ b/docs/martini_md_simulation_methodology.md @@ -0,0 +1,638 @@ +# Coarse-Grained Molecular Dynamics Simulation Methodology: Martini 3-Inspired Multi-Bead Model + +This document describes the Martini 3-inspired coarse-grained (CG) pipeline for calculating pairwise residue-residue interaction energies from molecular dynamics simulations using OpenMM. It is a companion to the atomistic methodology document and covers the scientific basis, force-field parameterisation, explicit solvent treatment, simulation pipeline, and usage of the `MartiniNonBondedForceModel` class. + +## Table of Contents + +1. [Overview](#overview) +2. [Motivation: Why Martini?](#motivation-why-martini) +3. [System Representation](#system-representation) + - [Multi-Bead Residue Mapping](#multi-bead-residue-mapping) + - [Bead Types and LJ Parameters](#bead-types-and-lj-parameters) + - [Bead Charges](#bead-charges) + - [Explicit Solvent and Ions](#explicit-solvent-and-ions) +4. [Force Field](#force-field) + - [Bonded Interactions](#bonded-interactions) + - [Van der Waals Interactions](#van-der-waals-lennard-jones-interactions) + - [Electrostatic Interactions: Reaction-Field Coulomb](#electrostatic-interactions-reaction-field-coulomb) + - [Exclusions and Cutoff](#exclusions-and-cutoff) + - [Barostat](#barostat) +5. [Solvation Protocol](#solvation-protocol) + - [Water Bead Placement](#water-bead-placement) + - [Ion Placement](#ion-placement) +6. [Simulation Pipeline](#simulation-pipeline) + - [Energy Minimization](#1-energy-minimization) + - [NVT Equilibration](#2-nvt-equilibration) + - [NPT Equilibration](#3-npt-equilibration) + - [Production MD](#4-production-md) +7. [Energy Calculations](#energy-calculations) + - [Bead-Level Pairwise Energies](#bead-level-pairwise-energies) + - [Aggregation to Residue Level](#aggregation-to-residue-level) + - [Distance Matrix](#distance-matrix) +8. [Output Matrices](#output-matrices) +9. [Comparison with Atomistic Model](#comparison-with-atomistic-model) +10. [Usage Example](#usage-example) +11. [Appendix](#appendix) +12. [References](#references) + +--- + +## Overview + +The Martini pipeline computes the same **pairwise residue-residue interaction energies** as the atomistic pipeline — Van der Waals (attractive and repulsive) and electrostatic (attractive and repulsive) — using a multi-bead CG representation inspired by the Martini 3 force field [1]. + +Each residue is represented by **1–4 beads** depending on sidechain complexity. Bead positions are computed as centroids of the corresponding heavy atoms in the input PDB file. Compared to the atomistic model, this model: + +- Encodes sidechain geometry explicitly through multiple beads per residue +- Uses distinct bead types with physically meaningful LJ parameters from Martini 3 +- Simulates **explicit CG water** (W beads, each representing ~4 H₂O molecules) and **NaCl ions** +- Applies **reaction-field electrostatics** with proper solvent screening at the periodic boundary + +The `MartiniNonBondedForceModel` class implements an identical `run_full_pipeline()` interface to `AtomisticNonBondedForceModel`, returning the same five output matrices. `ProteogramV2` selects between models via the `cg_method` argument (`'martini'` or `None` for atomistic). + +--- + +## Motivation: Why Martini? + +| Property | Atomistic | Martini | +|---|---|---| +| Beads/atoms per residue | ~10 heavy + H | 1–4 | +| Sidechain geometry | Full | Coarse (1–3 SC beads) | +| Solvent | Explicit TIP3P water | Explicit CG W beads | +| Electrostatics | PME (full long-range) | Reaction-field, 1.1 nm cutoff | +| Timestep | 2 fs | 10 fs (eq) / 20 fs (prod) | +| Typical speedup vs atomistic | 1× | 10–30× | + +Martini provides a faster alternative to the atomistic model while retaining explicit solvent and chemically distinct sidechain representations: + +- **Faster than atomistic**: the system has ~5–10× fewer particles than all-atom + TIP3P, and the 10–20 fs timestep (vs. 2 fs) further reduces wall time. +- **Explicit solvent with periodic boundaries**: solvates the protein in a periodic CG water box and equilibrates the box volume via NPT before production, providing realistic dielectric screening. +- **Chemically distinct sidechain beads**: bead types (apolar C, polar N, charged Q, aromatic TC) encode sidechain chemical identity more explicitly than a single epsilon value. + +--- + +## System Representation + +### Multi-Bead Residue Mapping + +Each residue contributes one **backbone (BB)** bead placed at the centroid of `N, CA, C, O` backbone atoms, plus 0–3 **sidechain (SC)** beads at sidechain atom centroids. If an expected atom is missing from the PDB, the code falls back to the Cα position automatically. + +| Residue(s) | Beads | Labels | +|---|---|---| +| GLY | 1 | BB | +| ALA, VAL, LEU, ILE, PRO, MET, CYS | 2 | BB, SC1 | +| SER, THR, ASN, GLN, ASP, GLU | 2 | BB, SC1 | +| LYS, ARG, HIS, PHE, TYR | 3 | BB, SC1, SC2 | +| TRP | 4 | BB, SC1, SC2, SC3 | + +The BB bead is always the first bead within a residue. This ordering is used throughout the aggregation pipeline. + +--- + +### Bead Types and LJ Parameters + +Five protein bead types encode chemical identity through different LJ well depths and radii, following approximate Martini 3 parameters [1]: + +| Bead type | Residues / role | σ (nm) | ε (kJ/mol) | +|---|---|---|---| +| BB | All backbone beads | 0.47 | 5.6 | +| C | Apolar sidechains (ALA, VAL, LEU, ILE, PRO, MET, CYS) | 0.47 | 4.5 | +| N | Polar sidechains (SER, THR, ASN, GLN; linker beads of LYS, ARG, HIS) | 0.47 | 3.6 | +| Q | Charged sidechains (ASP SC1, GLU SC1, LYS SC2, ARG SC2) | 0.47 | 5.6 | +| TC | Tiny cyclic — aromatic rings (HIS SC2, PHE SC1/SC2, TYR SC1/SC2, TRP SC1/SC2/SC3) | 0.38 | 3.1 | + +Three additional bead types are used for explicit solvent and ions (not included in the residue-level energy maps): + +| Bead type | Role | σ (nm) | ε (kJ/mol) | Mass (Da) | +|---|---|---|---|---| +| W | CG water (~4 H₂O per bead) | 0.47 | 1.00 | 72.0 | +| ION_NA | Na⁺ | 0.258 | 0.063 | 22.99 | +| ION_CL | Cl⁻ | 0.440 | 0.830 | 35.45 | + +Pairwise parameters use Lorentz-Berthelot combining rules: + +$$\sigma_{ij} = \frac{\sigma_i + \sigma_j}{2}, \qquad \epsilon_{ij} = \sqrt{\epsilon_i \cdot \epsilon_j}$$ + +Protein beads all have a uniform mass of **72 Da**, the approximate mass of an average amino acid fragment at this level of coarse-graining. + +--- + +### Bead Charges + +Only explicitly ionised sidechains carry charge. All backbone beads have charge 0. The charge assignments follow Martini 3 conventions: + +| Residue | Charged bead | Charge (e) | Basis | +|---|---|---|---| +| ASP | SC1 (Q type) | −1.0 | Deprotonated at pH 7 | +| GLU | SC1 (Q type) | −1.0 | Deprotonated at pH 7 | +| LYS | SC2 (Q type) | +1.0 | Fully protonated at pH 7 | +| ARG | SC2 (Q type) | +1.0 | Fully protonated at pH 7 | +| HIS | SC2 (TC type) | +0.5 | ~30% protonated at pH 7 (pKa ≈ 6.5) | +| All others | All beads | 0.0 | Neutral at pH 7 | + +--- + +### Explicit Solvent and Ions + +The Martini simulation runs in an explicit solvent box: + +- **W beads** representing ~4 H₂O each are placed on a cubic grid with 1.2 nm padding on each face of the protein bounding box. +- **Na⁺ and Cl⁻** ions are placed by randomly replacing W beads — first to neutralise the net protein charge, then to reach 0.15 M physiological NaCl. +- Periodic boundary conditions (PBC) are applied to all forces. + +The solvent participates fully in the dynamics (driving realistic thermal fluctuations and dielectric screening) but is **excluded from the residue-level energy maps**: the numpy pairwise energy calculation slices only the first $B_{prot}$ bead positions, so water-protein, water-water, and ion-protein interactions do not appear in the output matrices. + +--- + +## Force Field + +### Bonded Interactions + +#### Backbone bonds (BB–BB, inter-residue) + +$$U_{\text{bond}}(r) = \frac{1}{2} k_{bb} (r - r_0)^2$$ + +| Parameter | Value | +|---|---| +| Force constant $k_{bb}$ | 3800 kJ/(mol·nm²) | +| Equilibrium length $r_0$ | 0.35 nm | + +#### Backbone–sidechain bonds (BB–SC1, intra-residue) + +$$U_{\text{bond}}(r) = \frac{1}{2} k_{bs} (r - r_0)^2$$ + +| Parameter | Value | +|---|---| +| Force constant $k_{bs}$ | 3800 kJ/(mol·nm²) | +| Equilibrium length $r_0$ | 0.27 nm | + +#### Sidechain–sidechain bonds (SC–SC, intra-residue) + +$$U_{\text{bond}}(r) = \frac{1}{2} k_{ss} (r - r_0)^2$$ + +| Parameter | Value | +|---|---| +| Force constant $k_{ss}$ | 2500 kJ/(mol·nm²) | +| Equilibrium length $r_0$ | 0.27 nm | + +#### Backbone angles (BB–BB–BB) + +$$U_{\text{angle}}(\theta) = \frac{1}{2} k_\theta (\theta - \theta_0)^2$$ + +| Parameter | Value | +|---|---| +| Force constant $k_\theta$ | 40 kJ/(mol·rad²) | +| Equilibrium angle $\theta_0$ | 127° (2.217 rad) | + +127° is a standard Martini backbone angle for a generic / random-coil protein chain. + +No torsion (dihedral) terms are included — a simplification relative to full Martini 3, acceptable for the proteogram application where the structural signal is the frame-averaged energy pattern rather than accurate free energies of specific conformations. + +--- + +### Van der Waals (Lennard-Jones) Interactions + +$$U_{LJ}(r) = 4\epsilon_{ij} \left[ \left(\frac{\sigma_{ij}}{r}\right)^{12} - \left(\frac{\sigma_{ij}}{r}\right)^{6} \right]$$ + +Applied as `CutoffPeriodic` with a 1.1 nm cutoff, matching the standard Martini 3 LJ cutoff [1]. The cutoff is applied to both protein–protein and protein–solvent pairs; the latter drives realistic solvation dynamics. + +#### Separated energy terms + +| Component | Formula | Sign | Physical meaning | +|---|---|---|---| +| **Repulsive** | $4\epsilon_{ij}(\sigma_{ij}/r)^{12}$ | + | Excluded volume / steric clash | +| **Attractive** | $-4\epsilon_{ij}(\sigma_{ij}/r)^{6}$ | − | Dispersion / hydrophobic contact | + +--- + +### Electrostatic Interactions: Reaction-Field Coulomb + +Standard Martini 3 uses **reaction-field (RF) electrostatics** [2] rather than bare Coulomb or PME. The RF formula models the dielectric response of the medium beyond the cutoff $r_c$ as a uniform continuum with permittivity $\varepsilon_s$ (bulk water): + +$$U_{\text{RF}}(r) = \frac{k_e^* \cdot q_i q_j}{\varepsilon_r} \left( \frac{1}{r} + k_{\text{rf}} r^2 - c_{\text{rf}} \right), \quad r < r_c$$ + +Where: +- $k_e^* = 138.935456$ kJ·nm/(mol·e²) — vacuum Coulomb constant +- $\varepsilon_r = 15$ — protein-interior relative permittivity (electronic polarizability screening) +- $k_{\text{rf}} = \frac{\varepsilon_s - \varepsilon_r}{(2\varepsilon_s + \varepsilon_r) r_c^3}$ — RF correction curvature term +- $c_{\text{rf}} = \frac{3\varepsilon_s}{(2\varepsilon_s + \varepsilon_r) r_c}$ — RF correction constant (ensures continuity at $r_c$) + +#### Parameter values + +| Parameter | Value | Description | +|---|---|---| +| $r_c$ | 1.1 nm | LJ and RF cutoff | +| $\varepsilon_r$ | 15 | Protein-interior permittivity | +| $\varepsilon_s$ | 80 | Bulk water permittivity | +| $k_{\text{rf}}$ | 0.2791 nm⁻³ | RF curvature coefficient | +| $c_{\text{rf}}$ | 1.2468 nm⁻¹ | RF continuity constant | +| Effective $k_e$ | 138.935456 / 15 = 9.262 kJ·nm/(mol·e²) | Pre-screened Coulomb prefactor | + +The full expression used in the `CustomNonbondedForce` and in the numpy energy calculation is: + +$$U_{\text{RF}}(r) = 9.262 \cdot q_i q_j \left( \frac{1}{r} + 0.2791 \cdot r^2 - 1.2468 \right), \quad r < 1.1 \text{ nm}$$ + +The $k_{\text{rf}} r^2$ term grows without bound beyond the cutoff, so the 1.1 nm cutoff is applied strictly in both the OpenMM `CutoffPeriodic` force and in the numpy pairwise energy calculation during production snapshots. + +#### Energy classification + +| Condition | Energy | Type | +|---|---|---| +| $q_i q_j > 0$ (like charges) | Positive | **Repulsive** | +| $q_i q_j < 0$ (opposite charges) | Negative | **Attractive** | + +--- + +### Exclusions and Cutoff + +**Exclusions**: All intra-residue bead pairs are excluded from both LJ and Coulomb forces (backbone and sidechain beads within the same residue interact only through the bonded terms). Additionally, 1-2 and 1-3 backbone BB pairs across sequential residues are excluded. + +**Cutoff**: `CutoffPeriodic` at 1.1 nm for both LJ and Coulomb. This is the standard Martini 3 non-bonded cutoff. + +--- + +### Barostat + +A `MonteCarloBarostat` is added to the system during `setup_system()` at 1 bar / 25-step update frequency. Its frequency is managed across pipeline stages: + +| Stage | Barostat frequency | +|---|---| +| NVT equilibration | **0** (disabled — constant volume) | +| NPT equilibration | 25 (active — box volume adjusts) | +| Production | 25 (active — maintains pressure during sampling) | + +Disabling the barostat during NVT is essential: the solvation box starts slightly under-dense (grid spacing at the LJ minimum rather than liquid density). Allowing volume rescaling before the system is thermally equilibrated causes extreme local forces that drive coordinates to NaN. + +--- + +## Solvation Protocol + +### Water Bead Placement + +W beads are placed on a **cubic grid** with spacing set to the W–W LJ minimum distance: + +$$d_{\text{grid}} = 2^{1/6} \cdot \sigma_W = 2^{1/6} \times 0.47 \approx 0.527 \text{ nm}$$ + +This spacing ensures adjacent W beads start at zero force — placing them at $\sigma_W$ (0.47 nm, the LJ zero-crossing) would put them in the repulsive region, launching beads during minimisation. + +The box extends 1.2 nm beyond the protein bounding box on each face: + +$$L_x = x_{\max} - x_{\min} + 2 \times 1.2 \text{ nm}, \quad \text{etc.}$$ + +W beads within **0.53 nm** of any protein bead are removed to prevent high-energy initial contacts. The resulting grid density is ~6.8 W beads/nm³, slightly below the Martini liquid-water target of ~8.4 W/nm³ — NPT equilibration shrinks the box to the correct density. + +### Ion Placement + +Ions are placed by randomly sampling W bead slots (fixed seed 42 for reproducibility): + +1. **Neutralisation**: if the protein has net charge $q_{\text{net}}$ (rounded to nearest integer), add $|q_{\text{net}}|$ counterions (Na⁺ for negative protein, Cl⁻ for positive). +2. **Physiological salt**: add NaCl pairs to reach 0.15 M. The number of pairs is: + +$$n_{\text{pairs}} = \text{round}\left( 0.15 \text{ mol/L} \times V_{\text{box}} \text{ (L)} \times N_A \right)$$ + +The selected W beads are replaced by ion particles; the remaining W beads stay in place. + +--- + +## Simulation Pipeline + +### Default Parameters + +| Parameter | Value | Description | +|---|---|---| +| Temperature | 310.15 K (37 °C) | Physiological temperature | +| Equilibration timestep | 10 fs | Conservative while system settles (NVT + NPT) | +| Production timestep | 20 fs | Standard Martini 3 W-model timestep | +| Integrator | Langevin Middle | 1 ps⁻¹ friction coefficient | +| Protein bead mass | 72 Da | Uniform for all protein beads | +| W bead mass | 72 Da | ~4 × 18 Da | +| Box padding | 1.2 nm | Solvent buffer on each face | + +> **State continuity**: bead positions are propagated between all pipeline stages. After each stage, `getState(enforcePeriodicBox=True)` is called to wrap positions back into the primary box before creating the next simulation object. After NPT equilibration, the equilibrated box vectors are explicitly propagated back to the topology, system defaults, and `_box_lengths` so that the production simulation starts with the correct box geometry. + +--- + +### 1. Energy Minimization + +**Purpose**: Relax any strained initial bead geometry and high-energy water–protein contacts. + +| Parameter | Value | +|---|---| +| Algorithm | L-BFGS (OpenMM default) | +| Max iterations | 2,000 | + +--- + +### 2. NVT Equilibration + +**Purpose**: Thermalise the system to 310 K at constant volume before allowing box relaxation. + +| Parameter | Value | +|---|---| +| Ensemble | NVT (barostat disabled) | +| Steps | 25,000 (250 ps at 10 fs/step) | +| Temperature | 310.15 K | +| Reporting interval | 2,500 steps (25 ps) | + +Velocities are initialised from a Maxwell-Boltzmann distribution at 310 K. The barostat is disabled (frequency = 0) throughout this stage. Allowing box rescaling before thermalisation is complete causes the minimiser to produce NaN coordinates — the slightly under-dense initial grid generates large inter-bead forces that NPT would amplify before Langevin friction has damped them. + +--- + +### 3. NPT Equilibration + +**Purpose**: Allow the simulation box volume to relax to the correct liquid-water density under barostat control. + +| Parameter | Value | +|---|---| +| Ensemble | NPT (barostat frequency = 25) | +| Steps | 25,000 (250 ps at 10 fs/step) | +| Pressure | 1.0 bar | +| Temperature | 310.15 K | +| Reporting interval | 2,500 steps (25 ps) with volume output | + +After NPT equilibration, the converged box vectors are **propagated back** to the OpenMM topology, system defaults, and the internal `_box_lengths` array. Without this propagation, the subsequent production simulation is created with the original (over-large) box while bead positions correspond to the compressed NPT box — this creates extreme local density at one corner and immediately produces NaN. + +--- + +### 4. Production MD + +**Purpose**: Generate a thermally and mechanically equilibrated trajectory for residue-level energy sampling. + +| Parameter | Value | +|---|---| +| Ensemble | NPT (barostat frequency = 25) | +| Steps | 250,000 (5 ns at 20 fs/step) | +| Energy snapshot interval | 5,000 steps (100 ps) | +| Frames collected | 50 | + +At each snapshot, all bead positions are extracted and protein-only pairwise energies are computed in numpy (see [Energy Calculations](#energy-calculations)). Solvent and ion beads drive the dynamics but do not contribute to the output energy maps. + +--- + +## Energy Calculations + +### Bead-Level Pairwise Energies + +At each production snapshot, bead-level $B_{prot} \times B_{prot}$ interaction matrices are computed in vectorised numpy, **consistent with the OpenMM force expressions**: + +#### LJ energy (protein beads only, upper triangle) + +$$U_{LJ,b_1 b_2} = \begin{cases} +4\epsilon_{b_1 b_2}\left[\left(\frac{\sigma_{b_1 b_2}}{r_{b_1 b_2}}\right)^{12} - \left(\frac{\sigma_{b_1 b_2}}{r_{b_1 b_2}}\right)^{6}\right] & r < 1.1 \text{ nm, non-excluded} \\ +0 & \text{otherwise} +\end{cases}$$ + +Split into repulsive (positive $r^{-12}$ term) and attractive (negative $r^{-6}$ term) components. + +#### Reaction-field Coulomb (protein beads only, upper triangle) + +$$U_{\text{RF},b_1 b_2} = \begin{cases} +9.262 \cdot q_{b_1} q_{b_2} \left(\frac{1}{r} + 0.2791 r^2 - 1.2468\right) & r < 1.1 \text{ nm, non-excluded} \\ +0 & \text{otherwise} +\end{cases}$$ + +Split into attractive ($U < 0$) and repulsive ($U > 0$) components. + +#### Minimum image convention + +Bead-pair distances use the minimum image convention (MIC) for periodic boundary conditions: + +$$\Delta\vec{r}_{ij} \leftarrow \Delta\vec{r}_{ij} - \text{round}\!\left(\frac{\Delta\vec{r}_{ij}}{L}\right) \cdot L$$ + +where $L$ is the box edge length vector. This matches OpenMM's `CutoffPeriodic` treatment. + +--- + +### Aggregation to Residue Level + +Bead-level $B_{prot} \times B_{prot}$ matrices are collapsed to residue-level $N \times N$ matrices via the **indicator matrix** $\mathbf{I} \in \{0,1\}^{N \times B_{prot}}$, where $I_{ib} = 1$ iff bead $b$ belongs to residue $i$: + +$$E^{\text{residue}}[i,j] = \sum_{b_1 \in i,\; b_2 \in j} E^{\text{bead}}[b_1, b_2] = \left(\mathbf{I} \cdot E^{\text{bead}} \cdot \mathbf{I}^T\right)_{ij}$$ + +This matrix multiply aggregates all bead-pair contributions (BB–BB, BB–SC, SC–SC) between residue $i$ and residue $j$ into a single residue-pair energy. + +> **Contrast with atomistic**: the atomistic model iterates over all atom pairs within each residue pair and normalises by the number of atom pairs. The Martini model sums bead-pair energies without normalisation — each bead already represents multiple heavy atoms, so the energies are inherently coarser. + +--- + +### Distance Matrix + +The distance map uses **BB bead positions only**, providing a residue-level backbone distance matrix analogous to the Cα distance matrix in `ProteogramV2`: + +$$d_{ij}^{\text{BB}} = \|\vec{r}_{BB,i} - \vec{r}_{BB,j}\| \times 10 \quad \text{(Å)}$$ + +Upper triangle only ($j \geq i + 3$); lower triangle is zero. + +### Accumulation and averaging + +All five matrices are accumulated over all production frames and averaged: + +$$\bar{E}_{ij} = \frac{1}{N_{\text{frames}}} \sum_{f=1}^{N_{\text{frames}}} E_{ij}^{(f)}$$ + +--- + +## Output Matrices + +The pipeline produces **5 N×N matrices** (where N = number of protein residues), identical in format to `AtomisticNonBondedForceModel`: + +| Matrix | Formula | Units | +|---|---|---| +| `vdw_energy_attractive` | $-4\epsilon_{ij}(\sigma_{ij}/r)^6$, summed over bead pairs | kJ/mol | +| `vdw_energy_repulsive` | $4\epsilon_{ij}(\sigma_{ij}/r)^{12}$, summed over bead pairs | kJ/mol | +| `es_energy_attractive` | $U_{\text{RF}}$ when $q_{b_1}q_{b_2} < 0$, summed over bead pairs | kJ/mol | +| `es_energy_repulsive` | $U_{\text{RF}}$ when $q_{b_1}q_{b_2} > 0$, summed over bead pairs | kJ/mol | +| `dist_avg` | BB–BB bead distance | Å | + +### Matrix properties + +- **Dimensions**: N × N (protein residues only; solvent excluded) +- **Storage**: Upper triangle only +- **Averaging**: Frame-averaged over all production snapshots +- **Normalisation**: None — bead-pair energies are summed, not averaged per pair + +### Normalisation convention in ProteogramV2 + +`ProteogramV2.normalize_map()` rescales each matrix independently to [0–255]. Attractive energy channels (`vdw_att`, `es_att`) have values ≤ 0; their absolute value is taken first so that zero (no interaction) maps to 0 (dark) and large-magnitude interactions map to 255 (bright). Repulsive and distance channels are already ≥ 0 and normalise correctly without transformation. + +--- + +## Comparison with Atomistic Model + +| Property | Atomistic | Martini | +|---|---|---| +| Class | `AtomisticNonBondedForceModel` | `MartiniNonBondedForceModel` | +| Particles | All heavy + H + TIP3P water | 1–4 CG beads/residue + W + ions | +| Solvent | Explicit TIP3P | Explicit W beads (CG) | +| Electrostatics | PME, long-range | Reaction-field, 1.1 nm cutoff | +| Dielectric | RESP partial charges | ε_r=15 protein, ε_s=80 RF boundary | +| Charges | RESP (all atoms) | Ionised sidechains only (on SC bead) | +| Timestep | 2 fs | 10 fs (eq) / 20 fs (prod) | +| Equilibration | NVT (100 ps) + NPT (100 ps) | NVT (250 ps) + NPT (250 ps) | +| Production | 1 ns (500,000 steps) | 5 ns (250,000 steps) | +| Energy interval | 20 ps (50 frames) | 100 ps (50 frames) | +| Typical speedup | 1× baseline | 10–30× | +| API / output | `run_full_pipeline()` → 5 matrices | identical | + +--- + +## Usage Example + +### Minimal API usage + +```python +from proteogram.v2 import MartiniNonBondedForceModel + +model = MartiniNonBondedForceModel( + pdb_path='protein.pdb', + output_dir='output', + temperature=310.15, # Kelvin + use_gpu=False, +) + +vdw_att, vdw_rep, es_att, es_rep, dist_avg = model.run_full_pipeline( + nvt_steps=25000, # 250 ps NVT equilibration + npt_steps=25000, # 250 ps NPT equilibration + production_steps=250000, # 5 ns production (at 20 fs/step) + energy_calc_interval=5000, # snapshot every 100 ps (50 frames) + debug=False, +) + +model.cleanup_all_resources(final_run=True) +``` + +### Via ProteogramV2 (recommended) + +```python +from proteogram.v2 import ProteogramV2 + +# Set cg_method at construction time — all calls use Martini CG +pg = ProteogramV2( + pdb_path='protein.pdb', + output_dir='output', + chain_id='A', + cg_method='martini', +) +proteogram_array, errors = pg.calculate_proteogram() + +# Or override per-call to compare models on the same protein +pg = ProteogramV2('protein.pdb', 'output', 'A') +array_atomistic, _ = pg.calculate_proteogram(cg_method=None) +array_martini, _ = pg.calculate_proteogram(cg_method='martini') +``` + +--- + +## Appendix + +### Appendix A: Energy Monitoring + +#### Expected energy ranges (Martini) + +| Stage | Expectation | +|---|---| +| After minimization | Large negative; W bead repulsions resolved | +| NVT equilibration | Energy decreases and stabilises; temperature converges to 310 K | +| NPT equilibration | Volume decreases from initial grid density (~6.8 W/nm³) to Martini water density (~8.4 W/nm³) | +| Production | Fluctuates around a stable mean; no systematic drift | + +#### Per-bead energy reference + +| System size | Approximate system particles | Typical potential energy | +|---|---|---| +| Small protein (50 residues) | ~150 protein + ~2,000 W + ions | −10,000 to −25,000 kJ/mol | +| Medium protein (100 residues) | ~300 protein + ~4,000 W + ions | −20,000 to −50,000 kJ/mol | +| Large protein (200 residues) | ~600 protein + ~8,000 W + ions | −40,000 to −100,000 kJ/mol | + +Martini energies are larger in absolute magnitude than atomistic energies on a per-residue basis due to the coarser force field, but smaller in total because the system has fewer particles. Track stability and trends rather than specific values. + +#### Setup summary printed by `setup_system()` + +``` + Solvation box: 7.40 × 7.10 × 7.20 nm (378.4 nm³) + Placed 2431 W beads (89 removed for clashes with protein) + Net protein charge = -2 → neutralization: 2 Na⁺, 0 Cl⁻ + NaCl (0.15 M, 34 pairs): 36 Na⁺ total, 34 Cl⁻ total [system net charge = 0] +Martini CG system: 153 residues, 287 protein beads + 2431 W + 36 Na⁺ + 34 Cl⁻ = 2788 total; 286 bonds, 18 charged protein beads +``` + +Sanity checks: +- System net charge should be 0 (or ±1 for rounding of His fractional charges) +- Box volume should grow with protein size +- W bead count should be several times the protein bead count + +--- + +### Appendix B: Residue Bead Definitions + +Full bead assignment table. Format: bead type / label / charge (e) / centroid atom names. + +| Residue | BB | SC1 | SC2 | SC3 | +|---|---|---|---|---| +| GLY | BB/0.0/N,CA,C,O | — | — | — | +| ALA | BB/0.0/N,CA,C,O | C/SC1/0.0/CB | — | — | +| VAL | BB/0.0/N,CA,C,O | C/SC1/0.0/CB,CG1,CG2 | — | — | +| LEU | BB/0.0/N,CA,C,O | C/SC1/0.0/CB,CG,CD1,CD2 | — | — | +| ILE | BB/0.0/N,CA,C,O | C/SC1/0.0/CB,CG1,CG2,CD1 | — | — | +| PRO | BB/0.0/N,CA,C,O | C/SC1/0.0/CB,CG,CD | — | — | +| MET | BB/0.0/N,CA,C,O | C/SC1/0.0/CB,CG,SD,CE | — | — | +| CYS | BB/0.0/N,CA,C,O | C/SC1/0.0/CB,SG | — | — | +| SER | BB/0.0/N,CA,C,O | N/SC1/0.0/CB,OG | — | — | +| THR | BB/0.0/N,CA,C,O | N/SC1/0.0/CB,OG1,CG2 | — | — | +| ASN | BB/0.0/N,CA,C,O | N/SC1/0.0/CB,CG,OD1,ND2 | — | — | +| GLN | BB/0.0/N,CA,C,O | N/SC1/0.0/CB,CG,CD,OE1,NE2 | — | — | +| ASP | BB/0.0/N,CA,C,O | Q/SC1/−1.0/CB,CG,OD1,OD2 | — | — | +| GLU | BB/0.0/N,CA,C,O | Q/SC1/−1.0/CB,CG,CD,OE1,OE2 | — | — | +| LYS | BB/0.0/N,CA,C,O | N/SC1/0.0/CB,CG,CD | Q/SC2/+1.0/CE,NZ | — | +| ARG | BB/0.0/N,CA,C,O | N/SC1/0.0/CB,CG,CD | Q/SC2/+1.0/NE,CZ,NH1,NH2 | — | +| HIS | BB/0.0/N,CA,C,O | TC/SC1/0.0/CB,CG | TC/SC2/+0.5/ND1,CD2,CE1,NE2 | — | +| PHE | BB/0.0/N,CA,C,O | TC/SC1/0.0/CB,CG,CD1,CD2 | TC/SC2/0.0/CE1,CE2,CZ | — | +| TYR | BB/0.0/N,CA,C,O | TC/SC1/0.0/CB,CG,CD1,CD2 | TC/SC2/0.0/CE1,CE2,CZ,OH | — | +| TRP | BB/0.0/N,CA,C,O | TC/SC1/0.0/CB,CG,CD1,NE1 | TC/SC2/0.0/CD2,CE2,CZ2,CH2 | TC/SC3/0.0/CE3,CZ3 | + +--- + +### Appendix C: Reaction-Field Parameter Derivation + +The RF parameters follow Tironi et al. [2]: + +$$k_{\text{rf}} = \frac{\varepsilon_s - \varepsilon_r}{2\varepsilon_s + \varepsilon_r} \cdot \frac{1}{r_c^3} = \frac{80 - 15}{2 \times 80 + 15} \cdot \frac{1}{1.1^3} = \frac{65}{175} \cdot \frac{1}{1.331} \approx 0.2791 \text{ nm}^{-3}$$ + +$$c_{\text{rf}} = \frac{3\varepsilon_s}{2\varepsilon_s + \varepsilon_r} \cdot \frac{1}{r_c} = \frac{3 \times 80}{2 \times 80 + 15} \cdot \frac{1}{1.1} = \frac{240}{175} \cdot \frac{1}{1.1} \approx 1.2468 \text{ nm}^{-1}$$ + +These terms ensure that: +1. The electrostatic potential is continuous at $r = r_c$ +2. The electrostatic force is continuous at $r = r_c$ (no abrupt truncation artefact) +3. Long-range interactions ($r > r_c$) are approximated by the response of a dielectric continuum with $\varepsilon_s = 80$ + +The $c_{\text{rf}}$ constant shifts the potential so that $U_{\text{RF}}(r_c) = 0$, preventing a discontinuous energy jump at the cutoff that would introduce systematic errors in the simulation. + +--- + +### Appendix D: Scientific Basis and Limitations + +#### What the Martini CG model captures + +- **Sidechain chemical identity**: distinct bead types (C, N, Q, TC) encode apolar, polar, charged, and aromatic character, providing chemically meaningful residue differentiation. +- **Salt bridges and electrostatic interactions**: charges placed on the correct sidechain bead give geometrically meaningful electrostatic interactions between charged pairs. +- **Solvent screening via explicit water**: W beads provide realistic dielectric boundary conditions and drive protein conformational sampling through collisions, analogously to TIP3P in the atomistic model. +- **Aromatic interactions**: TC bead types have smaller σ (0.38 nm vs. 0.47 nm) and lower ε (3.1 kJ/mol), partly encoding the reduced effective size of aromatic ring contacts. +- **Chain geometry**: backbone bonds (0.35 nm), BB-SC bonds (0.27 nm), and the 127° backbone angle maintain realistic protein topology. + +#### What the Martini CG model does not capture + +- **Dihedral terms**: the implementation omits backbone and sidechain torsion potentials present in full Martini 3. Secondary structure propensity is therefore weaker than the full force field. +- **Specific hydrogen bonding**: no explicit H-bond terms. Polar interactions are represented only through N-bead LJ parameters. +- **Desolvation penalties beyond LJ**: the cost of burying charged residues is partially captured through explicit water competition but not through an explicit transfer free energy term. +- **Side-chain conformational detail**: SC bead centroids cannot represent rotamer diversity; packing interactions are averaged into a single centroid position per bead. + +#### Why these limitations are acceptable for proteograms + +`ProteogramV2.normalize_map()` rescales each energy channel independently to [0–255] (attractive channels via `abs()` first). The **relative spatial pattern** of residue-residue interactions — which pairs are strongly interacting relative to others — is what the downstream model uses for structure comparison. The Martini model captures this pattern while running significantly faster than the atomistic model. + +--- + +## References + +1. Souza, P. C. T., et al. (2021). "Martini 3: a general purpose force field for coarse-grained molecular dynamics". *Nature Methods*. 18(4): 382–388. +2. Tironi, I. G., et al. (1995). "A generalized reaction field method for molecular dynamics simulations". *The Journal of Chemical Physics*. 102(13): 5451–5459. +3. Eastman, P., et al. (2023). "OpenMM 8: Molecular Dynamics Simulation with Machine Learning Potentials". *Journal of Physical Chemistry B*. 128(1): 109-116. diff --git a/proteogram/common/constants.py b/proteogram/common/constants.py index 6cec440..63b1504 100644 --- a/proteogram/common/constants.py +++ b/proteogram/common/constants.py @@ -1,3 +1,6 @@ +import math +import numpy as np + RESIDUE_LIST = [ ("A", "ALA"), ("R", "ARG"), @@ -250,4 +253,185 @@ 'E': -31, 'P': -46, 'D': -55 -} \ No newline at end of file +} + +# ── Martini 3 CG force field constants ─────────────────────────────────────── +# Sources: https://cgmartini.nl/docs/downloads/force-field-parameters/martini3/particle-definitions.html, https://github.com/maccallumlab/martini_openmm/tree/master/tutorial/martini_v3.0.0.itp and Tironi et al. (J. Chem. Phys. 1995) for RF parameters. +# Coulomb prefactor: k_e / ε_r where ε_r = 15 (protein-interior dielectric). +# k_rf and c_rf follow Tironi et al. (J. Chem. Phys. 1995). +MARTINI_COULOMB_K = 138.935456 / 15.0 # kJ·nm/(mol·e²), ε_r = 15 +MARTINI_SOLVENT_CUTOFF_NM = 1.1 # nm — LJ and RF cutoff +MARTINI_EPS_WATER = 80.0 # bulk water dielectric +MARTINI_EPS_PROT = 15.0 # protein-interior dielectric +MARTINI_RF_K = ((MARTINI_EPS_WATER - MARTINI_EPS_PROT) + / (2 * MARTINI_EPS_WATER + MARTINI_EPS_PROT) + / MARTINI_SOLVENT_CUTOFF_NM ** 3) +MARTINI_RF_C = ((3 * MARTINI_EPS_WATER) + / (2 * MARTINI_EPS_WATER + MARTINI_EPS_PROT) + / MARTINI_SOLVENT_CUTOFF_NM) +MARTINI_SOLVENT_BUFFER_NM = 1.2 # nm padding around protein +MARTINI_SALT_CONC_M = 0.15 # mol/L — physiological NaCl + +# Bead type → (sigma nm, epsilon kJ/mol); values from martini_v3.0.0.itp nonbond_params. +# Combination rule 2 (Lorentz-Berthelot): σ_ij = (σ_i+σ_j)/2, ε_ij = sqrt(ε_i*ε_j). +MARTINI_BEAD_TYPE_PARAMS = { + 'BB': (0.47, 4.06), # backbone: P2 random-coil self-ε from ITP + 'C': (0.47, 3.39), # apolar regular (C1-C5): self-ε from ITP + 'SC': (0.41, 2.35), # apolar small (SC1-SC4): σ & ε from ITP + 'N': (0.47, 3.52), # polar (N1-N2): self-ε from ITP + 'Q': (0.47, 5.95), # charged (Q4): self-ε from ITP + 'TC': (0.34, 1.51), # tiny apolar (TC1-TC5): σ & ε from ITP + 'W': (0.47, 1.00), # water bead (~4 H₂O) + 'ION_NA': (0.354, 1.18), # Na+: TQ5 self-ε from ITP + 'ION_CL': (0.354, 1.18), # Cl-: TQ5 self-ε from ITP +} + +# Bond force constants (kJ/(mol·nm²) and nm) +MARTINI_BB_BB_K = 3800.0 +MARTINI_BB_BB_R0 = 0.35 +MARTINI_BB_SC_K = 3800.0 +MARTINI_BB_SC_R0 = 0.27 +MARTINI_SC_SC_K = 2500.0 +MARTINI_SC_SC_R0 = 0.27 + +# Backbone BB–BB–BB angle (127° generic/random-coil) +MARTINI_BB_ANGLE_K = 40.0 # kJ/(mol·rad²) +MARTINI_BB_ANGLE_THETA = math.radians(127.0) # radians + +# Bead masses (Da) +MARTINI_BEAD_MASS_DA = 72.0 +MARTINI_WATER_MASS = 72.0 # W bead (4 × 18 Da) +MARTINI_ION_MASS_NA = 22.99 +MARTINI_ION_MASS_CL = 35.45 + +# Per-residue bead definitions: (bead_type, bead_label, charge_e, [atom_names]) +# BB bead is always first. Missing atoms fall back to CA automatically. +MARTINI_RESIDUE_BEADS = { + 'GLY': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ], + 'ALA': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('TC', 'SC1', 0.0, ['CB']), # TC1 in Martini 3 + ], + 'VAL': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('SC', 'SC1', 0.0, ['CB', 'CG1', 'CG2']), # SC2 in Martini 3 + ], + 'LEU': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('SC', 'SC1', 0.0, ['CB', 'CG', 'CD1', 'CD2']), # SC3 in Martini 3 + ], + 'ILE': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('SC', 'SC1', 0.0, ['CB', 'CG1', 'CG2', 'CD1']), # SC4 in Martini 3 + ], + 'PRO': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('SC', 'SC1', 0.0, ['CB', 'CG', 'CD']), # SC3 in Martini 3 + ], + 'MET': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('C', 'SC1', 0.0, ['CB', 'CG', 'SD', 'CE']), + ], + 'CYS': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('TC', 'SC1', 0.0, ['CB', 'SG']), # TC4v in Martini 3 + ], + 'SER': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('N', 'SC1', 0.0, ['CB', 'OG']), + ], + 'THR': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('N', 'SC1', 0.0, ['CB', 'OG1', 'CG2']), + ], + 'ASN': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('N', 'SC1', 0.0, ['CB', 'CG', 'OD1', 'ND2']), + ], + 'GLN': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('N', 'SC1', 0.0, ['CB', 'CG', 'CD', 'OE1', 'NE2']), + ], + 'ASP': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('Q', 'SC1', -1.0, ['CB', 'CG', 'OD1', 'OD2']), + ], + 'GLU': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('Q', 'SC1', -1.0, ['CB', 'CG', 'CD', 'OE1', 'OE2']), + ], + 'LYS': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('SC', 'SC1', 0.0, ['CB', 'CG', 'CD']), # SC3 aliphatic linker in Martini 3 + ('Q', 'SC2', +1.0, ['CE', 'NZ']), + ], + 'ARG': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('N', 'SC1', 0.0, ['CB', 'CG', 'CD']), + ('Q', 'SC2', +1.0, ['NE', 'CZ', 'NH1', 'NH2']), + ], + 'HIS': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('TC', 'SC1', 0.0, ['CB', 'CG']), + ('TC', 'SC2', +0.5, ['ND1', 'CD2', 'CE1', 'NE2']), + ], + 'PHE': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('SC', 'SC1', 0.0, ['CB', 'CG', 'CD1', 'CD2']), # SC4 ring bead in Martini 3 + ('SC', 'SC2', 0.0, ['CE1', 'CE2', 'CZ']), # SC4 ring bead in Martini 3 + ], + 'TYR': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('SC', 'SC1', 0.0, ['CB', 'CG', 'CD1', 'CD2']), # SC4 ring bead in Martini 3 + ('SC', 'SC2', 0.0, ['CE1', 'CE2', 'CZ', 'OH']), # SC4 ring bead in Martini 3 + ], + 'TRP': [ + ('BB', 'BB', 0.0, ['N', 'CA', 'C', 'O']), + ('SC', 'SC1', 0.0, ['CB', 'CG', 'CD1', 'NE1']), # SC4 indole in Martini 3 + ('SC', 'SC2', 0.0, ['CD2', 'CE2', 'CZ2', 'CH2']), # SC4 indole in Martini 3 + ('TC', 'SC3', 0.0, ['CE3', 'CZ3']), # TC4 small ring in Martini 3 + ], +} + +# Integer index for each bead type; used to index into the pair tables below. +# ION_NA and ION_CL both map to ION (same TQ5 type in Martini 3). +MARTINI_BEAD_TYPE_INDEX: dict[str, int] = { + 'BB': 0, 'C': 1, 'SC': 2, 'N': 3, 'Q': 4, 'TC': 5, 'W': 6, 'ION': 7, + 'ION_NA': 7, 'ION_CL': 7, +} + +# Explicit pairwise LJ parameters from martini_v3.0.0.itp [nonbond_params]. +# Representative ITP types: BB=P2, C=C3, SC=SC3, N=N2, Q=Q4, TC=TC3, W=W, ION=TQ5. +# Row/column order matches MARTINI_BEAD_TYPE_INDEX: BB, C, SC, N, Q, TC, W, ION. +# Stored as float64 numpy arrays; symmetric (shape 8×8). +_M3_SIGMA = [ + # BB C SC N Q TC W ION + [0.470, 0.470, 0.430, 0.470, 0.470, 0.398, 0.465, 0.395], # BB + [0.470, 0.470, 0.430, 0.470, 0.470, 0.395, 0.465, 0.465], # C + [0.430, 0.430, 0.410, 0.430, 0.430, 0.365, 0.425, 0.484], # SC + [0.470, 0.470, 0.430, 0.470, 0.470, 0.395, 0.465, 0.395], # N + [0.470, 0.470, 0.430, 0.470, 0.470, 0.401, 0.465, 0.405], # Q + [0.398, 0.395, 0.365, 0.395, 0.401, 0.340, 0.393, 0.505], # TC + [0.465, 0.465, 0.425, 0.465, 0.465, 0.393, 0.470, 0.385], # W + [0.395, 0.465, 0.484, 0.395, 0.405, 0.505, 0.385, 0.354], # ION +] +_M3_EPS = [ + # BB C SC N Q TC W ION + [ 4.060, 2.790, 2.160, 3.390, 5.148, 1.450, 4.330, 9.642], # BB + [ 2.790, 3.390, 2.920, 3.240, 2.787, 2.310, 2.420, 2.946], # C + [ 2.160, 2.920, 2.350, 2.770, 1.961, 1.910, 1.800, 2.506], # SC + [ 3.390, 3.240, 2.770, 3.520, 4.477, 2.110, 3.290, 8.247], # N + [ 5.148, 2.787, 1.961, 4.477, 5.950, 1.168, 5.960, 6.360], # Q + [ 1.450, 2.310, 1.910, 2.110, 1.168, 1.510, 1.120, 1.818], # TC + [ 4.330, 2.420, 1.800, 3.290, 5.960, 1.120, 4.650, 11.460], # W + [ 9.642, 2.946, 2.506, 8.247, 6.360, 1.818, 11.460, 1.180], # ION +] +MARTINI_PAIR_SIGMA_NM = np.array(_M3_SIGMA, dtype=np.float64) +MARTINI_PAIR_EPS_KJ = np.array(_M3_EPS, dtype=np.float64) + +# Secondary-structure-dependent backbone BB–BB–BB angles (kJ/(mol·rad²), radians). +# Coil = default; helix (H/G/I) and sheet (E/B) from Martini 3 protein FF. +MARTINI_BB_ANGLE_HELIX = math.radians(96.0) +MARTINI_BB_ANGLE_SHEET = math.radians(134.0) \ No newline at end of file diff --git a/proteogram/v2/__init__.py b/proteogram/v2/__init__.py index 6fb6384..d6dc97a 100644 --- a/proteogram/v2/__init__.py +++ b/proteogram/v2/__init__.py @@ -1,6 +1,7 @@ from .proteogram import ProteogramV2 from .image_similarity import Img2Vec -from .nonbonded_forces import NonBondedForceModel +from .atomistic_nonbonded_forces import AtomisticNonBondedForceModel +from .martini_nonbonded_forces import MartiniNonBondedForceModel -__all__ = ['ProteogramV2', 'Img2Vec', 'NonBondedForceModel'] \ No newline at end of file +__all__ = ['ProteogramV2', 'Img2Vec', 'AtomisticNonBondedForceModel', 'MartiniNonBondedForceModel'] \ No newline at end of file diff --git a/proteogram/v2/nonbonded_forces.py b/proteogram/v2/atomistic_nonbonded_forces.py similarity index 98% rename from proteogram/v2/nonbonded_forces.py rename to proteogram/v2/atomistic_nonbonded_forces.py index 171d3d3..3917915 100644 --- a/proteogram/v2/nonbonded_forces.py +++ b/proteogram/v2/atomistic_nonbonded_forces.py @@ -34,7 +34,7 @@ from ..common.constants import MODIFIED_RESIDUES_TO_STANDARD, SOLVENT_RESIDUES -class NonBondedForceModel: +class AtomisticNonBondedForceModel: """Model for computing non-bonded forces between residues using MD simulation. This class provides a complete pipeline for: @@ -175,18 +175,10 @@ def fix_pdb_file(pdb_path: str) -> io.StringIO: # removeHeterogens. Modified residues are stored as HETATM records in PDB # files, so removeHeterogens would silently delete them if they haven't # already been converted to a standard residue name. - _MODIFIED_TO_STANDARD = { - 'MSE': 'MET', 'FME': 'MET', 'CXM': 'MET', - 'M3L': 'LYS', 'MLY': 'LYS', 'MLZ': 'LYS', 'KCX': 'LYS', 'ALY': 'LYS', 'LLP': 'LYS', - 'CSO': 'CYS', 'CME': 'CYS', 'OCS': 'CYS', 'SEC': 'CYS', 'SMC': 'CYS', 'CSD': 'CYS', - 'SEP': 'SER', 'TPO': 'THR', 'PTR': 'TYR', 'TYS': 'TYR', - 'HYP': 'PRO', 'CGU': 'GLU', 'PCA': 'GLN', 'NEP': 'HIS', 'HIC': 'HIS', - 'BHD': 'ASP', - } for res in fixer.topology.residues(): - if res.name in _MODIFIED_TO_STANDARD: - print(f" INFO: Pre-renaming {res.name} → {_MODIFIED_TO_STANDARD[res.name]} before hetatm removal") - res.name = _MODIFIED_TO_STANDARD[res.name] + if res.name in MODIFIED_RESIDUES_TO_STANDARD: + print(f" INFO: Pre-renaming {res.name} → {MODIFIED_RESIDUES_TO_STANDARD[res.name]} before hetatm removal") + res.name = MODIFIED_RESIDUES_TO_STANDARD[res.name] fixer.removeHeterogens(keepWater=False) fixer.findMissingAtoms() @@ -1342,10 +1334,11 @@ def equilibrate_nvt( time_ps = steps_run * timestep_ps self._log_energy('nvt', time_ps, current_energy) - # Validate energy + # Validate energy — skip first-chunk comparison: the jump from a + # zero-temperature minimized state to 310 K is expected and large. warnings_list = self._validate_energy( - current_energy, 'NVT', - prev_energy=energy_history[-2] if len(energy_history) > 1 else None, + current_energy, 'NVT', + prev_energy=energy_history[-2] if len(energy_history) > 2 else None, n_atoms=n_atoms ) for w in warnings_list: @@ -1358,14 +1351,13 @@ def equilibrate_nvt( print(f" Final potential energy: {final_energy:.1f} kJ/mol " f"({final_energy/n_atoms:.2f} kJ/mol/atom)") - # Check overall trend - if final_energy > initial_energy: - warnings.warn( - f"NVT equilibration: Energy increased from {initial_energy:.1f} to {final_energy:.1f} kJ/mol" - ) - print(f" WARNING: Energy increased during NVT equilibration") + # Energy increase from minimization → NVT is expected (system thermalizes + # from 0 K to 310 K), so only report the trend without raising a warning. + delta_nvt = final_energy - initial_energy + if delta_nvt > 0: + print(f" Energy increased by {delta_nvt:.1f} kJ/mol (normal thermalization)") else: - print(f" Energy decreased by {initial_energy - final_energy:.1f} kJ/mol (good)") + print(f" Energy decreased by {-delta_nvt:.1f} kJ/mol (good)") self._get_positions_and_cleanup() @@ -2335,13 +2327,21 @@ def run_full_pipeline( add_barostat=False, add_calpha_restraint=True) self.minimize_energy() - # Step 3: NPT equilibration with new system including barostat force - print("\n[Step 3/5] NPT equilibration...") - self.equilibrate_npt(steps=npt_steps) - - # Step 4: NVT equilibration using same system without barostat force - print("\n[Step 4/5] NVT equilibration...") - self.equilibrate_nvt_with_warming(steps=nvt_steps) + # Step 3: NVT equilibration using same system without barostat force + print("\n[Step 3/5] NVT equilibration...") + self.equilibrate_nvt(steps=nvt_steps) + + # Step 4: NPT equilibration with new system including barostat force. + # Skipped for small proteins (< 50 residues): tiny simulation boxes + # produce pressure fluctuations large enough to cause NaN coordinates, + # and box-volume equilibration adds no value at that scale. + print("\n[Step 4/5] NPT equilibration...") + n_protein_residues = len(self.protein_residue_indices) + if n_protein_residues < 50: + print(f" Skipping NPT for small protein ({n_protein_residues} residues < 50): " + "pressure coupling is unstable at this scale.") + else: + self.equilibrate_npt(steps=npt_steps) # Step 5: Production MD using new system and simulation with energy calculations print("\n[Step 5/5] Production MD...") diff --git a/proteogram/v2/martini_nonbonded_forces.py b/proteogram/v2/martini_nonbonded_forces.py new file mode 100644 index 0000000..f3e682f --- /dev/null +++ b/proteogram/v2/martini_nonbonded_forces.py @@ -0,0 +1,1289 @@ +"""Martini 3-inspired coarse-grained non-bonded force model for proteogram calculation. + +Multi-bead-per-residue CG model where each amino acid is represented by 1–4 +beads depending on residue type. Inspired by the Martini 3 force field +(Souza et al., Nat. Methods 2021) but simplified for the proteogram pipeline: +explicit CG water (W beads, ~4 H₂O each), Na⁺/Cl⁻ ions for charge +neutralization and 0.15 M physiological salt, no dihedral terms. + +Residue representation +---------------------- +Each residue contributes one backbone (BB) bead at the N/CA/C/O centroid plus +0–3 sidechain beads placed at sidechain atom centroids. Missing atoms fall back +to CA automatically (common in sidechain-stripped or coarse PDBs). + + GLY 1 bead (BB only) + ALA/VAL/LEU/ILE/PRO/MET/CYS/ + SER/THR/ASN/GLN/ASP/GLU 2 beads (BB + SC1) + LYS/ARG/HIS/PHE/TYR 3 beads (BB + SC1 + SC2) + TRP 4 beads (BB + SC1 + SC2 + SC3) + +Bead types and LJ parameters (Martini 3 ITP values, Souza et al. 2021) +----------------------------------------------------------------------- + BB — backbone (P2 coil): σ=0.47 nm, ε=4.06 kJ/mol + C — apolar regular (C1–C5): σ=0.47 nm, ε=3.39 kJ/mol (MET) + SC — apolar small (SC1–SC4): σ=0.41 nm, ε=2.35 kJ/mol (VAL/LEU/ILE/PRO/LYS-SC1/PHE/TYR/TRP rings) + N — polar (N1–N2): σ=0.47 nm, ε=3.52 kJ/mol (SER/THR/ASN/GLN/ARG-SC1) + Q — charged (Q4): σ=0.47 nm, ε=5.95 kJ/mol (ASP/GLU q=−1; LYS/ARG q=+1) + TC — tiny apolar (TC1–TC5): σ=0.34 nm, ε=1.51 kJ/mol (ALA/CYS sidechains; HIS/TRP small rings) + ION — Na⁺/Cl⁻ (TQ5): σ=0.354 nm, ε=1.18 kJ/mol + +Bonded interactions +------------------- + HarmonicBondForce BB–BB inter-residue (k=3800, r0=0.35 nm) + BB–SC intra-residue (k=3800, r0=0.27 nm) + SC–SC intra-residue (k=2500, r0=0.27 nm) + HarmonicAngleForce BB–BB–BB backbone angle (k=40 kJ/(mol·rad²)), + θ₀ per-residue from DSSP: coil=127°, helix=96°, sheet=134° + (falls back to coil for all if DSSP binary unavailable) + +Non-bonded interactions +----------------------- + CustomNonbondedForce LJ-12-6 with explicit pairwise σ/ε from martini_v3.0.0.itp + via Discrete2DFunction (type-index lookup); replaces Lorentz- + Berthelot combining rules, which deviate by up to 9 kJ/mol + for key pairs (BB–ION, SC–Q, TC–Q). + CustomNonbondedForce Coulomb (1/r) for charged beads + Exclusions all intra-residue pairs + 1-2/1-3 backbone BB pairs + +Output format +------------- +Returns 5 NxN matrices (N = number of residues, upper triangle populated): + [vdw_att, vdw_rep, es_att, es_rep, dist_avg] + +B×B bead-level energies are aggregated to N×N by summing over all bead pairs +belonging to each residue pair. The distance matrix uses BB–BB bead distances. +Format is identical to NonBondedForceModel and CGNonBondedForceModel so +ProteogramV2 can swap models without changes to the downstream pipeline. + +Typical speedup vs atomistic: 10–30× (no explicit solvent, 10× larger timestep, +fewer degrees of freedom per residue than all-atom). +""" + +import gc +import io +import linecache +import sys +from collections import defaultdict +from pathlib import Path +from typing import Optional + +import numpy as np + +from openmm.app import ( + PDBFile, Topology, Simulation, Element, StateDataReporter, +) +from openmm import ( + LangevinMiddleIntegrator, MonteCarloBarostat, Platform, System, + HarmonicBondForce, HarmonicAngleForce, CustomNonbondedForce, Vec3, + Discrete2DFunction, +) +from openmm.unit import ( + kelvin, picosecond, femtoseconds, nanometer, + kilojoules_per_mole, dalton, radian, bar, +) +from Bio.PDB.PDBParser import PDBParser + +from ..common.constants import ( + RESIDUE_LIST, + MODIFIED_RESIDUES_TO_STANDARD, + MARTINI_COULOMB_K as _COULOMB_K, + MARTINI_SOLVENT_CUTOFF_NM as _SOLVENT_CUTOFF_NM, + MARTINI_RF_K as _RF_K, + MARTINI_RF_C as _RF_C, + MARTINI_SOLVENT_BUFFER_NM as _SOLVENT_BUFFER_NM, + MARTINI_SALT_CONC_M as _SALT_CONC_M, + MARTINI_BEAD_TYPE_PARAMS as _BEAD_TYPE_PARAMS, + MARTINI_BEAD_TYPE_INDEX as _BEAD_TYPE_INDEX, + MARTINI_PAIR_SIGMA_NM as _PAIR_SIGMA, + MARTINI_PAIR_EPS_KJ as _PAIR_EPS, + MARTINI_BB_BB_K as _BB_BB_K, + MARTINI_BB_BB_R0 as _BB_BB_R0, + MARTINI_BB_SC_K as _BB_SC_K, + MARTINI_BB_SC_R0 as _BB_SC_R0, + MARTINI_SC_SC_K as _SC_SC_K, + MARTINI_SC_SC_R0 as _SC_SC_R0, + MARTINI_BB_ANGLE_K as _BB_ANGLE_K, + MARTINI_BB_ANGLE_THETA as _BB_ANGLE_THETA, + MARTINI_BB_ANGLE_HELIX as _BB_ANGLE_HELIX, + MARTINI_BB_ANGLE_SHEET as _BB_ANGLE_SHEET, + MARTINI_BEAD_MASS_DA as _BEAD_MASS_DA, + MARTINI_WATER_MASS as _WATER_MASS, + MARTINI_ION_MASS_NA as _ION_MASS_NA, + MARTINI_ION_MASS_CL as _ION_MASS_CL, + MARTINI_RESIDUE_BEADS as _RESIDUE_BEADS, +) + + +class MartiniNonBondedForceModel: + """Martini 3-inspired multi-bead CG MD model for residue-residue interaction maps. + + Each residue is represented by 1–4 beads depending on sidechain complexity. + Bead positions are computed as centroids of the corresponding heavy atoms in + the PDB (backbone atoms for BB; sidechain atom subsets for SC beads). + + Bead-level B×B energies are aggregated to residue-level N×N matrices so + the output format matches NonBondedForceModel and CGNonBondedForceModel. + + Attributes: + pdb_path (str): Path to input PDB file. + temperature (Quantity): Simulation temperature. + timestep (Quantity): Integration timestep (default 10 fs). + use_gpu (bool): Use CUDA platform if True. + topology: OpenMM Topology. + positions: Bead positions (list of Vec3). + system (System): OpenMM System. + simulation (Simulation): OpenMM Simulation. + """ + + DEFAULT_TEMPERATURE = 310.15 # K (37 °C, physiological) + DEFAULT_FRICTION = 1.0 # 1/ps + DEFAULT_TIMESTEP = 10.0 # fs — equilibration (conservative while system settles) + DEFAULT_PRODUCTION_TIMESTEP = 20.0 # fs — production (standard Martini 3 W-model timestep) + DEFAULT_NVT_STEPS = 25000 # 250 ps at 10 fs/step + DEFAULT_PRODUCTION_STEPS = 250000 # 5 ns at 20 fs/step + DEFAULT_REPORTING_INTERVAL = 2500 + DEFAULT_ENERGY_INTERVAL = 5000 + + def __init__( + self, + pdb_path: str, + output_dir: Optional[str] = None, + temperature: float = DEFAULT_TEMPERATURE, + timestep: float = DEFAULT_TIMESTEP, + use_gpu: bool = False, + **kwargs, + ): + self.pdb_path = pdb_path + self.output_dir = Path(output_dir) if output_dir else Path(pdb_path).parent + self.temperature = temperature * kelvin + self.timestep = timestep * femtoseconds + self.use_gpu = use_gpu + self.debug = False + + self.topology = None + self.positions = None + self.system = None + self.simulation = None + + # Per-bead parameter arrays — populated by setup_system() + self._residue_names: list[str] = [] + self._n_protein_beads: int = 0 # B_prot; solvent beads follow + self._bead_to_residue: np.ndarray | None = None # (B_prot,) residue index per bead + self._bb_bead_indices: np.ndarray | None = None # (N,) BB bead index per residue + self._bead_type_idx: np.ndarray | None = None # (B_prot,) integer type index + self._bead_sigmas: np.ndarray | None = None # (B_prot,) LJ sigma (nm) [kept for energy extraction] + self._bead_epsilons: np.ndarray | None = None # (B_prot,) LJ epsilon (kJ/mol) + self._bead_charges: np.ndarray | None = None # (B_prot,) charge (e) + self._indicator: np.ndarray | None = None # (N, B_prot) aggregation matrix + self._bead_valid_mask: np.ndarray | None = None # (B_prot, B_prot) non-excluded upper triangle + self._box_lengths: np.ndarray | None = None # (3,) box edge lengths in nm + self._openmm_exclusions: list[tuple[int, int]] = [] + + self.energy_log = { + 'nvt': {'time_ps': [], 'energy_kj': [], 'stage': 'NVT Equilibration'}, + 'production': {'time_ps': [], 'energy_kj': [], 'stage': 'Production'}, + } + + # ── context manager / cleanup ───────────────────────────────────────────── + + def __del__(self): + try: + self.cleanup_all_resources() + except Exception: + pass + + def __enter__(self): + return self + + def __exit__(self, *_): + self.cleanup_all_resources() + return False + + def cleanup_all_resources(self, final_run: bool = False) -> None: + try: + linecache.clearcache() + except Exception: + pass + if self.simulation is not None: + try: + self.simulation.reporters.clear() + except Exception: + pass + sim, self.simulation = self.simulation, None + del sim + for _ in range(3): + gc.collect() + if final_run: + self.system = None + self.topology = None + self.positions = None + if self.use_gpu: + self._clear_cuda_cache() + + def _clear_cuda_cache(self) -> None: + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass + + # ── secondary-structure backbone angles ─────────────────────────────────── + + @staticmethod + def _get_dssp_angles(model, pdb_path: str, residue_seq_info: list) -> list[float]: + """Return per-residue BB–BB–BB equilibrium angle (radians) from DSSP. + + Falls back to the generic coil angle for all residues if DSSP is + unavailable (binary not installed or Bio.PDB.DSSP import fails). + + DSSP code mapping (Martini 3 protein FF): + H/G/I → helix 96° + E/B → sheet 134° + everything else → coil 127° + """ + default = [_BB_ANGLE_THETA] * len(residue_seq_info) + try: + from Bio.PDB.DSSP import DSSP + dssp = DSSP(model, pdb_path) + angles = [] + for chain_id, seq_num in residue_seq_info: + key = (chain_id, (' ', seq_num, ' ')) + if key in dssp: + ss = dssp[key][2] + if ss in ('H', 'G', 'I'): + angles.append(_BB_ANGLE_HELIX) + elif ss in ('E', 'B'): + angles.append(_BB_ANGLE_SHEET) + else: + angles.append(_BB_ANGLE_THETA) + else: + angles.append(_BB_ANGLE_THETA) + return angles + except Exception: + return default + + # ── bead position helper ────────────────────────────────────────────────── + + @staticmethod + def _compute_bead_position(residue, atom_names: list[str]) -> np.ndarray: + """Return centroid (Å) of atom_names in residue, falling back to CA/N/C.""" + coords = [residue[name].get_vector().get_array() + for name in atom_names if name in residue] + if coords: + return np.mean(coords, axis=0) + for fallback in ('CA', 'N', 'C'): + if fallback in residue: + return residue[fallback].get_vector().get_array() + raise KeyError( + f"Residue {residue.resname} {residue.id} has no usable atoms " + f"(tried {atom_names} + CA/N/C fallbacks)" + ) + + # ── system setup ────────────────────────────────────────────────────────── + + def setup_system(self) -> None: + """Build multi-bead Martini topology and force field from the PDB file. + + Steps: + 1. Extract residue heavy-atom positions via Biopython; compute bead + centroids for each bead definition in _RESIDUE_BEADS. + 2. Build protein-only numpy parameter arrays (sigma, epsilon, charge), + the (N, B_prot) indicator matrix, and the 1-2/1-3 exclusion mask. + 3. Solvate: place W beads on a cubic grid (_build_solvent_box), remove + clashes with protein beads. + 4. Add ions: neutralize net protein charge and add 0.15 M NaCl + (_add_ions), replacing randomly selected water beads. + 5. Build OpenMM Topology (protein chain P, solvent chain S, ion chain I) + with periodic box vectors set. + 6. Build OpenMM System — protein particles, water particles, ion particles; + HarmonicBondForce and HarmonicAngleForce for protein bonds/angles; + LJ-12-6 and reaction-field Coulomb as CutoffPeriodic CustomNonbondedForces + covering all particles; MonteCarloBarostat at 1 bar. + """ + parser = PDBParser(QUIET=True) + structure = parser.get_structure("protein", self.pdb_path) + model_struct = structure[0] + + # ── Step 1: collect protein beads from PDB residues ─────────────────── + beads_data: list[dict] = [] + residue_names: list[str] = [] + residue_bead_ranges: list[tuple[int, int]] = [] + bb_bead_indices: list[int] = [] + residue_seq_info: list[tuple[str, int]] = [] # (chain_id, seq_num) for gap detection + + for chain in model_struct.get_chains(): + for res in chain: + raw = res.resname.strip() + resname = MODIFIED_RESIDUES_TO_STANDARD.get(raw, raw) + bead_defs = _RESIDUE_BEADS.get(resname) + if bead_defs is None: + continue # HETATM, solvent, unknown residue — skip + if 'CA' not in res: + continue # no alpha carbon — BB bead position unreliable, skip + + ri = len(residue_names) + residue_names.append(resname) + residue_seq_info.append((chain.id, res.id[1])) + first_bead = len(beads_data) + + for bead_type, bead_label, charge, atom_names in bead_defs: + pos_ang = self._compute_bead_position(res, atom_names) + beads_data.append({ + 'residue_idx': ri, + 'bead_type': bead_type, + 'bead_label': bead_label, + 'charge': charge, + 'pos_nm': pos_ang / 10.0, # Å → nm + }) + + bb_bead_indices.append(first_bead) # BB is always first bead + residue_bead_ranges.append((first_bead, len(beads_data))) + + if not beads_data: + raise ValueError(f"No recognised residues found in {self.pdb_path}") + + # ── Gap filling: insert dummy GLY beads at backbone gaps ────────────── + # When consecutive residues within the same chain are not sequence- + # adjacent (missing loop in crystal structure), insert dummy GLY BB + # beads at linearly interpolated positions to restore chain connectivity. + # Dummy beads: zero charge, BB LJ parameters → minimal energy contribution. + filled_beads_data: list[dict] = [] + filled_residue_names: list[str] = [] + filled_residue_bead_ranges: list[tuple[int, int]] = [] + filled_bb_bead_indices: list[int] = [] + filled_residue_seq_info: list[tuple[str, int]] = [] + filled_is_dummy: list[bool] = [] + n_dummy = 0 + + for ri in range(len(residue_names)): + if ri > 0: + cid_prev, seq_prev = residue_seq_info[ri - 1] + cid_curr, seq_curr = residue_seq_info[ri] + if cid_prev == cid_curr and seq_curr > seq_prev + 1: + n_missing = seq_curr - seq_prev - 1 + pos_a = beads_data[residue_bead_ranges[ri - 1][0]]['pos_nm'] + pos_b = beads_data[residue_bead_ranges[ri][0]]['pos_nm'] + for k in range(1, n_missing + 1): + t = k / (n_missing + 1) + new_ri = len(filled_residue_names) + first_bead = len(filled_beads_data) + filled_beads_data.append({ + 'residue_idx': new_ri, + 'bead_type': 'BB', + 'bead_label': 'BB', + 'charge': 0.0, + 'pos_nm': pos_a + t * (pos_b - pos_a), + }) + filled_residue_names.append('GLY') + filled_residue_seq_info.append((cid_prev, seq_prev + k)) + filled_bb_bead_indices.append(first_bead) + filled_residue_bead_ranges.append((first_bead, first_bead + 1)) + filled_is_dummy.append(True) + n_dummy += 1 + + new_ri = len(filled_residue_names) + lo, hi = residue_bead_ranges[ri] + first_bead = len(filled_beads_data) + for b in beads_data[lo:hi]: + b_copy = dict(b) + b_copy['residue_idx'] = new_ri + filled_beads_data.append(b_copy) + filled_residue_names.append(residue_names[ri]) + filled_residue_seq_info.append(residue_seq_info[ri]) + filled_bb_bead_indices.append(first_bead + (bb_bead_indices[ri] - lo)) + filled_residue_bead_ranges.append((first_bead, len(filled_beads_data))) + filled_is_dummy.append(False) + + if n_dummy: + print(f" Inserted {n_dummy} dummy GLY bead(s) to bridge backbone gaps") + + beads_data = filled_beads_data + residue_names = filled_residue_names + residue_bead_ranges = filled_residue_bead_ranges + bb_bead_indices = filled_bb_bead_indices + residue_seq_info = filled_residue_seq_info + + N = len(residue_names) + B_prot = len(beads_data) + + # ── Step 2: protein-only parameter arrays and exclusion mask ────────── + self._residue_names = residue_names + self._dummy_residue_mask = np.array(filled_is_dummy, dtype=bool) + self._n_protein_beads = B_prot + self._bead_to_residue = np.array([b['residue_idx'] for b in beads_data], dtype=int) + self._bb_bead_indices = np.array(bb_bead_indices, dtype=int) + self._bead_type_idx = np.array([_BEAD_TYPE_INDEX[b['bead_type']] for b in beads_data], dtype=np.int32) + self._bead_sigmas = np.array([_BEAD_TYPE_PARAMS[b['bead_type']][0] for b in beads_data]) + self._bead_epsilons = np.array([_BEAD_TYPE_PARAMS[b['bead_type']][1] for b in beads_data]) + self._bead_charges = np.array([b['charge'] for b in beads_data]) + + # indicator[i, b] = 1 iff bead b belongs to residue i + self._indicator = np.zeros((N, B_prot)) + for b_idx, b in enumerate(beads_data): + self._indicator[b['residue_idx'], b_idx] = 1.0 + + # Bond list (protein only) + bond_list: list[tuple[int, int]] = [] + for ri in range(N): + lo_r, hi_r = residue_bead_ranges[ri] + bead_idxs = list(range(lo_r, hi_r)) + for k in range(len(bead_idxs) - 1): + bond_list.append((bead_idxs[k], bead_idxs[k + 1])) + # BB–BB inter-residue bonds: same chain only. Intra-chain gaps are + # already bridged by dummy GLY above; inter-chain breaks are intentional. + for ri in range(N - 1): + if residue_seq_info[ri][0] == residue_seq_info[ri + 1][0]: + bond_list.append((bb_bead_indices[ri], bb_bead_indices[ri + 1])) + + # 1-2 / 1-3 exclusion pairs + neighbors: dict[int, set[int]] = defaultdict(set) + for i, j in bond_list: + neighbors[i].add(j) + neighbors[j].add(i) + + exclude_set: set[tuple[int, int]] = set() + for b in range(B_prot): + for n1 in list(neighbors[b]): + exclude_set.add((min(b, n1), max(b, n1))) + for n2 in list(neighbors[n1]): + if n2 != b: + exclude_set.add((min(b, n2), max(b, n2))) + + excl_matrix = np.zeros((B_prot, B_prot), dtype=bool) + for i, j in exclude_set: + excl_matrix[i, j] = True + excl_matrix[j, i] = True + np.fill_diagonal(excl_matrix, True) + self._bead_valid_mask = np.triu(~excl_matrix, k=1) + self._openmm_exclusions = list(exclude_set) + + # ── Step 3: solvate ─────────────────────────────────────────────────── + protein_pos_nm = np.array([b['pos_nm'] for b in beads_data]) # (B_prot, 3) + lo = protein_pos_nm.min(axis=0) - _SOLVENT_BUFFER_NM + box_lengths = protein_pos_nm.max(axis=0) + _SOLVENT_BUFFER_NM - lo + protein_pos_box = protein_pos_nm - lo # translate to box frame [0, L) + self._box_lengths = box_lengths + + print(f" Solvation box: {box_lengths[0]:.2f} × {box_lengths[1]:.2f} × " + f"{box_lengths[2]:.2f} nm ({np.prod(box_lengths):.1f} nm³)") + water_pos = self._build_solvent_box(protein_pos_box, box_lengths) + + # ── Step 4: add ions ────────────────────────────────────────────────── + net_charge = float(self._bead_charges.sum()) + water_pos, na_pos, cl_pos = self._add_ions(water_pos, net_charge, box_lengths) + + n_water = len(water_pos) + n_na = len(na_pos) + n_cl = len(cl_pos) + B_total = B_prot + n_water + n_na + n_cl + + # ── Step 5: OpenMM Topology ─────────────────────────────────────────── + topology = Topology() + carbon = Element.getBySymbol('C') + openmm_atoms: list = [] + + # Protein chain + prot_chain = topology.addChain('P') + for ri in range(N): + res = topology.addResidue(residue_names[ri], prot_chain) + lo_r, hi_r = residue_bead_ranges[ri] + for b_idx in range(lo_r, hi_r): + openmm_atoms.append( + topology.addAtom(beads_data[b_idx]['bead_label'], carbon, res)) + for i, j in bond_list: + topology.addBond(openmm_atoms[i], openmm_atoms[j]) + + # Solvent chain (one residue per W bead) + if n_water > 0: + sol_chain = topology.addChain('S') + for _ in range(n_water): + res = topology.addResidue('SOL', sol_chain) + topology.addAtom('W', carbon, res) + + # Ion chain + if n_na > 0 or n_cl > 0: + ion_chain = topology.addChain('I') + for _ in range(n_na): + res = topology.addResidue('NA', ion_chain) + topology.addAtom('NA', carbon, res) + for _ in range(n_cl): + res = topology.addResidue('CL', ion_chain) + topology.addAtom('CL', carbon, res) + + bvx = Vec3(box_lengths[0], 0, 0) * nanometer + bvy = Vec3(0, box_lengths[1], 0) * nanometer + bvz = Vec3(0, 0, box_lengths[2]) * nanometer + topology.setPeriodicBoxVectors((bvx, bvy, bvz)) + self.topology = topology + + # All positions in box frame + pieces = [protein_pos_box, water_pos] + if n_na > 0: + pieces.append(na_pos) + if n_cl > 0: + pieces.append(cl_pos) + all_pos_nm = np.vstack(pieces) + self.positions = [Vec3(*p) * nanometer for p in all_pos_nm] + + # ── Step 6: OpenMM System ───────────────────────────────────────────── + system = System() + system.setDefaultPeriodicBoxVectors(bvx, bvy, bvz) + + for _ in range(B_prot): + system.addParticle(_BEAD_MASS_DA * dalton) + for _ in range(n_water): + system.addParticle(_WATER_MASS * dalton) + for _ in range(n_na): + system.addParticle(_ION_MASS_NA * dalton) + for _ in range(n_cl): + system.addParticle(_ION_MASS_CL * dalton) + + # Harmonic bonds (protein beads only) + bond_force = HarmonicBondForce() + for i, j in bond_list: + li = beads_data[i]['bead_label'] + lj = beads_data[j]['bead_label'] + is_bb_bb = (li == 'BB' and lj == 'BB') + is_bb_sc = (li == 'BB') != (lj == 'BB') + if is_bb_bb: + k_b, r0 = _BB_BB_K, _BB_BB_R0 + elif is_bb_sc: + k_b, r0 = _BB_SC_K, _BB_SC_R0 + else: + k_b, r0 = _SC_SC_K, _SC_SC_R0 + bond_force.addBond(i, j, + r0 * nanometer, + k_b * kilojoules_per_mole / nanometer**2) + system.addForce(bond_force) + + # Backbone angles: BB_{i-1}–BB_i–BB_{i+1}, same chain only. + # Per-residue equilibrium angles from DSSP (helix/sheet/coil); falls + # back to coil 127° for all if DSSP binary is unavailable. + bb_angles = self._get_dssp_angles(model_struct, self.pdb_path, residue_seq_info) + angle_force = HarmonicAngleForce() + for ri in range(1, N - 1): + if not (residue_seq_info[ri - 1][0] + == residue_seq_info[ri][0] + == residue_seq_info[ri + 1][0]): + continue + angle_force.addAngle( + bb_bead_indices[ri - 1], + bb_bead_indices[ri], + bb_bead_indices[ri + 1], + bb_angles[ri] * radian, + _BB_ANGLE_K * kilojoules_per_mole / radian**2, + ) + system.addForce(angle_force) + + # LJ-12-6, CutoffPeriodic with explicit pairwise σ/ε from Martini 3 ITP. + # Each particle carries an integer type index; σ_ij and ε_ij are looked + # up from Discrete2DFunction tables rather than using Lorentz-Berthelot. + # OpenMM stores Discrete2DFunction as f(x,y) = values[x + xsize*y]. + n_types = _PAIR_SIGMA.shape[0] + sig_flat = _PAIR_SIGMA.T.flatten().tolist() # column-major → f(x,y)=values[x+n*y] + eps_flat = _PAIR_EPS.T.flatten().tolist() + lj = CustomNonbondedForce( + "4*eps_tab(itype1,itype2)*((sig_tab(itype1,itype2)/r)^12" + " - (sig_tab(itype1,itype2)/r)^6);" + ) + lj.addTabulatedFunction('sig_tab', Discrete2DFunction(n_types, n_types, sig_flat)) + lj.addTabulatedFunction('eps_tab', Discrete2DFunction(n_types, n_types, eps_flat)) + lj.addPerParticleParameter('itype') + lj.setNonbondedMethod(CustomNonbondedForce.CutoffPeriodic) + lj.setCutoffDistance(_SOLVENT_CUTOFF_NM * nanometer) + + # Coulomb with reaction-field correction, CutoffPeriodic + # E = K * q1*q2 * (1/r + k_rf*r² − c_rf) [kJ/mol, r in nm] + es = CustomNonbondedForce( + f"{_COULOMB_K:.8f}*charge1*charge2*(1/r + {_RF_K:.8f}*r^2 - {_RF_C:.8f});" + ) + es.addPerParticleParameter('charge') + es.setNonbondedMethod(CustomNonbondedForce.CutoffPeriodic) + es.setCutoffDistance(_SOLVENT_CUTOFF_NM * nanometer) + + # Add all particles to force objects + w_idx = float(_BEAD_TYPE_INDEX['W']) + ion_idx = float(_BEAD_TYPE_INDEX['ION']) + for b_idx in range(B_prot): + lj.addParticle([float(self._bead_type_idx[b_idx])]) + es.addParticle([self._bead_charges[b_idx]]) + for _ in range(n_water): + lj.addParticle([w_idx]) + es.addParticle([0.0]) + for _ in range(n_na): + lj.addParticle([ion_idx]) + es.addParticle([+1.0]) + for _ in range(n_cl): + lj.addParticle([ion_idx]) + es.addParticle([-1.0]) + + # Protein-only 1-2/1-3 exclusions + for i, j in self._openmm_exclusions: + lj.addExclusion(i, j) + es.addExclusion(i, j) + + system.addForce(lj) + system.addForce(es) + + # MonteCarloBarostat at 1 atm, update every 25 steps + system.addForce(MonteCarloBarostat(1.0 * bar, self.temperature, 25)) + + self.system = system + + charged = int(np.count_nonzero(self._bead_charges)) + print(f"Martini CG system: {N} residues, {B_prot} protein beads " + f"+ {n_water} W + {n_na} Na⁺ + {n_cl} Cl⁻ = {B_total} total; " + f"{len(bond_list)} bonds, {charged} charged protein beads") + + # ── solvation helpers ───────────────────────────────────────────────────── + + def _build_solvent_box( + self, + protein_pos_box: np.ndarray, + box_lengths: np.ndarray, + min_clash_nm: float = 0.53, + ) -> np.ndarray: + """Place W beads on a cubic grid and return clash-free positions. + + Grid spacing is set to the W–W LJ minimum distance (2^(1/6)·σ_W ≈ + 0.527 nm) so that adjacent water beads start at zero force — placing + them at σ (0.47 nm) puts them in the repulsive region and causes the + minimiser to launch beads. Density is ~6.8 W/nm³, slightly below the + Martini liquid target of ~8.4; NPT equilibration shrinks the box to + reach the correct density. W beads within min_clash_nm of any protein + bead are removed. Clash detection is chunked to cap peak memory. + + Args: + protein_pos_box: Protein bead positions in box-frame nm coords. + box_lengths: Box edge lengths (nm) in x, y, z. + min_clash_nm: Minimum allowed protein–water bead distance (nm). + Default 0.53 nm places water at or beyond the W–protein LJ + minimum, preventing high-energy repulsive starting forces. + + Returns: + water_pos: (W, 3) float64 array of W positions in box frame (nm). + """ + sigma_w = _BEAD_TYPE_PARAMS['W'][0] # 0.47 nm + spacing = 2 ** (1.0 / 6.0) * sigma_w # LJ minimum ≈ 0.527 nm + axes = [np.arange(0.0, box_lengths[d], spacing) for d in range(3)] + gx, gy, gz = np.meshgrid(*axes, indexing='ij') + grid = np.stack([gx.ravel(), gy.ravel(), gz.ravel()], axis=1) # (G, 3) + + valid = np.ones(len(grid), dtype=bool) + chunk = 1000 + for start in range(0, len(grid), chunk): + slc = grid[start:start + chunk] # (c, 3) + diff = slc[:, np.newaxis, :] - protein_pos_box[np.newaxis, :, :] # (c, B, 3) + min_d = np.sqrt((diff ** 2).sum(axis=2)).min(axis=1) # (c,) + valid[start:start + chunk] = min_d >= min_clash_nm + + water_pos = grid[valid] + print(f" Placed {len(water_pos)} W beads " + f"({len(grid) - len(water_pos)} removed for clashes with protein)") + return water_pos + + def _add_ions( + self, + water_pos: np.ndarray, + net_charge: float, + box_lengths: np.ndarray, + salt_conc_M: float = _SALT_CONC_M, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Replace water beads with Na⁺/Cl⁻ for neutralization and 0.15 M NaCl. + + Neutralization ions are placed first to bring the system to zero net + charge, then NaCl pairs are added to reach the target concentration. + Ion positions are chosen by randomly sampling water bead slots (seed=42 + for reproducibility). His residues with charge +0.5 are rounded to the + nearest integer when determining neutralization count. + + Args: + water_pos: W bead positions in box frame (nm). + net_charge: Protein net charge (e); rounded to nearest integer. + box_lengths: Box edge lengths (nm), used to compute box volume. + salt_conc_M: Target NaCl concentration in mol/L. + + Returns: + (final_water, na_pos, cl_pos) — position arrays, dtype float64. + """ + rng = np.random.default_rng(42) + n_water = len(water_pos) + order = rng.permutation(n_water) + used = 0 + + na_indices: list[int] = [] + cl_indices: list[int] = [] + + # Neutralization: positive protein → add Cl⁻; negative → add Na⁺ + net_int = round(net_charge) + n_neutralize_na = max(0, -net_int) + n_neutralize_cl = max(0, net_int) + for _ in range(n_neutralize_na): + na_indices.append(int(order[used])); used += 1 + for _ in range(n_neutralize_cl): + cl_indices.append(int(order[used])); used += 1 + + # Additional NaCl for physiological salt concentration + vol_l = float(np.prod(box_lengths)) * 1e-24 # nm³ → litres + n_pairs = int(round(salt_conc_M * vol_l * 6.02214076e23)) + for _ in range(n_pairs): + if used < n_water: + na_indices.append(int(order[used])); used += 1 + if used < n_water: + cl_indices.append(int(order[used])); used += 1 + + ion_set = set(na_indices + cl_indices) + keep = np.array([i for i in range(n_water) if i not in ion_set]) + final_water = water_pos[keep] if len(keep) else np.empty((0, 3)) + na_pos = water_pos[na_indices] if na_indices else np.empty((0, 3)) + cl_pos = water_pos[cl_indices] if cl_indices else np.empty((0, 3)) + + system_charge = net_int + len(na_indices) - len(cl_indices) + print(f" Net protein charge = {net_int:+d} → " + f"neutralization: {n_neutralize_na} Na⁺, {n_neutralize_cl} Cl⁻") + print(f" NaCl ({salt_conc_M} M, {n_pairs} pairs): " + f"{len(na_indices)} Na⁺ total, {len(cl_indices)} Cl⁻ total " + f"[system net charge = {system_charge:+d}]") + return final_water, na_pos, cl_pos + + # ── simulation helpers ──────────────────────────────────────────────────── + + def _get_barostat(self) -> Optional[MonteCarloBarostat]: + """Return the MonteCarloBarostat from self.system, or None.""" + if self.system is None: + return None + for i in range(self.system.getNumForces()): + force = self.system.getForce(i) + if isinstance(force, MonteCarloBarostat): + return force + return None + + def _get_platform(self) -> Platform: + if self.use_gpu: + try: + return Platform.getPlatformByName('CUDA') + except Exception: + print("CUDA not available, falling back to CPU") + return Platform.getPlatformByName('CPU') + + def _create_new_simulation(self) -> None: + self.cleanup_all_resources(final_run=False) + integrator = LangevinMiddleIntegrator( + self.temperature, + self.DEFAULT_FRICTION / picosecond, + self.timestep, + ) + platform = self._get_platform() + props = {'CudaPrecision': 'mixed'} if platform.getName() == 'CUDA' else {} + self.simulation = Simulation( + self.topology, self.system, integrator, platform, props + ) + self.simulation.context.setPositions(self.positions) + bvs = self.topology.getPeriodicBoxVectors() + if bvs is not None: + self.simulation.context.setPeriodicBoxVectors(*bvs) + + def _resolve_bead_clashes(self, min_sep_nm: float = 0.20) -> int: + """Push apart protein bead pairs closer than min_sep_nm. + + Prevents the L-BFGS minimizer from encountering LJ singularities caused + by sidechain beads falling back to CA (landing on top of the BB bead), + by dummy GLY beads interpolated into a crowded region, or by unusual + crystallographic geometry. Only protein beads are checked — solvent is + on a regular grid and cannot produce coincident pairs. + + Returns the number of pairs resolved. + """ + B_prot = self._n_protein_beads + if B_prot == 0: + return 0 + + full_state = self.simulation.context.getState(getPositions=True) + full_pos = np.asarray( + full_state.getPositions(asNumpy=True).value_in_unit(nanometer), + dtype=np.float64, + ) + del full_state + + pos = full_pos[:B_prot].copy() + rng = np.random.default_rng(99) + n_fixed = 0 + + for i in range(B_prot): + diffs = pos[i + 1:] - pos[i] + dists = np.linalg.norm(diffs, axis=1) + for ci in np.where(dists < min_sep_nm)[0]: + j = i + 1 + ci + dist = dists[ci] + if dist > 1e-10: + direction = diffs[ci] / dist + else: + direction = rng.uniform(-1, 1, 3) + direction /= np.linalg.norm(direction) + pos[j] += (min_sep_nm - dist + 0.02) * direction + n_fixed += 1 + + if n_fixed: + full_pos[:B_prot] = pos + self.positions = [Vec3(*p) * nanometer for p in full_pos] + self.simulation.context.setPositions(self.positions) + + return n_fixed + + def minimize_energy(self, max_iterations: int = 2000) -> None: + """Minimize energy with up to 3 attempts, jittering positions on NaN. + + OpenMM raises an exception (not a return value) when minimization + produces NaN coordinates, so each attempt is wrapped in try/except. + After any failure the simulation is recreated (the context is invalid + after an OpenMM exception) and positions are jittered with increasing + amplitude before retrying. + """ + print("Performing energy minimization...") + + # Pre-minimization: resolve overlapping protein beads. + n_fixed = self._resolve_bead_clashes() + if n_fixed: + print(f" Resolved {n_fixed} close bead pair(s) before minimization") + + # Snapshot positions after clash resolution — this is the restart point. + init_state = self.simulation.context.getState(getPositions=True) + init_pos_nm = np.asarray( + init_state.getPositions(asNumpy=True).value_in_unit(nanometer), + dtype=np.float64, + ) + del init_state + + rng = np.random.default_rng(42) + for attempt in range(3): + failed = False + try: + self.simulation.minimizeEnergy(maxIterations=max_iterations) + state = self.simulation.context.getState( + getPositions=True, enforcePeriodicBox=True + ) + pos_nm = np.asarray( + state.getPositions(asNumpy=True).value_in_unit(nanometer), + dtype=np.float64, + ) + del state + if np.any(np.isnan(pos_nm)): + failed = True + else: + self.positions = [Vec3(*p) * nanometer for p in pos_nm] + except Exception as exc: + if 'NaN' not in str(exc) and 'nan' not in str(exc).lower(): + raise + failed = True + + if not failed: + break + + if attempt < 2: + jitter_amp = 0.05 * (attempt + 1) # 0.05 nm → 0.10 nm + print(f" Minimization attempt {attempt + 1} produced NaN; " + f"jittering positions by ±{jitter_amp:.2f} nm and retrying...") + jittered = init_pos_nm + rng.uniform( + -jitter_amp, jitter_amp, init_pos_nm.shape + ) + self.positions = [Vec3(*p) * nanometer for p in jittered] + self.cleanup_all_resources(final_run=False) + self._create_new_simulation() + else: + raise RuntimeError( + f"Energy minimization failed with NaN after 3 attempts: {self.pdb_path}" + ) + + self.cleanup_all_resources(final_run=False) + print("Energy minimization complete.") + + # ── NVT equilibration ───────────────────────────────────────────────────── + + def equilibrate_nvt( + self, + steps: int = DEFAULT_NVT_STEPS, + report_interval: int = DEFAULT_REPORTING_INTERVAL, + ) -> None: + # Disable barostat so this is a true constant-volume phase. + # The system is under-dense immediately after solvation; allowing + # box rescaling before the temperature is established causes the + # integrator to produce NaN when it encounters the resulting + # large forces. + barostat = self._get_barostat() + if barostat is not None: + barostat.setFrequency(0) + print(f"Running NVT equilibration for {steps} steps...") + self._create_new_simulation() + self.simulation.context.setVelocitiesToTemperature(self.temperature) + self.simulation.reporters.append( + StateDataReporter(sys.stdout, report_interval, + step=True, potentialEnergy=True, + temperature=True, separator='\t') + ) + self.simulation.step(steps) + state = self.simulation.context.getState( + getPositions=True, enforcePeriodicBox=True + ) + self.positions = state.getPositions() + del state + self.cleanup_all_resources(final_run=False) + # Re-enable barostat for the NPT and production phases that follow. + if barostat is not None: + barostat.setFrequency(25) + print("NVT equilibration complete.") + + def equilibrate_npt( + self, + steps: int = DEFAULT_NVT_STEPS, + report_interval: int = DEFAULT_REPORTING_INTERVAL, + ) -> None: + """NPT equilibration — allow box volume to relax under barostat control. + + The MonteCarloBarostat added during setup_system() is active throughout + all phases; this step gives the box volume explicit time to converge + before production data collection begins. + + Args: + steps: Number of integration steps. + report_interval: Steps between StateDataReporter outputs. + """ + print(f"Running NPT equilibration for {steps} steps...") + self._create_new_simulation() + self.simulation.context.setVelocitiesToTemperature(self.temperature) + self.simulation.reporters.append( + StateDataReporter(sys.stdout, report_interval, + step=True, potentialEnergy=True, + temperature=True, volume=True, separator='\t') + ) + self.simulation.step(steps) + state = self.simulation.context.getState( + getPositions=True, getParameterDerivatives=False, + enforcePeriodicBox=True, + ) + self.positions = state.getPositions() + + # Propagate the NPT-equilibrated box back to topology, system defaults, + # and _box_lengths. Without this, _create_new_simulation for production + # resets to the original (over-large) box, placing compressed positions + # in a corner and creating extreme local density on the first MD step. + new_bvs = self.simulation.context.getState( + getPositions=False + ).getPeriodicBoxVectors() + self.topology.setPeriodicBoxVectors((new_bvs[0], new_bvs[1], new_bvs[2])) + self.system.setDefaultPeriodicBoxVectors(*new_bvs) + self._box_lengths = np.array([ + new_bvs[0][0].value_in_unit(nanometer), + new_bvs[1][1].value_in_unit(nanometer), + new_bvs[2][2].value_in_unit(nanometer), + ]) + del state + self.cleanup_all_resources(final_run=False) + print("NPT equilibration complete.") + + # ── production MD ───────────────────────────────────────────────────────── + + def run_production( + self, + steps: int = DEFAULT_PRODUCTION_STEPS, + energy_calc_interval: int = DEFAULT_ENERGY_INTERVAL, + subtract_solvent: bool = False, # no-op — no solvent in CG + ) -> list: + """Run production MD and accumulate frame-averaged residue-level energies. + + At each energy snapshot bead positions are extracted, pairwise bead + energies are computed in numpy, aggregated from B×B to N×N via the + indicator matrix, and accumulated for averaging. + + Args: + steps: Total production steps. + energy_calc_interval: Steps between energy-snapshot frames. + subtract_solvent: No-op; present for API compatibility. + + Returns: + [vdw_att, vdw_rep, es_att, es_rep, dist_avg] — each NxN ndarray, + upper-triangle populated, matching NonBondedForceModel convention. + """ + print(f"Running Martini production MD for {steps} steps...") + self._create_new_simulation() + self.simulation.context.setVelocitiesToTemperature(self.temperature) + + report_interval = max(steps // 10, self.DEFAULT_REPORTING_INTERVAL) + self.simulation.reporters.append( + StateDataReporter(sys.stdout, report_interval, + step=True, potentialEnergy=True, + temperature=True, separator='\t') + ) + + N = len(self._residue_names) + vdw_att_sum = np.zeros((N, N)) + vdw_rep_sum = np.zeros((N, N)) + es_att_sum = np.zeros((N, N)) + es_rep_sum = np.zeros((N, N)) + dist_sum = np.zeros((N, N)) + frame_count = 0 + + timestep_ps = self.timestep.value_in_unit(picosecond) + + for step_start in range(0, steps, energy_calc_interval): + self.simulation.step(energy_calc_interval) + state = self.simulation.context.getState( + getPositions=True, getEnergy=True + ) + pos_nm = np.asarray( + state.getPositions(asNumpy=True).value_in_unit(nanometer), + dtype=np.float64, + ) + current_energy = state.getPotentialEnergy().value_in_unit( + kilojoules_per_mole + ) + del state + + if self.debug: + time_ps = (step_start + energy_calc_interval) * timestep_ps + self.energy_log['production']['time_ps'].append(time_ps) + self.energy_log['production']['energy_kj'].append(current_energy) + + vdw_a, vdw_r, es_a, es_r, dist = self._calc_pairwise_martini_energies(pos_nm) + vdw_att_sum += vdw_a + vdw_rep_sum += vdw_r + es_att_sum += es_a + es_rep_sum += es_r + dist_sum += dist + frame_count += 1 + + if frame_count % 5 == 0: + gc.collect() + + avg = 1.0 / frame_count + # Wrap all bead positions back into the primary periodic box before + # storing them — without enforcePeriodicBox, coordinates accumulate + # unbounded drift across image boundaries and cannot be written to PDB. + state = self.simulation.context.getState( + getPositions=True, enforcePeriodicBox=True + ) + self.positions = state.getPositions() + del state + print("Martini production MD complete.") + + results = [ + vdw_att_sum * avg, vdw_rep_sum * avg, + es_att_sum * avg, es_rep_sum * avg, + dist_sum * avg, + ] + + # Strip dummy-residue rows/columns so the returned N×N matrices match + # the real sequence length expected by ProteogramV2. + if np.any(self._dummy_residue_mask): + real = ~self._dummy_residue_mask + results = [m[np.ix_(real, real)] for m in results] + + return results + + # ── pairwise energy calculation ─────────────────────────────────────────── + + def _calc_pairwise_martini_energies( + self, positions_nm: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Compute pairwise bead energies and aggregate to residue-level N×N. + + Computes all B×B bead-pair interactions using vectorised numpy, then + collapses to N×N via the indicator matrix: + E_residue[i, j] = sum of E_bead[b1, b2] + for all b1 in residue i, b2 in residue j. + + Only non-excluded upper-triangle pairs are populated (see + _bead_valid_mask built in setup_system). The BB–BB distance matrix is + computed separately using only backbone bead positions. + + Args: + positions_nm: (B, 3) float64 array of all bead positions in nm. + + Returns: + (vdw_att, vdw_rep, es_att, es_rep, dist_angstrom) — each NxN ndarray, + upper triangle populated, lower triangle zero. + """ + N = len(self._residue_names) + + # Slice to protein beads only — solvent beads drive dynamics but are + # excluded from the residue-level energy maps. + pos_prot = positions_nm[:self._n_protein_beads] # (B_prot, 3) + + # ── B_prot×B_prot pairwise distances with minimum image convention ──── + diff = pos_prot[:, np.newaxis, :] - pos_prot[np.newaxis, :, :] + if self._box_lengths is not None: + diff -= np.round(diff / self._box_lengths) * self._box_lengths + dist_nm = np.sqrt(np.einsum('ijk,ijk->ij', diff, diff)) + del diff + + # Apply the same 1.1 nm cutoff as the OpenMM CutoffPeriodic forces. + # Without this, the RF term k_rf·r² grows unboundedly beyond the + # cutoff, producing spuriously large energies for distant residue pairs + # and inflating map intensities relative to atomistic. + cutoff_mask = self._bead_valid_mask & (dist_nm < _SOLVENT_CUTOFF_NM) + r_safe = np.where(cutoff_mask, dist_nm, 1.0) + + # ── LJ-12-6 (Lorentz-Berthelot combining rules) ─────────────────────── + sig_ij = 0.5 * (self._bead_sigmas[:, np.newaxis] + self._bead_sigmas[np.newaxis, :]) + eps_ij = np.sqrt(self._bead_epsilons[:, np.newaxis] * self._bead_epsilons[np.newaxis, :]) + sr6 = (sig_ij / r_safe) ** 6 + lj_rep_b = np.where(cutoff_mask, 4.0 * eps_ij * sr6 * sr6, 0.0) + lj_att_b = np.where(cutoff_mask, -4.0 * eps_ij * sr6, 0.0) + del sr6, sig_ij, eps_ij + + # ── Coulomb with reaction-field correction ──────────────────────────── + # Matches the OpenMM CustomNonbondedForce expression exactly: + # E = K * q1*q2 * (1/r + k_rf*r² − c_rf) for r < cutoff + # E = 0 for r >= cutoff + q_ij = self._bead_charges[:, np.newaxis] * self._bead_charges[np.newaxis, :] + es_raw = np.where( + cutoff_mask, + _COULOMB_K * q_ij * (1.0 / r_safe + _RF_K * r_safe ** 2 - _RF_C), + 0.0, + ) + es_att_b = np.where(es_raw < 0, es_raw, 0.0) + es_rep_b = np.where(es_raw > 0, es_raw, 0.0) + del q_ij, es_raw, dist_nm, r_safe + + # ── Aggregate B×B → N×N via indicator matrix ───────────────────────── + # Because beads are ordered by residue and lj_att_b is upper-triangle, + # the aggregation naturally preserves upper-triangle structure: all + # bead pairs b1 < b2 with residue[b1] < residue[b2] land in [i, j] iij', bb_diff, bb_diff)) + del bb_diff + + i_idx, j_idx = np.triu_indices(N, k=3) + dist_mask = np.zeros((N, N), dtype=bool) + dist_mask[i_idx, j_idx] = True + dist_ang = np.where(dist_mask, bb_dist * 10.0, 0.0) + + return vdw_att, vdw_rep, es_att, es_rep, dist_ang + + # ── full pipeline (mirrors NonBondedForceModel.run_full_pipeline) ───────── + + def run_full_pipeline( + self, + npt_steps: int = DEFAULT_NVT_STEPS, + nvt_steps: int = DEFAULT_NVT_STEPS, + production_steps: int = DEFAULT_PRODUCTION_STEPS, + energy_calc_interval: int = DEFAULT_ENERGY_INTERVAL, + return_simulated_pdb: bool = False, + subtract_solvent_energies: bool = False, # no-op for CG + debug: bool = False, + ) -> list: + """Run the complete Martini CG MD pipeline with explicit solvent. + + Mirrors NonBondedForceModel.run_full_pipeline() in signature and return + format. The pipeline includes an NPT equilibration phase to allow the + solvation box volume to relax before production data collection. + + Pipeline: + 1. setup_system — build multi-bead topology, solvate, add ions + 2. minimize_energy — relax initial bead geometry + 3. equilibrate_nvt — thermal equilibration (NVT, barostat active) + 4. equilibrate_npt — box volume equilibration (NPT) + 5. run_production — accumulate frame-averaged residue-level energies + + Args: + npt_steps: NPT equilibration steps (default 25 000 = 250 ps). + nvt_steps: NVT equilibration steps (default 25 000 = 250 ps). + production_steps: Production steps (default 250 000 = 5 ns at 20 fs/step). + energy_calc_interval: Steps between energy-snapshot frames. + return_simulated_pdb: If True, append a BB-bead PDB stream as the + last element of the returned list. + subtract_solvent_energies: Ignored; present for API compatibility. + debug: Log per-frame energies to self.energy_log if True. + + Returns: + [vdw_att, vdw_rep, es_att, es_rep, dist_avg] (5 NxN ndarrays) + or [..., pdb_stream] if return_simulated_pdb=True. + """ + print("=" * 60) + print("Starting Martini CG MD pipeline (explicit solvent)") + print("=" * 60) + self.debug = debug + + print("\n[Step 1/5] Setting up Martini CG system...") + self.setup_system() + + print("\n[Step 2/5] Energy minimization...") + self._create_new_simulation() + self.minimize_energy() + + print("\n[Step 3/5] NVT equilibration...") + self.equilibrate_nvt(steps=nvt_steps) + + print("\n[Step 4/5] NPT equilibration...") + n_real_residues = int((~self._dummy_residue_mask).sum()) + if n_real_residues < 50: + print(f" Skipping NPT for small protein ({n_real_residues} residues < 50): " + "pressure coupling is unstable at this scale.") + else: + self.equilibrate_npt(steps=npt_steps) + + print("\n[Step 5/5] Production MD...") + # Switch to the larger production timestep now that the system is + # thermally and mechanically equilibrated. + self.timestep = self.DEFAULT_PRODUCTION_TIMESTEP * femtoseconds + results = self.run_production( + steps=production_steps, + energy_calc_interval=energy_calc_interval, + ) + + if return_simulated_pdb: + results.append(self._get_bb_pdb_stream()) + + print("\n" + "=" * 60, flush=True) + print("Martini CG pipeline complete!", flush=True) + print("=" * 60, flush=True) + + self.cleanup_all_resources(final_run=True) + return results + + # ── PDB output ──────────────────────────────────────────────────────────── + + def _get_bb_pdb_stream(self) -> io.StringIO: + """Return current BB bead positions as an in-memory CA-only PDB stream.""" + topo = Topology() + chain = topo.addChain() + carbon = Element.getBySymbol('C') + for resname in self._residue_names: + res = topo.addResidue(resname, chain) + topo.addAtom('CA', carbon, res) + + # Extract raw Vec3 values (no units) then re-wrap as ONE Quantity. + # A plain Python list of individual Quantity(Vec3, nm) objects would + # cause PDBFile.writeFile to call np.array() on the list, producing an + # object-dtype array, and the subsequent np.isnan() call raises TypeError. + pos_values = self.positions.value_in_unit(nanometer) + bb_vecs = [pos_values[int(idx)] for idx in self._bb_bead_indices] + + # Wrap backbone bead coordinates into the primary box using Python + # float64 arithmetic. OpenMM's enforcePeriodicBox uses single-precision + # internally, which loses all decimal precision when coordinates have + # drifted by hundreds of millions of box widths, leaving coordinates + # unchanged or incorrectly wrapped. Python's % operator is exact in + # float64 even for values at 10^9 nm scale. + if self._box_lengths is not None: + Lx, Ly, Lz = float(self._box_lengths[0]), float(self._box_lengths[1]), float(self._box_lengths[2]) + bb_vecs = [Vec3(v.x % Lx, v.y % Ly, v.z % Lz) for v in bb_vecs] + + bb_positions = bb_vecs * nanometer # single Quantity(list[Vec3], nm) + + stream = io.StringIO() + PDBFile.writeFile(topo, bb_positions, stream) + stream.seek(0) + return stream diff --git a/proteogram/v2/proteogram.py b/proteogram/v2/proteogram.py index 701c565..67591c0 100644 --- a/proteogram/v2/proteogram.py +++ b/proteogram/v2/proteogram.py @@ -11,7 +11,8 @@ from Bio.PDB.Polypeptide import PPBuilder from ..common.constants import HYDROPHOBICITY_LIST, RESIDUE_LIST, MODIFIED_RESIDUES_TO_STANDARD -from .nonbonded_forces import NonBondedForceModel +from .atomistic_nonbonded_forces import AtomisticNonBondedForceModel +from .martini_nonbonded_forces import MartiniNonBondedForceModel # Ignore PDB construction warnings @@ -42,7 +43,9 @@ def __init__(self, calpha_atom_distance_cutoff=10, sequence_len_lower_cutoff=20, sequence_len_upper_cutoff=1e9, - use_gpu=False): + use_gpu=False, + cg_method=None, + sidechain_completeness_cutoff=0.5): """Initialize the ProteogramV2 instance. Args: @@ -55,9 +58,17 @@ def __init__(self, sequence_len_upper_cutoff (float, optional): Maximum sequence length for valid chains. Defaults to 1e9. use_gpu (bool, optional): Whether to use GPU acceleration. Defaults to False. + cg_method (str | None, optional): Coarse-grained MD method to use. + 'martini' — Martini 3-inspired multi-bead model. + None — full atomistic simulation (default). + sidechain_completeness_cutoff (float, optional): Minimum fraction of + non-GLY residues that must have a CB atom. Structures below this + threshold (e.g. CA-only or backbone-only PDBs) are rejected by + is_valid_chain. Defaults to 0.5. Raises: KeyError: If the specified chain_id is not found in the PDB file. + ValueError: If cg_method is not one of 'martini' or None. """ self.pdb_path = pdb_path self.output_dir = output_dir @@ -82,27 +93,44 @@ def __init__(self, self.sequence = ''.join( [self.allowed_amino_acids[res.resname] for res in self.chain - if res.resname in self.allowed_amino_acids]) + if res.resname in self.allowed_amino_acids and "CA" in res]) self.calpha_atom_distance_cutoff = calpha_atom_distance_cutoff self.sequence_len_lower_cutoff = sequence_len_lower_cutoff self.sequence_len_upper_cutoff = sequence_len_upper_cutoff + _valid = {'martini', None} + if cg_method not in _valid: + raise ValueError(f"cg_method must be one of {_valid}, got {cg_method!r}") self.use_gpu = use_gpu + self.cg_method = cg_method + self.sidechain_completeness_cutoff = sidechain_completeness_cutoff def is_valid_chain(self): - """Check if the chain meets the sequence length criteria. + """Check if the chain meets sequence length and sidechain completeness criteria. Returns: - bool: True if the chain length is within the specified cutoffs, - False otherwise. + bool: True if the chain length is within the specified cutoffs and + at least sidechain_completeness_cutoff fraction of non-GLY + residues have a CB atom, False otherwise. """ seq_len = len(self.sequence) - return (self.sequence_len_lower_cutoff <= seq_len <= self.sequence_len_upper_cutoff) + if not (self.sequence_len_lower_cutoff <= seq_len <= self.sequence_len_upper_cutoff): + return False + if self.sidechain_completeness_cutoff > 0: + non_gly = [res for res in self.chain + if res.resname in self.allowed_amino_acids + and self.allowed_amino_acids[res.resname] != 'G'] + if non_gly: + cb_present = sum(1 for res in non_gly if 'CB' in res) + if cb_present / len(non_gly) < self.sidechain_completeness_cutoff: + return False + return True def calculate_proteogram(self, return_simulated_pdb: bool = False, debug: bool = False, subtract_solvent_energies: bool = True, - memory_efficient: bool = False): + memory_efficient: bool = False, + cg_method: str | None = 'use_instance'): """Calculate the proteogram maps. Computes distance, hydrophobicity, Van der Waals, and electrostatic maps @@ -114,9 +142,14 @@ def calculate_proteogram(self, Defaults to False. debug (bool): If True, print debug information during calculations. Defaults to False. - subtract_solvent_energies (bool): If True, subtract solvent-only - energies from the protein+solvent energies to isolate the protein - contributions. Defaults to False. + subtract_solvent_energies (bool): If True, subtract solvent-only + energies from the protein+solvent energies to isolate the protein + contributions. Ignored for CG methods (no solvent). Defaults to True. + memory_efficient (bool): Lower memory footprint at cost of speed. + Ignored for CG methods. Defaults to False. + cg_method (str | None): Override the instance-level cg_method for + this call. 'martini' or None (atomistic). The sentinel + 'use_instance' (default) falls back to self.cg_method. Returns: tuple: A tuple containing: @@ -127,30 +160,44 @@ def calculate_proteogram(self, - io.StringIO | None: Production PDB structure stream (only if return_simulated_pdb=True). """ - # Initialize the model - model = NonBondedForceModel( - pdb_path=self.pdb_path, - output_dir=self.output_dir, - temperature=310.15, # Kelvin (37 C) - timestep=2.0, # Femtoseconds - use_gpu=self.use_gpu, # Set True for GPU acceleration - memory_efficient=memory_efficient - ) - - energy_calc_interval = 10000 # Default: Calculate energies every 20 ps - if memory_efficient: - energy_calc_interval = 50000 # Calculate energies every 100 ps - - # Run the full pipeline (recommended) - pipeline_result = model.run_full_pipeline( - npt_steps=50000, # 100 ps NPT equilibration - nvt_steps=50000, # 100 ps NVT equilibration - production_steps=500000, # 1 ns production - energy_calc_interval=energy_calc_interval, - return_simulated_pdb=return_simulated_pdb, - debug=debug, - subtract_solvent_energies=subtract_solvent_energies # Subtract solvent-only energies - ) + method = self.cg_method if cg_method == 'use_instance' else cg_method + + if method == 'martini': + model = MartiniNonBondedForceModel( + pdb_path=self.pdb_path, + output_dir=self.output_dir, + temperature=310.15, + use_gpu=self.use_gpu, + ) + pipeline_result = model.run_full_pipeline( + nvt_steps=25000, # 250 ps NVT equilibration + npt_steps=25000, # 250 ps NPT equilibration (box volume) + production_steps=250000, # 5 ns production + energy_calc_interval=5000, + return_simulated_pdb=return_simulated_pdb, + debug=debug, + ) + else: + model = AtomisticNonBondedForceModel( + pdb_path=self.pdb_path, + output_dir=self.output_dir, + temperature=310.15, + timestep=2.0, + use_gpu=self.use_gpu, + memory_efficient=memory_efficient, + ) + energy_calc_interval = 10000 + if memory_efficient: + energy_calc_interval = 50000 + pipeline_result = model.run_full_pipeline( + nvt_steps=50000, + npt_steps=50000, + production_steps=500000, + energy_calc_interval=energy_calc_interval, + return_simulated_pdb=return_simulated_pdb, + debug=debug, + subtract_solvent_energies=subtract_solvent_energies, + ) # Explicit clean-up of OpenMM resources after pipeline completion model.cleanup_all_resources(final_run=True) @@ -165,16 +212,30 @@ def calculate_proteogram(self, vdw_e_att, vdw_e_rep, es_e_att, es_e_rep, disto_map = pipeline_result simulated_pdb = None - # Hydrophobicity map depends on the MD-derived distance matrix - hydro_map = self.calc_hydrophobicity_map(self.sequence, disto_map) + # Hydrophobicity map: always use crystal-structure Cα distances so the + # pattern is consistent across atomistic and CG runs. The MD-derived + # disto_map uses BB bead centroids in CG (shifted ~0.5–1 Å from Cα), + # which changes which pairs fall within the distance cutoff. + hydro_map = self.calc_hydrophobicity_map(self.sequence, self.calc_dist_matrix()) - # Normalize all maps to [0-255] - norm_disto_map, disto_err = self.normalize_map(disto_map) - norm_hydro_map, hydro_err = self.normalize_map(hydro_map) - norm_vdw_att_map, vdw_att_err = self.normalize_map(vdw_e_att) - norm_vdw_rep_map, vdw_rep_err = self.normalize_map(vdw_e_rep) - norm_es_att_map, es_att_err = self.normalize_map(es_e_att) - norm_es_rep_map, es_rep_err = self.normalize_map(es_e_rep) + # Normalize all maps to [0-255]. + # Attractive energy maps (vdw_att, es_att) have values ≤ 0; zero means no + # interaction and would otherwise normalize to 255 (brightest), flooding the + # image with spurious signal. Taking abs() first makes zero → 0 (dark = no + # interaction) and large magnitude → bright, which is the correct convention. + # Repulsive maps (vdw_rep, es_rep) and hydro_map are already ≥ 0 so zero + # naturally normalizes to 0 — no transformation needed for those. + # For CG (Martini) the hard 1.1 nm cutoff creates many exact zeros; clipping + # at the 99th percentile of non-zero values before normalizing spreads the + # dynamic range across the actual interaction region rather than letting a few + # outlier pairs compress everything else toward black. + _pct = 99 if method == 'martini' else None + norm_disto_map, disto_err = self.normalize_map(disto_map, percentile=_pct) + norm_hydro_map, hydro_err = self.normalize_map(hydro_map, percentile=_pct) + norm_vdw_att_map, vdw_att_err = self.normalize_map(np.abs(vdw_e_att), percentile=_pct) + norm_vdw_rep_map, vdw_rep_err = self.normalize_map(vdw_e_rep, percentile=_pct) + norm_es_att_map, es_att_err = self.normalize_map(np.abs(es_e_att), percentile=_pct) + norm_es_rep_map, es_rep_err = self.normalize_map(es_e_rep, percentile=_pct) # Clear the original energy maps to save memory del disto_map, hydro_map, vdw_e_att, vdw_e_rep, es_e_att, es_e_rep @@ -216,11 +277,15 @@ def calculate_proteogram(self, return None, {'Error stacking maps': str(e)} @staticmethod - def normalize_map(arr): + def normalize_map(arr, percentile=None): """Normalize any numpy array to [0-255] using Min-Max linear scaling. Args: arr (numpy.ndarray): Input array to normalize. + percentile (float | None): If set, clip the array at this percentile + of non-zero values before normalizing. Useful for CG maps where + the hard interaction cutoff produces many exact zeros that would + otherwise compress the dynamic range. Defaults to None (no clip). Returns: tuple: A tuple containing: @@ -229,9 +294,20 @@ def normalize_map(arr): """ err = '' try: - arr = ((arr - arr.min()) * (1/(arr.max() - arr.min()) * 255)).astype('uint8') + arr = arr.astype(np.float64) + if percentile is not None: + nonzero = arr[arr > 0] + if len(nonzero) > 0: + clip_val = np.percentile(nonzero, percentile) + arr = np.clip(arr, 0, clip_val) + lo, hi = arr.min(), arr.max() + if lo == hi: + arr = np.zeros_like(arr, dtype=np.uint8) + else: + arr = ((arr - lo) * (255.0 / (hi - lo))).clip(0, 255).astype(np.uint8) except Exception as e: err = f'Problem normalizing map: {e}' + arr = np.zeros_like(arr, dtype=np.uint8) return arr, err def calc_dist_matrix(self): @@ -247,16 +323,10 @@ def calc_dist_matrix(self): """ ca_atoms = [res["CA"] for res in self.chain if "CA" in res] n_residues = len(ca_atoms) - # Initialize a results matrix with zeros distogram = np.zeros((n_residues, n_residues), dtype=np.float64) - - # Assign upper triangle the c-alpha distances (lower triangle remains all 0) for i in range(n_residues): - for j in range(i + 1, n_residues): # Only iterate over unique pairs (upper triangle) - # Use the distance operator overload for Atom objects - distance = ca_atoms[i] - ca_atoms[j] - distogram[i, j] = distance - + for j in range(i + 1, n_residues): + distogram[i, j] = ca_atoms[i] - ca_atoms[j] return distogram def calc_hydrophobicity_map(self, sequence, disto_map): diff --git a/pyproject.toml b/pyproject.toml index 04420f5..1ae40f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "proteogram" -version = "0.0.3" +version = "0.0.4" description = "Protein structure to image manipulation and analysis." requires-python = ">=3.11" dependencies = [ @@ -12,28 +12,28 @@ dependencies = [ "pandas>=3.0", "pdbfixer @ git+https://github.com/openmm/pdbfixer@94cfa4c", "pillow>=12.1", - "torch>=2.2", - "torchvision>=0.17", + "torch>=2.2,<2.11", + "torchvision>=0.17,<0.27", "torchsummary>=1.5", "kmeans-pytorch>=0.3", "tqdm>=4.67", "psutil>=7.2.2", "objgraph>=3.6.2", "PyYAML>=6.0", + "goatools>=1.4", "pyrotein @ git+https://github.com/carbonscott/pyrotein.git@main", + "rcsb-api>=1.7.3", ] [project.optional-dependencies] cuda12 = [ "openmm==8.4.0", - "nvidia-cuda-nvcc-cu12==12.9.86", - "openmm-cuda-12==8.4.0.post2" + "openmm-cuda-12==8.4.0.post2", ] test = ["pytest", "pytest-cov"] notebook = [ "jupyterlab>=4.2.5", "nglview>=3.1.4", - "rcsbsearchapi>=1.5.1" ] [tool.setuptools.packages.find] diff --git a/scripts/utilities/download_pdb_before_date.py b/scripts/utilities/download_pdb_before_date.py new file mode 100644 index 0000000..a343a3b --- /dev/null +++ b/scripts/utilities/download_pdb_before_date.py @@ -0,0 +1,176 @@ +""" +Download all PDB protein structures deposited before a given date from RCSB PDB. + +Usage: + python download_pdb_before_date.py --before 2020-01-01 --output-dir ./pdb_structures + python download_pdb_before_date.py --before 2020-01-01 --output-dir ./pdb_structures --format cif --workers 16 +""" +import argparse +import os +import sys +import time +import requests +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from tqdm import tqdm + + +RCSB_SEARCH_URL = "https://search.rcsb.org/rcsbsearch/v2/query" +RCSB_DOWNLOAD_URL = "https://files.rcsb.org/download/{pdb_id}.{fmt}" +MAX_ROWS_PER_REQUEST = 10000 + + +def fetch_pdb_ids_before_date(cutoff_date: str) -> list[str]: + """Return PDB entry IDs for experimental protein-only structures deposited before cutoff_date.""" + query = { + "query": { + "type": "group", + "logical_operator": "and", + "nodes": [ + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "rcsb_accession_info.deposit_date", + "operator": "less", + "value": f"{cutoff_date}T00:00:00Z", + "negation": False, + }, + }, + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "rcsb_entry_info.selected_polymer_entity_types", + "operator": "in", + "value": ["Protein (only)", "Protein/NA"], + "negation": False, + }, + }, + { + "type": "terminal", + "service": "text", + "parameters": { + "attribute": "rcsb_entry_info.structure_determination_methodology", + "operator": "exact_match", + "value": "experimental", + "negation": False, + }, + }, + ], + }, + "return_type": "entry", + "request_options": { + "paginate": {"start": 0, "rows": MAX_ROWS_PER_REQUEST}, + "results_verbosity": "minimal", + }, + } + + pdb_ids = [] + start = 0 + total = None + + print(f"Querying RCSB PDB for entries deposited before {cutoff_date} ...") + with tqdm(unit=" entries", desc="Fetching IDs") as pbar: + while True: + query["request_options"]["paginate"]["start"] = start + resp = requests.post(RCSB_SEARCH_URL, json=query, timeout=60) + resp.raise_for_status() + data = resp.json() + + if total is None: + total = data.get("total_count", 0) + pbar.total = total + + batch = [hit["identifier"] for hit in data.get("result_set", [])] + pdb_ids.extend(batch) + pbar.update(len(batch)) + + start += len(batch) + if not batch or start >= total: + break + + return pdb_ids + + +def download_structure(pdb_id: str, output_dir: Path, fmt: str, session: requests.Session) -> tuple[str, bool, str]: + """Download a single structure file. Returns (pdb_id, success, error_msg).""" + url = RCSB_DOWNLOAD_URL.format(pdb_id=pdb_id.lower(), fmt=fmt) + dest = output_dir / f"{pdb_id.lower()}.{fmt}" + + if dest.exists(): + return pdb_id, True, "already exists" + + try: + resp = session.get(url, timeout=30) + if resp.status_code == 404: + return pdb_id, False, "404 not found" + resp.raise_for_status() + dest.write_bytes(resp.content) + return pdb_id, True, "" + except Exception as e: + return pdb_id, False, str(e) + + +def download_all(pdb_ids: list[str], output_dir: Path, fmt: str, workers: int) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + failed = [] + with requests.Session() as session, \ + ThreadPoolExecutor(max_workers=workers) as executor, \ + tqdm(total=len(pdb_ids), unit=" files", desc="Downloading") as pbar: + + futures = { + executor.submit(download_structure, pid, output_dir, fmt, session): pid + for pid in pdb_ids + } + for future in as_completed(futures): + pdb_id, success, msg = future.result() + if not success and msg != "already exists": + failed.append((pdb_id, msg)) + pbar.update(1) + + print(f"\nDownloaded {len(pdb_ids) - len(failed)} / {len(pdb_ids)} structures to {output_dir}") + if failed: + fail_log = output_dir / "failed_downloads.txt" + fail_log.write_text("\n".join(f"{pid}\t{msg}" for pid, msg in failed)) + print(f" {len(failed)} failures logged to {fail_log}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Download PDB structures deposited before a given date.") + parser.add_argument("--before", required=True, metavar="YYYY-MM-DD", + help="Download entries deposited strictly before this date.") + parser.add_argument("--output-dir", required=True, metavar="DIR", + help="Directory to save downloaded structure files.") + parser.add_argument("--format", dest="fmt", choices=["pdb", "cif"], default="pdb", + help="File format to download (default: pdb).") + parser.add_argument("--workers", type=int, default=8, + help="Number of parallel download threads (default: 8).") + parser.add_argument("--ids-file", metavar="FILE", + help="Optional: path to a file with one PDB ID per line. " + "Skips the RCSB query and downloads only these IDs.") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + output_dir = Path(args.output_dir) + + if args.ids_file: + pdb_ids = Path(args.ids_file).read_text().split() + print(f"Loaded {len(pdb_ids)} IDs from {args.ids_file}") + else: + pdb_ids = fetch_pdb_ids_before_date(args.before) + print(f"Found {len(pdb_ids)} entries deposited before {args.before}") + + ids_cache = output_dir / f"pdb_ids_before_{args.before}.txt" + output_dir.mkdir(parents=True, exist_ok=True) + ids_cache.write_text("\n".join(pdb_ids)) + print(f"ID list saved to {ids_cache}") + + if not pdb_ids: + print("No entries to download. Exiting.") + sys.exit(0) + + download_all(pdb_ids, output_dir, args.fmt, args.workers) diff --git a/scripts/utilities/split_pdb_into_chains.py b/scripts/utilities/split_pdb_into_chains.py new file mode 100644 index 0000000..0d9b77d --- /dev/null +++ b/scripts/utilities/split_pdb_into_chains.py @@ -0,0 +1,97 @@ +""" +Split downloaded PDB files into per-chain monomer files. +Outputs {PDBID}_{CHAIN}.pdb files to pdb_monomers_dir. + +The chain ID encoded in the filename flows downstream to: + - create_pdb_annotation_file.py (chain-specific GO lookup) + - create_v2_proteograms.py (chain_id = bname[5] still works since + positions 0-3 = PDB ID, 4 = '_', 5 = chain) +""" +import glob +import os +import warnings + +from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Select +from Bio.PDB.PDBParser import PDBConstructionWarning +from Bio.PDB.Polypeptide import is_aa + +from proteogram.common import read_yaml + + +warnings.filterwarnings("ignore", category=PDBConstructionWarning) + + +class ProteinChainSelect(Select): + """Keep only standard amino acid residues from one chain.""" + def __init__(self, chain_id): + self.chain_id = chain_id + + def accept_chain(self, chain): + return chain.get_id() == self.chain_id + + def accept_residue(self, residue): + return is_aa(residue, standard=True) + + +def is_protein_chain(chain): + return any(is_aa(res, standard=True) for res in chain) + + +def pdb_id_from_filename(basename): + """Handle RCSB naming conventions: pdb5wsu.ent → 5WSU, 5wsu.pdb → 5WSU.""" + noext = basename.rsplit('.', 1)[0].lower() + if noext.startswith('pdb') and len(noext) == 7: + return noext[3:].upper() + return noext[:4].upper() + + +if __name__ == '__main__': + config = read_yaml('config.yml') + pdb_download_dir = config['pdb_download_dir'] + pdb_monomers_dir = config['pdb_monomers_dir'] + + os.makedirs(pdb_monomers_dir, exist_ok=True) + + pdb_parser = PDBParser(PERMISSIVE=1) + cif_parser = MMCIFParser(QUIET=True) + io = PDBIO() + + pdb_files = [] + for ext in ('*.pdb', '*.ent', '*.cif', '*.mmcif'): + pdb_files.extend( + glob.glob(os.path.join(pdb_download_dir, '**', ext), recursive=True)) + + print(f'Found {len(pdb_files)} structure files to split.') + written, skipped = 0, 0 + + for pdb_file in pdb_files: + basename = os.path.basename(pdb_file) + ext = basename.rsplit('.', 1)[-1].lower() + pdb_id = pdb_id_from_filename(basename) + + try: + parser = cif_parser if ext in ('cif', 'mmcif') else pdb_parser + structure = parser.get_structure(pdb_id, pdb_file) + except Exception as e: + print(f'Failed to parse {pdb_file}: {e}') + skipped += 1 + continue + + model = next(iter(structure)) + io.set_structure(structure) + + for chain in model: + chain_id = chain.get_id().strip() + if not chain_id or not is_protein_chain(chain): + continue + out_file = os.path.join(pdb_monomers_dir, f'{pdb_id}_{chain_id}.pdb') + if os.path.exists(out_file): + continue + try: + io.save(out_file, ProteinChainSelect(chain_id)) + written += 1 + except Exception as e: + print(f'Failed to write {out_file}: {e}') + skipped += 1 + + print(f'Written {written} monomer files, skipped {skipped}.') diff --git a/scripts/v2/config.example.yml b/scripts/v2/config.example.yml index 06c5a9f..d580f42 100644 --- a/scripts/v2/config.example.yml +++ b/scripts/v2/config.example.yml @@ -1,16 +1,23 @@ # for create_proteograms.py and create_v2_proteograms.py limit_file: '' # optional, limit to these structures, .ent file base name one per line scope_structures_dir: '' # input structures, .ent/.pdb file format -eval_proteograms_dir: '' # input proteogram dir for eval set all_proteograms_dir: '' # input proteogram dir for all structures -eval_structures_dir: '' # input structures for eval set, .ent/.pdb file format +cg_method: 'martini' # coarse-grained (CG) method: 'martini' or null for full atomistic # for measure_similarity_single_domain.py proteograms_dir_single_search: '' # dir of proteograms to search against -# for create_annotation_file.py -annot_file: './/ProteogramData_SCOP_RCSB_PDBe_AnnotationsLookup.tsv' +# for split_pdb_into_chains.py +pdb_download_dir: '../data/pdb_download' # downloaded full PDB files +pdb_monomers_dir: '../data/pdb_monomers' # output: {PDBID}_{CHAIN}.pdb per chain + +# for create_pdb_annotation_file.py (PDB-based pipeline, no SCOPe required) +# for create_scope_annotation_file.py / create_annotation_file.py (SCOPe pipeline) +annot_file: './/ProteogramData_RCSB_PDBe_AnnotationsLookup.tsv' fasta_style_file: './/structure_based_seqs.fa' +# GO OBO files — downloaded automatically from geneontology.org if absent +go_obo_file: '../data/go/go-basic.obo' +goslim_obo_file: '../data/go/goslim_generic.obo' # for make_training_data.py scope_level: '' # family, superfamily, fold or class diff --git a/scripts/v2/create_annotation_file.py b/scripts/v2/create_annotation_file.py deleted file mode 100644 index 41ffe1f..0000000 --- a/scripts/v2/create_annotation_file.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Using a list of SCOPe-named structure files, look up key annotations and -create a tab-delimited file to hold those annotations for lookup. Uses -local SCOPe database files and RCSB PDB API as well as PDBe API calls. -""" -import warnings -import glob -import os -import requests -import json -import pandas as pd - -from Bio.PDB.PDBParser import PDBParser, PDBConstructionWarning -from Bio.PDB.Polypeptide import PPBuilder -from Bio.SCOP import Scop - -from proteogram.common import read_yaml - - -# Ignore PDB construction warnings -warnings.filterwarnings("ignore", category=PDBConstructionWarning) - -def get_sequence(pdb_path): - """Get a protein sequence from a PDB file""" - seq = '' - try: - p = PDBParser(PERMISSIVE=0) - structure = p.get_structure('xyz', pdb_path) - ppb = PPBuilder() - for pp in ppb.build_peptides(structure): - seq += pp.get_sequence() - except Exception: - seq = '' - return seq - -if __name__ == '__main__': - - config = read_yaml('config.yml') - limit_file = config['limit_file'] - structures_dir = config['scope_structures_dir'] - annot_file = config['annot_file'] - fasta_style_file = config['fasta_style_file'] - scope_cla_handle = config['scope_cla_file'] - scope_des_handle = config['scope_des_file'] - scope_hie_handle = config['scope_hie_file'] - - limit_to_these_structs = [] - if limit_file: - with open(limit_file, 'r') as f: - for line in f: - limit_to_these_structs.append( - os.path.basename(line.strip()).replace('.ent','')) - - # Create a Scop object to get fold, superfamily and family - scop = Scop(cla_handle=open(scope_cla_handle, 'r'), - des_handle=open(scope_des_handle, 'r'), - hie_handle=open(scope_hie_handle, 'r')) - - pdb_files = glob.glob(os.path.join(structures_dir, '**', '*'), recursive=True) - - annot_data = [] - for_fasta = {} - for pdb_file in pdb_files: - if not os.path.isfile(pdb_file): - continue - - # Get the scope id from structure filename - bname = os.path.basename(pdb_file).split('.') - bname = '.'.join(bname[:-1]) - - # Check if we have a restricted list and if pdb file is in it - if limit_file: - if bname not in limit_to_these_structs: - continue - - # Sequence info - seq = get_sequence(pdb_file) - - # SCOPe info - try: - pdb_id = bname[1:5].upper() - chain = bname[5].upper() - pdb_id_chain = pdb_id + '_' + chain - prot_file = f'{pdb_id}_{chain}.jpg' - except: - print(f'Problem with filename {os.path.basename(pdb_file)}') - continue - try: - scop_entry = scop.getDomainBySid(bname) - # Parse out info for our dataframe - sccs = scop_entry.sccs - sccs_spl = sccs.split('.') - cls, fold, sfam, fam = sccs_spl[0], '.'.join(sccs_spl[:2]), '.'.join(sccs_spl[:3]), sccs - except: - cls, fold, sfam, fam = '', '', '', '' - - # Gene Ontology - try: - go_response = requests.get(f'https://www.ebi.ac.uk/pdbe/graph-api/mappings/go/{pdb_id}') - go_response = json.dumps(go_response.json()) - except: - go_response = '' - - # RCSB Data API annotations for entry - deposit_date = '' - experimental_method = '' - molecular_weight = '' - disulfide_bond_count = '' - entity_count = '' - try: - rcsb_response = requests.get(f'https://data.rcsb.org/rest/v1/core/entry/{pdb_id}') - rcsb_response = rcsb_response.json() - deposit_date = rcsb_response['rcsb_accession_info']['deposit_date'] - experimental_method = rcsb_response['rcsb_entry_info']['experimental_method'] - molecular_weight = rcsb_response['rcsb_entry_info']['molecular_weight'] - disulfide_bond_count = rcsb_response['rcsb_entry_info']['disulfide_bond_count'] - protein_entity_count = rcsb_response['rcsb_entry_info']['polymer_entity_count_protein'] - except: - deposit_date = '' - experimental_method = '' - molecular_weight = '' - disulfide_bond_count = '' - protein_entity_count = '' - - # RCSB Data API annotations for uniprot - get if is transmembrane - is_tm = False - tm_cnts = 0 - try: - rcsb_response = requests.get(f'https://data.rcsb.org/rest/v1/core/uniprot/{pdb_id}') - rcsb_response = rcsb_response.json() - # Get first entry and uniprot feature info (locations and sequence indices) - rcsb_uniprot_features = rcsb_response[0]['rcsb_uniprot_feature'] - except Exception as excp: - rcsb_uniprot_features = [] - - for entry in rcsb_uniprot_features: - # Get TM regions - if 'type' in entry and entry['type'] == 'TRANSMEMBRANE_REGION': - is_tm = True - tm_cnts+=1 - - row = [bname, - os.path.basename(pdb_file), - prot_file, - pdb_id, - chain, - pdb_id_chain, - cls, - fold, - sfam, - fam, - len(seq), - deposit_date, - experimental_method, - molecular_weight, - disulfide_bond_count, - protein_entity_count, - is_tm, - tm_cnts, - go_response, - seq] - - annot_data.append(row) - - fasta_style_id = f'>{pdb_id_chain}|{bname}|{fam}' - for_fasta[fasta_style_id] = seq - - annot_df = pd.DataFrame(annot_data) - annot_df.columns = ['SCOPeID', - 'PDBFileName', - 'ProteogramFileName', - 'PDBId', - 'ChainId', - 'PDBAndChainId', - 'SCOPeClass', - 'SCOPeFold', - 'SCOPeSuperfamily', - 'SCOPeFamily', - 'PDBSequenceLength', - 'PDBDepositDate', - 'PDBExperimentalMethod', - 'PDBMolecularWeight', - 'PDBDisulfideBond', - 'PDBProteinEntityCount', - 'PDBIsTransmembrane', - 'PDBTransmembraneRegionCounts', - 'PDBeGOAnnotation', - 'PDBSequence'] - - # Save the results - try: - annot_df.to_csv( - os.path.join(annot_file), - sep='\t', - index=False) - except Exception as e: - print(f'Problem saving to specific location: {e}, so saving in cwd.') - annot_df.to_csv( - os.path.join('.', os.path.basename(annot_file)), - sep='\t', - index=False) - - try: - f = open(fasta_style_file, 'w') - except Exception as e: - print(f'Problem saving to specific location: {e}, so saving in cwd.') - f = open(os.path.join('.', os.path.basename(fasta_style_file), 'w')) - - # Write a fasta-style file with the protein sequences from all structures processed - for entry in for_fasta: - f.write(entry + '\n' + for_fasta[entry] + '\n') - f.close() - diff --git a/scripts/v2/create_balanced_scope_train_eval_lists.py b/scripts/v2/create_balanced_scope_train_eval_lists.py index 2fae77b..bf0650d 100644 --- a/scripts/v2/create_balanced_scope_train_eval_lists.py +++ b/scripts/v2/create_balanced_scope_train_eval_lists.py @@ -19,6 +19,13 @@ [--seed 42] \ [--exclude-classes e,f,g] + Or without an eval split: + python create_balanced_scope_eval_list.py \ + --lst-file path/to/cdhits_result.lst \ + --lookup-tsv path/to/annotations.tsv \ + --train-output train_set.txt \ + --no-eval + Class column options: - SCOPeClass: Major structural class (e.g., a, b, c, d) - SCOPeFold: Fold level classification @@ -269,19 +276,26 @@ def main(): '--eval-fraction', '-e', type=float, default=0.2, - help='Fraction of data to use for evaluation (default: 0.2)' + help='Fraction of data to use for evaluation (default: 0.2); ignored with --no-eval' ) - + + parser.add_argument( + '--no-eval', + action='store_true', + help='Skip eval split; all sampled proteins go to the training set' + ) + parser.add_argument( '--train-output', required=True, help='Path to output file for training set (one identifier per line)' ) - + parser.add_argument( '--eval-output', - required=True, - help='Path to output file for evaluation set (one identifier per line)' + default=None, + help='Path to output file for evaluation set (one identifier per line); ' + 'required unless --no-eval is set' ) parser.add_argument( @@ -321,6 +335,9 @@ def main(): args = parser.parse_args() + if not args.no_eval and args.eval_output is None: + parser.error('--eval-output is required unless --no-eval is set') + exclude_classes = set() if args.exclude_classes: exclude_classes = {c.strip() for c in args.exclude_classes.split(',')} @@ -351,56 +368,65 @@ def main(): seed=args.seed ) - # Split into train and eval sets - print(f"\nSplitting into train ({1-args.eval_fraction:.0%}) and eval ({args.eval_fraction:.0%}) sets...") - train_ids, eval_ids, split_counts = split_train_eval( - sampled_by_class=sampled_by_class, - eval_fraction=args.eval_fraction, - seed=args.seed - ) - - # Print summary - print(f"\nClass distribution (total / train / eval):") - for scope_class in sorted(class_counts.keys()): - total = class_counts[scope_class] - train_n, eval_n = split_counts[scope_class] - print(f" {scope_class}: {total} / {train_n} / {eval_n}") - - print(f"\nTotal selected: {sum(class_counts.values())}") - print(f" Training set: {len(train_ids)}") - print(f" Evaluation set: {len(eval_ids)}") - + if args.no_eval: + # All sampled proteins go to training + train_ids = [sid for ids in sampled_by_class.values() for sid in ids] + + print(f"\nClass distribution (total / train):") + for scope_class in sorted(class_counts.keys()): + print(f" {scope_class}: {class_counts[scope_class]} / {class_counts[scope_class]}") + + print(f"\nTotal selected: {sum(class_counts.values())}") + print(f" Training set: {len(train_ids)} (no eval split)") + else: + # Split into train and eval sets + print(f"\nSplitting into train ({1-args.eval_fraction:.0%}) and eval ({args.eval_fraction:.0%}) sets...") + train_ids, eval_ids, split_counts = split_train_eval( + sampled_by_class=sampled_by_class, + eval_fraction=args.eval_fraction, + seed=args.seed + ) + + print(f"\nClass distribution (total / train / eval):") + for scope_class in sorted(class_counts.keys()): + total = class_counts[scope_class] + train_n, eval_n = split_counts[scope_class] + print(f" {scope_class}: {total} / {train_n} / {eval_n}") + + print(f"\nTotal selected: {sum(class_counts.values())}") + print(f" Training set: {len(train_ids)}") + print(f" Evaluation set: {len(eval_ids)}") + # Save training set train_path = Path(args.train_output) train_path.parent.mkdir(parents=True, exist_ok=True) - + if args.split_train and args.split_train > 1: - # Split training set into multiple files created_files = split_into_files(train_ids, train_path, args.split_train) print(f"\nSplit training set into {len(created_files)} files:") for f in created_files: print(f" {f}") else: - # Save as single file with open(train_path, 'w') as f: for identifier in train_ids: f.write(f"{identifier}\n") print(f"\nSaved training set to: {args.train_output}") - - # Save evaluation set - eval_path = Path(args.eval_output) - eval_path.parent.mkdir(parents=True, exist_ok=True) - if args.split_eval and args.split_eval > 1: - created_files = split_into_files(eval_ids, eval_path, args.split_eval) - print(f"\nSplit evaluation set into {len(created_files)} files:") - for f in created_files: - print(f" {f}") - else: - with open(eval_path, 'w') as f: - for identifier in eval_ids: - f.write(f"{identifier}\n") - print(f"Saved evaluation set to: {args.eval_output}") + if not args.no_eval: + # Save evaluation set + eval_path = Path(args.eval_output) + eval_path.parent.mkdir(parents=True, exist_ok=True) + + if args.split_eval and args.split_eval > 1: + created_files = split_into_files(eval_ids, eval_path, args.split_eval) + print(f"\nSplit evaluation set into {len(created_files)} files:") + for f in created_files: + print(f" {f}") + else: + with open(eval_path, 'w') as f: + for identifier in eval_ids: + f.write(f"{identifier}\n") + print(f"Saved evaluation set to: {args.eval_output}") if __name__ == '__main__': diff --git a/scripts/v2/create_pdb_annotation_file.py b/scripts/v2/create_pdb_annotation_file.py new file mode 100644 index 0000000..1e0a41f --- /dev/null +++ b/scripts/v2/create_pdb_annotation_file.py @@ -0,0 +1,294 @@ +""" +Annotate per-chain monomer PDB files ({PDBID}_{CHAIN}.pdb) with: + - GO Molecular_function, Biological_process, Cellular_component (chain-specific) + - GO slim annotations via goslim_generic (comparable high-level labels) + - RCSB entry-level metadata (deposit date, method, molecular weight, etc.) + - UniProt transmembrane annotation (entry-level) + +Input files must follow {PDBID}_{CHAIN}.pdb naming from split_pdb_into_chains.py. +GO data is fetched once per PDB entry and cached, so chains from the same +structure share one API call. A small delay is inserted between unique PDB +entries to respect API rate limits. +""" +import glob +import os +import time +import warnings + +import pandas as pd +import requests +from Bio.PDB.PDBParser import PDBConstructionWarning, PDBParser +from Bio.PDB.Polypeptide import PPBuilder +from goatools.mapslim import mapslim as _mapslim +from goatools.obo_parser import GODag + +from proteogram.common import read_yaml + + +warnings.filterwarnings("ignore", category=PDBConstructionWarning) + +_API_DELAY = 0.1 # seconds between unique PDB ID API calls +_GO_OBO_URL = 'http://current.geneontology.org/ontology/go-basic.obo' +_GOSLIM_OBO_URL = 'http://current.geneontology.org/ontology/subsets/goslim_generic.obo' + + +def get_sequence(pdb_path): + seq = '' + try: + p = PDBParser(PERMISSIVE=0) + structure = p.get_structure('xyz', pdb_path) + ppb = PPBuilder() + for pp in ppb.build_peptides(structure): + seq += str(pp.get_sequence()) + except Exception: + pass + return seq + + +def parse_pdb_id_and_chain(basename): + """Extract PDB ID and chain from filenames like '5WSU_A.pdb'.""" + noext = basename.rsplit('.', 1)[0] # '5WSU_A' + parts = noext.split('_', 1) + pdb_id = parts[0].upper() + chain_id = parts[1] if len(parts) > 1 else '' + return pdb_id, chain_id + + +def download_if_missing(url, path): + if os.path.exists(path): + return + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + print(f'Downloading {os.path.basename(path)} from Gene Ontology...') + r = requests.get(url, timeout=120) + r.raise_for_status() + with open(path, 'wb') as fh: + fh.write(r.content) + + +def load_go_dags(go_obo_path, goslim_obo_path): + """Load full GO and GO slim DAGs, auto-downloading OBO files if absent.""" + download_if_missing(_GO_OBO_URL, go_obo_path) + download_if_missing(_GOSLIM_OBO_URL, goslim_obo_path) + go_dag = GODag(go_obo_path) + goslim_dag = GODag(goslim_obo_path) + return go_dag, goslim_dag + + +def map_to_go_slim(go_ids_str, go_dag, goslim_dag): + """Map pipe-delimited specific GO IDs to their direct GO slim ancestors. + + Uses goslim_generic, which sits at a depth that is comparable across + proteins without being too specific — suitable as a classification label. + Returns a pipe-delimited string of slim GO IDs. + """ + if not go_ids_str: + return '' + slim_terms = set() + for go_id in go_ids_str.split('|'): + go_id = go_id.strip() + if not go_id or go_id not in go_dag: + continue + try: + direct_anc, _ = _mapslim(go_id, go_dag, goslim_dag) + slim_terms.update(direct_anc) + except Exception: + pass + return '|'.join(sorted(slim_terms)) + + +def fetch_go(pdb_id, cache): + """Fetch GO annotations from PDBe for a PDB entry, cached by PDB ID.""" + if pdb_id in cache: + return cache[pdb_id] + try: + r = requests.get( + f'https://www.ebi.ac.uk/pdbe/graph-api/mappings/go/{pdb_id.lower()}', + timeout=30) + cache[pdb_id] = r.json() + except Exception: + cache[pdb_id] = {} + time.sleep(_API_DELAY) + return cache[pdb_id] + + +def extract_go_terms_for_chain(go_data, pdb_id, chain_id, category): + """Return pipe-delimited GO IDs mapped to chain_id under the given category.""" + terms = [] + # PDBe keys the response by lowercase PDB ID + pdb_entry = go_data.get(pdb_id.lower(), go_data.get(pdb_id.upper(), {})) + for go_id, info in pdb_entry.get('GO', {}).items(): + if info['category'] == category: + if any(m['chain_id'] == chain_id for m in info.get('mappings', [])): + terms.append(go_id) + return '|'.join(sorted(terms)) + + +def fetch_rcsb_entry(pdb_id, cache): + if pdb_id in cache: + return cache[pdb_id] + try: + r = requests.get( + f'https://data.rcsb.org/rest/v1/core/entry/{pdb_id}', timeout=30) + cache[pdb_id] = r.json() + except Exception: + cache[pdb_id] = {} + time.sleep(_API_DELAY) + return cache[pdb_id] + + +def fetch_uniprot_features(pdb_id, cache): + """Entry-level UniProt features for TM region annotation.""" + if pdb_id in cache: + return cache[pdb_id] + try: + r = requests.get( + f'https://data.rcsb.org/rest/v1/core/uniprot/{pdb_id}', timeout=30) + features = r.json()[0].get('rcsb_uniprot_feature', []) + except Exception: + features = [] + cache[pdb_id] = features + time.sleep(_API_DELAY) + return features + + +if __name__ == '__main__': + config = read_yaml('config.yml') + monomers_dir = config['pdb_monomers_dir'] + annot_file = config['annot_file'] + fasta_style_file = config['fasta_style_file'] + limit_file = config.get('limit_file', '') + go_obo_path = config.get('go_obo_file', '../data/go/go-basic.obo') + goslim_obo_path = config.get('goslim_obo_file', '../data/go/goslim_generic.obo') + + go_dag, goslim_dag = load_go_dags(go_obo_path, goslim_obo_path) + + limit_to_these = set() + if limit_file: + with open(limit_file, 'r') as f: + for line in f: + limit_to_these.add(os.path.basename(line.strip()).rsplit('.', 1)[0]) + + pdb_files = [] + for ext in ('*.pdb', '*.ent'): + pdb_files.extend( + glob.glob(os.path.join(monomers_dir, '**', ext), recursive=True)) + + go_cache = {} + rcsb_cache = {} + uniprot_cache = {} + annot_data = [] + for_fasta = {} + + for pdb_file in pdb_files: + if not os.path.isfile(pdb_file): + continue + + basename = os.path.basename(pdb_file) + noext = basename.rsplit('.', 1)[0] + + if limit_to_these and noext not in limit_to_these: + continue + + pdb_id, chain_id = parse_pdb_id_and_chain(basename) + if not pdb_id or not chain_id: + print(f'Skipping {basename}: cannot parse PDB ID / chain ID') + continue + + pdb_id_chain = f'{pdb_id}_{chain_id}' + proteogram_file = f'{pdb_id_chain}.jpg' + seq = get_sequence(pdb_file) + + # GO annotations — one API call per PDB entry, shared across chains + go_data = fetch_go(pdb_id, go_cache) + go_mf = extract_go_terms_for_chain(go_data, pdb_id, chain_id, 'Molecular_function') + go_bp = extract_go_terms_for_chain(go_data, pdb_id, chain_id, 'Biological_process') + go_cc = extract_go_terms_for_chain(go_data, pdb_id, chain_id, 'Cellular_component') + + # GO slim — map specific terms to goslim_generic for comparable class labels + go_slim_mf = map_to_go_slim(go_mf, go_dag, goslim_dag) + go_slim_bp = map_to_go_slim(go_bp, go_dag, goslim_dag) + go_slim_cc = map_to_go_slim(go_cc, go_dag, goslim_dag) + + # RCSB entry metadata — cached per PDB ID + rcsb = fetch_rcsb_entry(pdb_id, rcsb_cache) + deposit_date = rcsb.get('rcsb_accession_info', {}).get('deposit_date', '') + exp_method = rcsb.get('rcsb_entry_info', {}).get('experimental_method', '') + mol_weight = rcsb.get('rcsb_entry_info', {}).get('molecular_weight', '') + disulfide_cnt = rcsb.get('rcsb_entry_info', {}).get('disulfide_bond_count', '') + protein_entity_cnt = rcsb.get('rcsb_entry_info', {}).get('polymer_entity_count_protein', '') + + # UniProt TM regions — entry-level, cached per PDB ID + uniprot_features = fetch_uniprot_features(pdb_id, uniprot_cache) + # TODO: could be more precise by checking if the TM region maps to this chain, but many entries lack that detail, so we'll just flag the whole chain as TM if any TM region is present in the entry + tm_regions = [f for f in uniprot_features if f.get('type') == 'TRANSMEMBRANE_REGION'] + is_tm = len(tm_regions) > 0 + tm_cnt = len(tm_regions) + + annot_data.append([ + noext, + basename, + proteogram_file, + pdb_id, + chain_id, + pdb_id_chain, + len(seq), + deposit_date, + exp_method, + mol_weight, + disulfide_cnt, + protein_entity_cnt, + is_tm, + tm_cnt, + go_mf, + go_bp, + go_cc, + go_slim_mf, + go_slim_bp, + go_slim_cc, + seq, + ]) + for_fasta[f'>{pdb_id_chain}'] = seq + + annot_df = pd.DataFrame(annot_data, columns=[ + 'MonomerID', + 'PDBFileName', + 'ProteogramFileName', + 'PDBId', + 'ChainId', + 'PDBAndChainId', + 'PDBSequenceLength', + 'PDBDepositDate', + 'PDBExperimentalMethod', + 'PDBMolecularWeight', + 'PDBDisulfideBond', + 'PDBProteinEntityCount', + 'PDBIsTransmembrane', + 'PDBTransmembraneRegionCounts', + 'GOTerms_MF', + 'GOTerms_BP', + 'GOTerms_CC', + 'GOSlim_MF', + 'GOSlim_BP', + 'GOSlim_CC', + 'PDBSequence', + ]) + + os.makedirs(os.path.dirname(os.path.abspath(annot_file)), exist_ok=True) + try: + annot_df.to_csv(annot_file, sep='\t', index=False) + except Exception as e: + out = os.path.join('.', os.path.basename(annot_file)) + print(f'Could not save to {annot_file}: {e}, saving to {out}') + annot_df.to_csv(out, sep='\t', index=False) + + os.makedirs(os.path.dirname(os.path.abspath(fasta_style_file)), exist_ok=True) + try: + fasta_out = open(fasta_style_file, 'w') + except Exception as e: + fasta_out = open(os.path.join('.', os.path.basename(fasta_style_file)), 'w') + print(f'Could not save fasta to {fasta_style_file}: {e}') + for header, seq in for_fasta.items(): + fasta_out.write(header + '\n' + seq + '\n') + fasta_out.close() + + print(f'Annotated {len(annot_df)} chains. Saved to {annot_file}') diff --git a/scripts/v2/create_scope_annotation_file.py b/scripts/v2/create_scope_annotation_file.py new file mode 100644 index 0000000..25f9293 --- /dev/null +++ b/scripts/v2/create_scope_annotation_file.py @@ -0,0 +1,309 @@ +""" +Using a list of SCOPe-named structure files, look up key annotations and +create a tab-delimited file to hold those annotations for lookup. Uses +local SCOPe database files and RCSB PDB API as well as PDBe API calls. + +GO data is fetched once per PDB entry and cached across SCOPe domains from +the same structure. Chain-specific GO terms (MF/BP/CC) are extracted by +filtering the PDBe mappings list by chain_id. Specific GO terms are also +mapped to goslim_generic for comparable high-level class labels. +""" +import glob +import os +import time +import warnings + +import pandas as pd +import requests +from Bio.PDB.PDBParser import PDBConstructionWarning, PDBParser +from Bio.PDB.Polypeptide import PPBuilder +from Bio.SCOP import Scop +from goatools.mapslim import mapslim as _mapslim +from goatools.obo_parser import GODag + +from proteogram.common import read_yaml + + +warnings.filterwarnings("ignore", category=PDBConstructionWarning) + +_API_DELAY = 0.1 # seconds between unique PDB ID API calls +_GO_OBO_URL = 'http://current.geneontology.org/ontology/go-basic.obo' +_GOSLIM_OBO_URL = 'http://current.geneontology.org/ontology/subsets/goslim_generic.obo' + + +def get_sequence(pdb_path): + seq = '' + try: + p = PDBParser(PERMISSIVE=0) + structure = p.get_structure('xyz', pdb_path) + ppb = PPBuilder() + for pp in ppb.build_peptides(structure): + seq += str(pp.get_sequence()) + except Exception: + pass + return seq + + +def download_if_missing(url, path): + if os.path.exists(path): + return + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + print(f'Downloading {os.path.basename(path)} from Gene Ontology...') + r = requests.get(url, timeout=120) + r.raise_for_status() + with open(path, 'wb') as fh: + fh.write(r.content) + + +def load_go_dags(go_obo_path, goslim_obo_path): + """Load full GO and GO slim DAGs, auto-downloading OBO files if absent.""" + download_if_missing(_GO_OBO_URL, go_obo_path) + download_if_missing(_GOSLIM_OBO_URL, goslim_obo_path) + go_dag = GODag(go_obo_path) + goslim_dag = GODag(goslim_obo_path) + return go_dag, goslim_dag + + +def map_to_go_slim(go_ids_str, go_dag, goslim_dag): + """Map pipe-delimited specific GO IDs to their direct GO slim ancestors. + + Uses goslim_generic, which sits at a depth that is comparable across + proteins without being too specific — suitable as a classification label. + Returns a pipe-delimited string of slim GO IDs. + """ + if not go_ids_str: + return '' + slim_terms = set() + for go_id in go_ids_str.split('|'): + go_id = go_id.strip() + if not go_id or go_id not in go_dag: + continue + try: + direct_anc, _ = _mapslim(go_id, go_dag, goslim_dag) + slim_terms.update(direct_anc) + except Exception: + pass + return '|'.join(sorted(slim_terms)) + + +def fetch_go(pdb_id, cache): + """Fetch GO annotations from PDBe for a PDB entry, cached by PDB ID.""" + if pdb_id in cache: + return cache[pdb_id] + try: + r = requests.get( + f'https://www.ebi.ac.uk/pdbe/graph-api/mappings/go/{pdb_id.lower()}', + timeout=30) + cache[pdb_id] = r.json() + except Exception: + cache[pdb_id] = {} + time.sleep(_API_DELAY) + return cache[pdb_id] + + +def extract_go_terms_for_chain(go_data, pdb_id, chain_id, category): + """Return pipe-delimited GO IDs mapped to chain_id under the given category.""" + terms = [] + pdb_entry = go_data.get(pdb_id.lower(), go_data.get(pdb_id.upper(), {})) + for go_id, info in pdb_entry.get('GO', {}).items(): + if info['category'] == category: + if any(m['chain_id'] == chain_id for m in info.get('mappings', [])): + terms.append(go_id) + return '|'.join(sorted(terms)) + + +def fetch_rcsb_entry(pdb_id, cache): + if pdb_id in cache: + return cache[pdb_id] + try: + r = requests.get( + f'https://data.rcsb.org/rest/v1/core/entry/{pdb_id}', timeout=30) + cache[pdb_id] = r.json() + except Exception: + cache[pdb_id] = {} + time.sleep(_API_DELAY) + return cache[pdb_id] + + +def fetch_uniprot_features(pdb_id, cache): + """Entry-level UniProt features for TM region annotation.""" + if pdb_id in cache: + return cache[pdb_id] + try: + r = requests.get( + f'https://data.rcsb.org/rest/v1/core/uniprot/{pdb_id}', timeout=30) + features = r.json()[0].get('rcsb_uniprot_feature', []) + except Exception: + features = [] + cache[pdb_id] = features + time.sleep(_API_DELAY) + return features + + +if __name__ == '__main__': + config = read_yaml('config.yml') + limit_file = config['limit_file'] + go_obo_path = config.get('go_obo_file', '../data/go/go-basic.obo') + goslim_obo_path = config.get('goslim_obo_file', '../data/go/goslim_generic.obo') + structures_dir = config['scope_structures_dir'] + annot_file = config['annot_file'] + fasta_style_file = config['fasta_style_file'] + scope_cla_handle = config['scope_cla_file'] + scope_des_handle = config['scope_des_file'] + scope_hie_handle = config['scope_hie_file'] + + limit_to_these_structs = set() + if limit_file: + with open(limit_file, 'r') as f: + for line in f: + limit_to_these_structs.add( + os.path.basename(line.strip()).replace('.ent', '')) + + go_dag, goslim_dag = load_go_dags(go_obo_path, goslim_obo_path) + + scop = Scop(cla_handle=open(scope_cla_handle, 'r'), + des_handle=open(scope_des_handle, 'r'), + hie_handle=open(scope_hie_handle, 'r')) + + pdb_files = glob.glob(os.path.join(structures_dir, '**', '*'), recursive=True) + + go_cache = {} + rcsb_cache = {} + uniprot_cache = {} + annot_data = [] + for_fasta = {} + + for pdb_file in pdb_files: + if not os.path.isfile(pdb_file): + continue + + # SCOPe domain ID from filename: d{PDBID}{CHAIN}{domnum} + bname = os.path.basename(pdb_file).rsplit('.', 1)[0] + + if limit_to_these_structs and bname not in limit_to_these_structs: + continue + + seq = get_sequence(pdb_file) + + try: + pdb_id = bname[1:5].upper() + chain = bname[5].upper() + pdb_id_chain = f'{pdb_id}_{chain}' + prot_file = f'{pdb_id}_{chain}.jpg' + except Exception: + print(f'Problem with filename {os.path.basename(pdb_file)}') + continue + + try: + scop_entry = scop.getDomainBySid(bname) + sccs = scop_entry.sccs + sccs_spl = sccs.split('.') + cls = sccs_spl[0] + fold = '.'.join(sccs_spl[:2]) + sfam = '.'.join(sccs_spl[:3]) + fam = sccs + except Exception: + cls, fold, sfam, fam = '', '', '', '' + + # GO annotations — one API call per PDB entry, cached across SCOPe domains + go_data = fetch_go(pdb_id, go_cache) + go_mf = extract_go_terms_for_chain(go_data, pdb_id, chain, 'Molecular_function') + go_bp = extract_go_terms_for_chain(go_data, pdb_id, chain, 'Biological_process') + go_cc = extract_go_terms_for_chain(go_data, pdb_id, chain, 'Cellular_component') + + # GO slim — map specific terms to goslim_generic for comparable class labels + go_slim_mf = map_to_go_slim(go_mf, go_dag, goslim_dag) + go_slim_bp = map_to_go_slim(go_bp, go_dag, goslim_dag) + go_slim_cc = map_to_go_slim(go_cc, go_dag, goslim_dag) + + # RCSB entry metadata — cached per PDB ID + rcsb = fetch_rcsb_entry(pdb_id, rcsb_cache) + deposit_date = rcsb.get('rcsb_accession_info', {}).get('deposit_date', '') + exp_method = rcsb.get('rcsb_entry_info', {}).get('experimental_method', '') + mol_weight = rcsb.get('rcsb_entry_info', {}).get('molecular_weight', '') + disulfide_cnt = rcsb.get('rcsb_entry_info', {}).get('disulfide_bond_count', '') + protein_entity_cnt = rcsb.get('rcsb_entry_info', {}).get('polymer_entity_count_protein', '') + + # UniProt TM regions — entry-level, cached per PDB ID + uniprot_features = fetch_uniprot_features(pdb_id, uniprot_cache) + tm_regions = [f for f in uniprot_features if f.get('type') == 'TRANSMEMBRANE_REGION'] + is_tm = len(tm_regions) > 0 + tm_cnt = len(tm_regions) + + annot_data.append([ + bname, + os.path.basename(pdb_file), + prot_file, + pdb_id, + chain, + pdb_id_chain, + cls, + fold, + sfam, + fam, + len(seq), + deposit_date, + exp_method, + mol_weight, + disulfide_cnt, + protein_entity_cnt, + is_tm, + tm_cnt, + go_mf, + go_bp, + go_cc, + go_slim_mf, + go_slim_bp, + go_slim_cc, + seq, + ]) + + fasta_style_id = f'>{pdb_id_chain}|{bname}|{fam}' + for_fasta[fasta_style_id] = seq + + annot_df = pd.DataFrame(annot_data, columns=[ + 'SCOPeID', + 'PDBFileName', + 'ProteogramFileName', + 'PDBId', + 'ChainId', + 'PDBAndChainId', + 'SCOPeClass', + 'SCOPeFold', + 'SCOPeSuperfamily', + 'SCOPeFamily', + 'PDBSequenceLength', + 'PDBDepositDate', + 'PDBExperimentalMethod', + 'PDBMolecularWeight', + 'PDBDisulfideBond', + 'PDBProteinEntityCount', + 'PDBIsTransmembrane', + 'PDBTransmembraneRegionCounts', + 'GOTerms_MF', + 'GOTerms_BP', + 'GOTerms_CC', + 'GOSlim_MF', + 'GOSlim_BP', + 'GOSlim_CC', + 'PDBSequence', + ]) + + try: + annot_df.to_csv(annot_file, sep='\t', index=False) + except Exception as e: + out = os.path.join('.', os.path.basename(annot_file)) + print(f'Problem saving to {annot_file}: {e}, saving to {out}') + annot_df.to_csv(out, sep='\t', index=False) + + try: + fasta_out = open(fasta_style_file, 'w') + except Exception as e: + fasta_out = open(os.path.join('.', os.path.basename(fasta_style_file)), 'w') + print(f'Problem saving fasta to {fasta_style_file}: {e}') + + for header, seq in for_fasta.items(): + fasta_out.write(header + '\n' + seq + '\n') + fasta_out.close() + + print(f'Annotated {len(annot_df)} SCOPe domains. Saved to {annot_file}') diff --git a/scripts/v2/create_v2_proteograms.py b/scripts/v2/create_v2_proteograms.py index 1d0d513..ef7eee8 100644 --- a/scripts/v2/create_v2_proteograms.py +++ b/scripts/v2/create_v2_proteograms.py @@ -75,6 +75,7 @@ limit_file = config['limit_file'] structures_dir = config['scope_structures_dir'] proteograms_output_dir = config['all_proteograms_dir'] + cg_method = config.get('cg_method', None) or None # Only create proteograms for these structures in the input limit file limit_to_these_structs = [] @@ -169,11 +170,12 @@ calpha_atom_distance_cutoff=10, sequence_len_lower_cutoff=20, sequence_len_upper_cutoff=200, - use_gpu=use_gpu) + use_gpu=use_gpu, + cg_method=cg_method) # Skip chains that don't meet the sequence length cutoffs if not proteogram.is_valid_chain(): - print(f'Skipping {pdb_file}: sequence length {len(proteogram.sequence)} outside [{proteogram.sequence_len_lower_cutoff}, {proteogram.sequence_len_upper_cutoff}]') + print(f'Skipping {pdb_file}: sidechain completeness criteria not met or sequence length {len(proteogram.sequence)} outside [{proteogram.sequence_len_lower_cutoff}, {proteogram.sequence_len_upper_cutoff}]') del proteogram continue @@ -191,6 +193,8 @@ memory_efficient=args.memory_efficient) simulated_pdb_stream = None + print(f'Calculated Proteogram for {pdb_file} with error: {err}') + if err is not None and args.verbose: print(f'Error calculating Proteogram for {pdb_file}: {err}') @@ -200,6 +204,8 @@ plt.imsave(image_file, final_data.astype('uint8')) plt.close('all') # Clear matplotlib figures from memory plt.clf() # Clear current figure + + print(f'Saved Proteogram image to {image_file}') # Save production simulation PDB structure if requested if simulated_pdb_stream is not None and production_pdb_output_dir is not None: diff --git a/scripts/v2/train_multiple_models_randomized_eval.py b/scripts/v2/train_multiple_models_randomized_eval.py new file mode 100644 index 0000000..f7ec577 --- /dev/null +++ b/scripts/v2/train_multiple_models_randomized_eval.py @@ -0,0 +1,815 @@ +""" +This script trains a CNN on the Proteogram dataset, with options for architecture (from-scratch ConvNet or pretrained ResNet18), hyperparameters (epochs, batch size, learning rate), and logging. It includes early stopping based on validation loss, and saves the best model weights to disk. The training and validation loss curves are plotted and saved to a file. After training, the model is evaluated on a held-out test set, with per-class accuracies and a classification report printed to the console. + +Requires a GPU for reasonable training time, especially for ResNet18. For reproducibility, a random seed controls the train/val/test split and all stochastic operations. + +Data location and other parameters can be configured via command-line arguments or a config.yml file. Command-line arguments take precedence over config.yml. The Proteogram dataset should be prepared in advance using the create_v2_proteograms.py script. The root directory ("training_data_dir" in config.yml) should contain all proteogram images in a flat layout (no train/eval subdirectories). A random seed (--seed) is used to reproducibly split images into train, validation, and held-out test sets. An annotation TSV file is also required, which should be specified via the --tsv_file argument or included in config.yml as "tsv_file". The model weights will be saved to the path specified by "cnn_model_file_prefix" in config.yml, with a suffix indicating the architecture and hyperparameters. + +Here is more information about the SCOPe dataset: https://scop.berkeley.edu + + Usage example: + python train_multiple_models_randomized_eval.py --model resnet18 --epochs 50 --batch_size 32 --lr 1e-4 --seed 42 +""" +import copy +from sched import scheduler +import pandas as pd +import numpy as np +import glob +import os +import random +import matplotlib +import matplotlib.pyplot as plt +from PIL import Image +import argparse + +from sklearn.metrics import classification_report, roc_auc_score +from sklearn.model_selection import train_test_split + +import torch +from torch.utils.data import random_split, Dataset, DataLoader, Subset +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision.models as tv_models +import torchvision.transforms as transforms + +from proteogram.common import read_yaml + + +matplotlib.use('agg') + +class ProteogramDataset(Dataset): + """SCOPe-based Proteograms dataset.""" + + def __init__(self, tsv_file, root_dir, pad=True, level='class', new_size=128, transform=None, + names_to_labels=None, min_class_size=0, min_image_size=0, max_image_size=0, + exclude_classes=None): + """ + Arguments + --------- + tsv_file : string + Path to the tsv file with annotations. + root_dir : string + Directory with all the images. + level : string + The scope category level [class, fold, superfamily, family] + transform : callable, optional + Optional transform to be applied on a sample. + names_to_labels : dict, optional + Pre-built name→integer mapping from the train dataset. When supplied + (eval mode), samples whose class is absent from this mapping are + dropped rather than creating a new mapping. + min_class_size : int + Classes with fewer than this many samples are excluded (train mode + only, ignored when names_to_labels is supplied). Default: 0 (keep all). + min_image_size : int + Images whose width or height is below this threshold (in pixels) are + excluded. Since proteograms are NxN where N = residue count, this is + equivalent to filtering by sequence length. Default: 0 (keep all). + max_image_size : int + Images whose width or height exceeds this threshold (in pixels) are + excluded. Prevents silent cropping of large proteins when padding mode + is used. Default: 0 (keep all). + exclude_classes : list of str, optional + Class names to exclude entirely (train mode only; eval samples for + excluded classes are automatically dropped via the names_to_labels + mechanism). Default: None (keep all). + """ + self.annot_frame = pd.read_csv(tsv_file, sep='\t') + self.root_dir = root_dir + self.files = glob.glob(os.path.join(self.root_dir, '*.jpg')) + + if min_image_size > 0: + small = [] + valid_files = [] + for f in self.files: + w, h = Image.open(f).size + if w < min_image_size or h < min_image_size: + small.append(os.path.basename(f)) + else: + valid_files.append(f) + self.files = valid_files + if small: + print(f'WARNING: {len(small)} image(s) smaller than ' + f'{min_image_size}x{min_image_size} px excluded. ' + f'First few: {small[:5]}') + + if max_image_size > 0: + large = [] + valid_files = [] + for f in self.files: + w, h = Image.open(f).size + if w > max_image_size or h > max_image_size: + large.append(os.path.basename(f)) + else: + valid_files.append(f) + self.files = valid_files + if large: + print(f'WARNING: {len(large)} image(s) larger than ' + f'{max_image_size}x{max_image_size} pixels excluded. ' + f'First few: {large[:5]}') + self.transform = transform + self.pad = pad + self.level = level # class, fold, superfamily or family + self.new_size = new_size + + level_col = {'class': 'SCOPeClass', 'fold': 'SCOPeFold', + 'superfamily': 'SCOPeSuperfamily', 'family': 'SCOPeFamily'} + col = level_col.get(self.level, 'SCOPeFamily') + + # Look up annotation label for each image file + self.label_names = [] + missing = [] + for file in self.files: + bname = os.path.basename(file).replace('.jpg', '') + row = self.annot_frame[self.annot_frame['SCOPeID'] == bname] + if len(row) == 0: + missing.append(bname) + self.label_names.append(None) + else: + self.label_names.append(row.iloc[0][col]) + + if missing: + print(f'WARNING: {len(missing)} image(s) not found in TSV annotations ' + f'and will be excluded. First few: {missing[:5]}') + + # Drop files with no annotation + paired = [(f, l) for f, l in zip(self.files, self.label_names) if l is not None] + if not paired: + raise RuntimeError( + f'No images matched any entry in the TSV SCOPeID column. ' + f'Check that image basenames (without .jpg) match the SCOPeID values.') + self.files, self.label_names = zip(*paired) + self.files, self.label_names = list(self.files), list(self.label_names) + + if names_to_labels is None: + # Train mode: optionally exclude named classes + if exclude_classes: + excluded_set = set(exclude_classes) + unknown = excluded_set - set(self.label_names) + if unknown: + print(f'WARNING: --exclude_classes named class(es) not found in data: ' + + ', '.join(sorted(unknown))) + paired = [(f, l) for f, l in zip(self.files, self.label_names) + if l not in excluded_set] + if not paired: + raise RuntimeError('All classes were excluded — check --exclude_classes.') + print(f'Excluding {len(excluded_set - unknown)} named class(es): ' + + ', '.join(sorted(excluded_set - unknown))) + self.files, self.label_names = zip(*paired) + self.files, self.label_names = list(self.files), list(self.label_names) + + # Train mode: optionally exclude classes below the minimum size threshold + if min_class_size > 0: + label_counts = {name: self.label_names.count(name) + for name in set(self.label_names)} + excluded = {n for n, c in label_counts.items() if c < min_class_size} + if excluded: + print(f'Excluding {len(excluded)} class(es) with < {min_class_size} samples: ' + + ', '.join(f'{n} ({label_counts[n]})' for n in sorted(excluded))) + paired = [(f, l) for f, l in zip(self.files, self.label_names) + if l not in excluded] + if not paired: + raise RuntimeError('All classes were excluded — lower min_class_size.') + self.files, self.label_names = zip(*paired) + self.files, self.label_names = list(self.files), list(self.label_names) + + self.label_names_unique = set(self.label_names) + self.names_to_labels = {name: i for i, name in enumerate(sorted(self.label_names_unique))} + self.labels_to_names = {i: name for name, i in self.names_to_labels.items()} + else: + # Eval mode: reuse the train mapping; drop samples for unseen classes + unknown = set(self.label_names) - set(names_to_labels.keys()) + if unknown: + print(f'Dropping {len(unknown)} eval class(es) not in training set ' + f'(excluded during training): {sorted(unknown)}') + paired = [(f, l) for f, l in zip(self.files, self.label_names) + if l in names_to_labels] + self.files, self.label_names = zip(*paired) + self.files, self.label_names = list(self.files), list(self.label_names) + self.label_names_unique = set(self.label_names) + self.names_to_labels = names_to_labels + self.labels_to_names = {v: k for k, v in names_to_labels.items()} + + self.labels = [self.names_to_labels[n] for n in self.label_names] + + def get_pad(self, curr_size: int, target_size: int): + d = target_size - curr_size + if d <= 0: return (0, 0) # no need to pad + p1 = d // 2 + p2 = d - p1 + return (p1, p2) + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + img_name = self.files[idx] + label = torch.tensor(self.labels[idx]) + if self.pad: + image = plt.imread(img_name) + H, W = image.shape[0], image.shape[1] + DH, DW = self.new_size, self.new_size # desired height / width + padding = (self.get_pad(H, DH), self.get_pad(W, DW), (0, 0)) + image = np.pad(image, padding, constant_values=128) # pad with gray + else: + image = Image.open(img_name).convert('RGB') + image = image.resize((self.new_size, self.new_size)) + image = np.array(image) + if self.transform: + image = self.transform(image) + return image, label + + +class ConvNet(nn.Module): + """From-scratch CNN: 4 conv blocks (3→64→128→256→256) + GAP + FC. + + Uses BatchNorm after every conv layer for stable small-dataset training. + Global Average Pooling makes the architecture input-size agnostic. + Dropout (p=0.5) is applied in the FC layers only. + """ + def __init__(self, num_classes): + super().__init__() + + def _block(in_ch, out_ch): + return nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + ) + + self.block1 = _block(3, 64) + self.block2 = _block(64, 128) + self.block3 = _block(128, 256) + self.block4 = _block(256, 256) + self.gap = nn.AdaptiveAvgPool2d(1) # (batch, 256, 1, 1) + self.fc1 = nn.Linear(256, 128) + self.fc2 = nn.Linear(128, num_classes) + + def forward(self, x): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.gap(x).view(x.size(0), -1) # flatten: (batch, 256) + x = F.dropout(F.relu(self.fc1(x)), p=0.5, training=self.training) + x = self.fc2(x) + return x + + +def build_resnet18(num_classes, freeze_layers=('layer1',)): + """Pretrained ResNet18 with the classification head replaced. + + Only the very first residual block (layer1) is frozen — proteograms encode + distance-matrix geometry that looks nothing like ImageNet, so the backbone + needs freedom to adapt. Regularisation comes from AdamW weight decay + rather than aggressive layer freezing. A Dropout is inserted before the + final linear layer for additional regularisation. + """ + model = tv_models.resnet18(weights=tv_models.ResNet18_Weights.IMAGENET1K_V1) + for name, param in model.named_parameters(): + if any(name.startswith(layer) for layer in freeze_layers): + param.requires_grad = False + in_features = model.fc.in_features + model.fc = nn.Sequential( + nn.Dropout(0.5), + nn.Linear(in_features, num_classes), + ) + return model + + +def train_model(model, train_loader, val_loader, optimizer, epochs, + patience=None, device=torch.device('cpu')): + """Train the ConvNet, tracking train and val loss each epoch. + + If `patience` is set, applies early stopping: training halts when val loss + has not improved for that many consecutive epochs, and the best weights are + restored. If `patience` is None, all epochs run and no weight restoration + is performed. + """ + model.to(device) + training_loss = [] + val_loss_history = [] + + best_val_loss = float('inf') + best_weights = None + best_epoch = -1 + epochs_no_improve = 0 + + # LR scheduler uses its own patience (independent of early stopping patience) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) + + loaders = {'train': train_loader, 'val': val_loader} + + for epoch in range(epochs): + epoch_losses = {} + for phase in ['train', 'val']: + model.train() if phase == 'train' else model.eval() + running_loss = 0.0 + n_batches = 0 + with torch.set_grad_enabled(phase == 'train'): + for data, target in loaders[phase]: + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = loss_criteria(output, target) + if phase == 'train': + loss.backward() + optimizer.step() + running_loss += loss.item() + n_batches += 1 + epoch_losses[phase] = running_loss / n_batches + lr_before = optimizer.param_groups[0]['lr'] + scheduler.step(epoch_losses['val']) + lr_after = optimizer.param_groups[0]['lr'] + + + training_loss.append(epoch_losses['train']) + val_loss_history.append(epoch_losses['val']) + + improved = epoch_losses['val'] < best_val_loss + if improved: + best_val_loss = epoch_losses['val'] + if patience is not None: + best_weights = copy.deepcopy(model.state_dict()) + best_epoch = epoch + epochs_no_improve = 0 + else: + epochs_no_improve += 1 + + suffix = '' + if patience is not None: + suffix = ' | *' if improved else f' | (no improvement {epochs_no_improve}/{patience})' + if lr_after < lr_before: + suffix += f' | LR reduced: {lr_before:.2e} → {lr_after:.2e}' + print(f'Epoch {epoch:>4d}: train loss: {epoch_losses["train"]:.6f} ' + f'val loss: {epoch_losses["val"]:.6f}' + suffix) + + if patience is not None and epochs_no_improve >= patience: + print(f'Early stopping at epoch {epoch} — no val loss improvement for {patience} epochs.') + break + + if patience is not None and best_weights is not None: + print(f'Restoring best weights (val loss: {best_val_loss:.6f})') + model.load_state_dict(best_weights) + return model, training_loss, val_loss_history, best_epoch + 1 + +def split_train_test(full_dataset, generator): + """Split a PyTorch Dataset object into train and test sets""" + total_size = len(full_dataset) + train_size = int(total_size * 0.7) + test_size = total_size - train_size + train_dataset, test_dataset = random_split( + full_dataset, + [train_size, test_size], + generator=generator + ) + return train_dataset, test_dataset + +def get_accuracies(model, test_loader, class_names, labels_to_names, device=torch.device('cpu')): + """Accuracies per class, classification report, and AUC-ROC.""" + correct_pred = {classname: 0 for classname in class_names} + total_pred = {classname: 0 for classname in class_names} + total_correct = 0 + len_data = len(test_loader) + y_pred = [] + y_test = [] + y_scores = [] + model.eval() + model.to(device) + with torch.no_grad(): + for (data, targets) in test_loader: + data, targets = data.to(device), targets.to(device) + outputs = model(data) + probs = torch.softmax(outputs, dim=1) + _, predictions = torch.max(outputs, 1) + y_scores.append(probs.cpu().numpy()) + # collect the correct predictions for each class + for label, prediction in zip(targets, predictions): + if label == prediction: + total_correct += 1 + correct_pred[labels_to_names[int(label)]] += 1 + total_pred[labels_to_names[int(label)]] += 1 + y_pred.append(labels_to_names[int(prediction)]) + y_test.append(labels_to_names[int(label)]) + + # Print accuracy for each class + for classname, correct_count in correct_pred.items(): + try: + accuracy = 100 * float(correct_count) / total_pred[classname] + print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %') + except ZeroDivisionError: + print(f'No samples in test set for class: {classname}') + overall_accuracy_str = f"{100 * float(total_correct) / len_data:.1f}" + print(f'Overall accuracy: {overall_accuracy_str} %') + + print('\nAdditional Classification Report:') + print(classification_report(y_test, y_pred)) + + name_to_int = {v: k for k, v in labels_to_names.items()} + y_test_int = [name_to_int[n] for n in y_test] + y_scores_arr = np.vstack(y_scores) + auc_macro = roc_auc_score(y_test_int, y_scores_arr, multi_class='ovr', average='macro') + auc_weighted = roc_auc_score(y_test_int, y_scores_arr, multi_class='ovr', average='weighted') + print(f'\nAUC-ROC (macro): {auc_macro:.4f}') + print(f'AUC-ROC (weighted): {auc_weighted:.4f}') + + return overall_accuracy_str + + +def view_pred_set(model, test_loader, num_preds, labels_to_names, fig_path): + """Graph a set of predictions with labels and save plot.""" + predictions = [] + images = [] + labels = [] + with torch.no_grad(): + cnt = 1 + for (data, target) in test_loader: + images.append(data) + labels.append(target[0]) + outputs = model(data) + _, predicted = torch.max(outputs, 1) + predictions.append(predicted[0]) + + if cnt == num_preds: + break + cnt += 1 + + for i in range(num_preds): + plot_row = max(1, int(num_preds/2)) + plt.subplot(2, plot_row, i + 1) + img = images[i] + npimg = img.numpy() + npimg = np.squeeze(npimg, axis=0) + npimg = npimg / 2 + 0.5 + plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.axis('off') + + color = "green" + pred = int(predictions[i].numpy()) + label = int(labels[i].numpy()) + name = labels_to_names[label] + if label != pred: + color = "red" + plt.title(name, color=color) + + plt.suptitle('Objects Found by Model', size=20) + plt.savefig(fig_path) + +def load_model(model_path, classes, image_size): + """Load the ConvNet model from disk.""" + model = ConvNet(classes, image_size) + ConvNet.load_state_dict(torch.load(model_path)) + return model + +def plot_losses(training_loss, val_loss, fig_path): + """Plot training and validation loss curves on the same axes and save to file.""" + epochs = range(1, len(training_loss) + 1) + plt.figure() + plt.plot(epochs, training_loss, label='Train loss') + plt.plot(epochs, val_loss, label='Val loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.title('Training and Validation Loss') + plt.legend() + plt.tight_layout() + plt.savefig(fig_path, dpi=150) + plt.close() + print(f'Loss curve saved to {fig_path}') + +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +class FocalLoss(nn.Module): + """Multi-class focal loss: FL = -(1 - p_t)^gamma * log(p_t). + + gamma=0 reduces to standard cross-entropy. gamma=2 is the value used in + the original RetinaNet paper and is a reasonable starting point. + label_smoothing is applied to the underlying cross-entropy term. + """ + def __init__(self, gamma=2.0, label_smoothing=0.0): + super().__init__() + self.gamma = gamma + self.label_smoothing = label_smoothing + + def forward(self, inputs, targets): + ce = F.cross_entropy(inputs, targets, label_smoothing=self.label_smoothing, reduction='none') + pt = torch.exp(-ce) + return ((1 - pt) ** self.gamma * ce).mean() + + +class TransformedSubset(Dataset): + """Wraps a Subset and applies a transform, allowing train/eval subsets to use different augmentations.""" + def __init__(self, subset, transform): + self.subset = subset + self.transform = transform + + def __len__(self): + return len(self.subset) + + def __getitem__(self, idx): + image, label = self.subset[idx] + if self.transform: + image = self.transform(image) + return image, label + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + description="Train CNN on Proteograms.") + parser.add_argument('--data_dir', '-d', + type=str, + default=None, + help="Root directory containing all proteogram images (flat layout, " + "no train/eval subdirectories). Overrides training_data_dir in config.yml.") + parser.add_argument("--epochs", "-e", + type=int, + help="Number of training epochs.") + parser.add_argument("--batch_size", "-b", + type=int, + help="Training batch size.") + parser.add_argument("--lr", "-l", + type=float, + help="Training learning rate.") + parser.add_argument('--model', '-m', + choices=['cnn', 'resnet18'], + default='cnn', + help="Model architecture: 'cnn' (from-scratch 4-block ConvNet) " + "or 'resnet18' (pretrained ResNet18 fine-tuning). Default: cnn.") + parser.add_argument('--overwrite', '-o', + action='store_true', + help="Recreate / overwrite model") + parser.add_argument('--resize', + action='store_true', + help="Resize images to new_size instead of padding. " + "Default is to pad with gray, which preserves the " + "1-pixel-per-residue-pair semantic of proteograms.") + parser.add_argument('--verbose', '-v', + action='store_true', + help="Verbose output and logging.") + parser.add_argument('--tsv_file', '-t', + type=str, + default=None, + help="Path to the TSV annotations file. Defaults to " + "ProteogramData_SCOP_RCSB_PDBe_AnnotationsLookup.tsv " + "in the proteograms directory.") + parser.add_argument('--patience', + type=int, + default=None, + help="Early stopping patience: stop after this many epochs " + "with no improvement in val loss. Omit to disable early stopping.") + parser.add_argument('--seed', + type=int, + default=0, + help="Random seed used for the train/val/test split and all stochastic " + "operations. Change to produce a different reproducible split (default: 0).") + parser.add_argument('--val_size', + type=float, + default=0.15, + help="Fraction of images to hold out as validation " + "set for early stopping (default: 0.15).") + parser.add_argument('--test_size', + type=float, + default=0.15, + help="Fraction of images to hold out as the final held-out test set (default: 0.15).") + parser.add_argument('--save_test_list', + action='store_true', + help="Flag to save the list of test set image filenames to a text file for later " "reference.") + parser.add_argument('--max_image_size', + type=int, + default=200, + help="Pad proteogram images to this square size (pixels) and exclude any " + "images larger than this value. Equivalent to filtering by max sequence " + "length. Default: 200.") + parser.add_argument('--exclude_classes', '-x', + type=str, + default=None, + help="Comma-separated list of SCOPe class names to exclude " + "from training and evaluation (e.g. 'j,h'). " + "Useful for removing very small or low-quality classes.") + parser.add_argument('--level', + choices=['class', 'fold', 'superfamily', 'family'], + default='class', + help="SCOPe hierarchy level to use as the classification target. " + "'class' is the highest (broadest) level; 'family' is the lowest " + "(finest). Default: class.") + parser.add_argument('--min_class_size', + type=int, + default=20, + help="Exclude classes with fewer than this many samples to avoid extreme " + "class imbalance. Default: 20.") + parser.add_argument('--loss', + choices=['ce', 'focal'], + default='ce', + help="Loss function: 'ce' (cross-entropy with label smoothing) or " + "'focal' (focal loss, good for hard/imbalanced examples). Default: ce.") + parser.add_argument('--focal_gamma', + type=float, + default=2.0, + help="Focusing parameter gamma for focal loss (ignored when --loss=ce). " + "gamma=0 reduces to cross-entropy; gamma=2 is the RetinaNet default. " + "Default: 2.0.") + args = parser.parse_args() + + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + config = read_yaml('config.yml') + root_dir = args.data_dir or config['training_data_dir'] + # Get level from command line or config, with command line taking precedence. Default to 'class' if neither is provided. + level = args.level or config.get('scope_level', 'class') + + if args.epochs: + epochs = args.epochs + elif 'num_epochs' in config: + epochs = config['num_epochs'] + else: + raise ValueError("Number of epochs must be specified via command line or config.yml") + if args.lr: + lr = args.lr + elif 'learning_rate' in config: + lr = config['learning_rate'] + else: + raise ValueError("Learning rate must be specified via command line or config.yml") + if args.batch_size: + batch_size = args.batch_size + elif 'batch_size' in config: + batch_size = config['batch_size'] + else: + raise ValueError("Batch size must be specified via command line or config.yml") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + image_resize = args.max_image_size + + # ResNet18 was trained with ImageNet normalisation (standardize input images to the same distribution as the data the model was pre-trained on) + # so the pretrained feature detectors remain valid. The ConvNet was not pretrained, so it doesn't strictly require ImageNet normalisation, but applying the same normalisation to both models allows for a more controlled comparison. + _imagenet_norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + _augment = [ + transforms.RandomApply([transforms.ColorJitter(brightness=0.1, contrast=0.1)], p=0.1), + transforms.RandomApply([transforms.RandomAdjustSharpness(sharpness_factor=2)], p=0.1), + ] + + if args.model == 'resnet18': + transform_train = transforms.Compose([transforms.ToTensor()] + _augment + [_imagenet_norm]) + transform_eval = transforms.Compose([transforms.ToTensor(), _imagenet_norm]) + else: + transform_train = transforms.Compose([transforms.ToTensor()] + _augment + [_imagenet_norm]) + transform_eval = transforms.Compose([transforms.ToTensor(), _imagenet_norm]) + + tsv_file = args.tsv_file or os.path.join( + root_dir, '..', 'ProteogramData_SCOP_RCSB_PDBe_AnnotationsLookup_AllSCOPe208.tsv') + + exclude_classes = [c.strip() for c in args.exclude_classes.split(',')] \ + if args.exclude_classes else None + + # Load the full dataset without transforms so each split can use its own augmentation policy. + # https://docs.pytorch.org/docs/stable/notes/randomness.html + g = torch.Generator() + g.manual_seed(args.seed) + + full_dataset = ProteogramDataset( + tsv_file=tsv_file, + root_dir=root_dir, + level=args.level, + new_size=image_resize, + pad=not args.resize, + transform=None, + # exclude classes with fewer than this many samples to avoid extreme class imbalance + min_class_size=args.min_class_size, + min_image_size=20, + max_image_size=args.max_image_size, + exclude_classes=exclude_classes) + + all_indices = list(range(len(full_dataset))) + labels = full_dataset.labels + + train_val_indices, test_indices = train_test_split( + all_indices, + test_size=args.test_size, + random_state=args.seed, + stratify=[labels[i] for i in all_indices]) + + # val fraction relative to the train+val pool + val_fraction = args.val_size / (1.0 - args.test_size) + train_indices, val_indices = train_test_split( + train_val_indices, + test_size=val_fraction, + random_state=args.seed, + stratify=[labels[i] for i in train_val_indices]) + + train_split = TransformedSubset(Subset(full_dataset, train_indices), transform_train) + val_split = TransformedSubset(Subset(full_dataset, val_indices), transform_eval) + test_split = TransformedSubset(Subset(full_dataset, test_indices), transform_eval) + print(f'Stratified split (seed={args.seed}): ' + f'{len(train_indices)} train / {len(val_indices)} val / {len(test_indices)} test (held-out)') + + # WeightedRandomSampler: oversample minority classes so each epoch sees + # a balanced class distribution regardless of raw class frequencies. + train_labels = [full_dataset.labels[i] for i in train_indices] + label_counts = torch.bincount(torch.tensor(train_labels)) + class_weights = torch.where( + label_counts > 0, 1.0 / label_counts.float(), torch.zeros_like(label_counts.float())) + sample_weights = [class_weights[lbl].item() for lbl in train_labels] + sampler = torch.utils.data.WeightedRandomSampler( + weights=sample_weights, num_samples=len(sample_weights), replacement=True) + + class_names = full_dataset.label_names_unique + + train_loader = DataLoader(train_split, + batch_size=args.batch_size, + # Note, shuffle is mutually exclusive with sampler + shuffle=True, + # sampler=sampler, + worker_init_fn=seed_worker) + val_loader = DataLoader(val_split, + batch_size=args.batch_size, + shuffle=False, + worker_init_fn=seed_worker, + generator=g) + test_loader = DataLoader(test_split, + batch_size=1, + shuffle=False, + worker_init_fn=seed_worker, + generator=g) + + num_classes = len(class_names) + print(f'Number of classes: {num_classes}') + + if args.model == 'resnet18': + model = build_resnet18(num_classes) + # Differential LR: lower rate for pretrained backbone, full rate for new head + backbone_params = [p for n, p in model.named_parameters() if 'fc' not in n and p.requires_grad] + head_params = list(model.fc.parameters()) + optimizer = optim.AdamW([ + {'params': backbone_params, 'lr': args.lr * 0.1}, + {'params': head_params, 'lr': args.lr}, + ], weight_decay=1e-3) + print(f'ResNet18: backbone LR={args.lr * 0.1:.2e}, head LR={args.lr:.2e}') + else: + model = ConvNet(num_classes) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + print(f'ConvNet (from scratch): LR={args.lr:.2e}') + + if args.loss == 'focal': + loss_criteria = FocalLoss(gamma=args.focal_gamma, label_smoothing=0.1) + print(f'Loss: FocalLoss(gamma={args.focal_gamma}, label_smoothing=0.1)') + else: + loss_criteria = nn.CrossEntropyLoss(label_smoothing=0.1) + print('Loss: CrossEntropyLoss(label_smoothing=0.1)') + model, training_loss, val_loss, epochs_trained = train_model(model, + train_loader=train_loader, + val_loader=val_loader, + optimizer=optimizer, + epochs=args.epochs, + patience=args.patience, + device=device) + + output_dir = os.path.dirname(os.path.abspath(root_dir)) + + overall_accuracy = get_accuracies(model, + test_loader, + class_names, + full_dataset.labels_to_names) + + loss_tag = f'focal_g{args.focal_gamma}' if args.loss == 'focal' else 'ce' + suffix = f'_{args.model}_lr{lr}_bs{batch_size}_e{epochs_trained}_seed{args.seed}_max_image_size{args.max_image_size}_min_class_size{args.min_class_size}_level-{level}_loss{loss_tag}_acc{overall_accuracy}' + + plot_losses(training_loss, val_loss, + fig_path=os.path.join(output_dir, f'loss_curves{suffix}.png')) + + model_file = config.get('model_file_prefix', 'scope_proteogram_model') + suffix + '.pt' + model_path = os.path.join(output_dir, model_file) + + # Save model + if os.path.exists(model_path) and not args.overwrite: + print(f'Model file {model_path} exists and overwrite not set, not saving model.') + else: + # Save only the model weights + torch.save(model.state_dict(), model_path) + print(f'Saved model to {model_path}') + + if args.save_test_list: + save_list_name = f"test_list_for_model_{suffix}.lst" + save_list_path = os.path.join(output_dir, save_list_name) + os.makedirs(os.path.dirname(os.path.abspath(save_list_path)), exist_ok=True) + with open(save_list_path, 'w') as f: + for i in test_indices: + f.write(os.path.splitext(os.path.basename(full_dataset.files[i]))[0] + '\n') + print(f'Test set file prefixes written to {save_list_path}') + + view_pred_set(model, + test_loader, + num_preds=10, + labels_to_names=full_dataset.labels_to_names, + fig_path=os.path.join(output_dir, f'sample_preds{suffix}.png')) + diff --git a/uv.lock b/uv.lock index 1aa5ba3..39ac717 100644 --- a/uv.lock +++ b/uv.lock @@ -746,6 +746,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] +[[package]] +name = "et-xmlfile" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/38/af70d7ab1ae9d4da450eeec1fa3918940a5fafb9055e934af8d6eb0c2313/et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54", size = 17234, upload-time = "2024-10-25T17:25:40.039Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa", size = 18059, upload-time = "2024-10-25T17:25:39.051Z" }, +] + [[package]] name = "executing" version = "2.2.1" @@ -840,6 +849,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ab/fb21f4c939bb440104cc2b396d3be1d9b7a9fd3c6c2a53d98c45b3d7c954/fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437", size = 202505, upload-time = "2026-02-05T21:50:51.819Z" }, ] +[[package]] +name = "ftpretty" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a2/c5/f5ed409312f67697bf7967cd52ed74794ccb860fec29d2105470c987de09/ftpretty-0.4.0.tar.gz", hash = "sha256:61233b9212f2cceec96ee2c972738fa31cae7248e92d0874c99c04ee739bb5a9", size = 9068, upload-time = "2021-06-12T17:45:20.038Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/d1/cbd3736da8d6253da85838b105b97bc794965584243711fa2ef0bb585df3/ftpretty-0.4.0-py2.py3-none-any.whl", hash = "sha256:1c0c45bacf69b82827838ae9b77a66e48064d2686649628e647965a85aa7367a", size = 8180, upload-time = "2021-06-12T17:45:18.349Z" }, +] + +[[package]] +name = "goatools" +version = "1.6.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ftpretty" }, + { name = "numpy" }, + { name = "openpyxl" }, + { name = "pandas" }, + { name = "pydot" }, + { name = "requests" }, + { name = "rich" }, + { name = "scipy" }, + { name = "setuptools" }, + { name = "statsmodels" }, + { name = "xlsxwriter" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e6/87/07e80f31af72d8e2c35936d1e5670ac05ff6009cc4a9412eb99e18efd24a/goatools-1.6.5.tar.gz", hash = "sha256:0d799706dc3ae4480feda25f411f8e9b2741c0d8ea7ad73af0b05730198a4be1", size = 17760712, upload-time = "2026-05-12T06:08:38.592Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/53/1020a7b32651289a68ec853d43fb412c8982d6d1de2289e6cb6f55bc4b7b/goatools-1.6.5-py3-none-any.whl", hash = "sha256:2b093ded287178b494e28350bce6beb60b920ff5f923fc5aaef227dd6ffb935a", size = 15768926, upload-time = "2026-05-12T06:08:35.183Z" }, +] + +[[package]] +name = "graphql-core" +version = "3.2.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4d/90/f2aff026ab4aebd80eb71905106a0885f4cfde85dcf965543f45bed0d9ee/graphql_core-3.2.11.tar.gz", hash = "sha256:e7e156d10beb127cab5c89ff0da71416fc73d27c484a4757d3b2d35633774802", size = 528407, upload-time = "2026-06-05T13:45:22.915Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/15/b92b4e1d88d02c6eff9733c9eea21846ab435cc4d813d84ccc5d335955df/graphql_core-3.2.11-py3-none-any.whl", hash = "sha256:0b3e35ff41e9adba53021ab0cef475eb18f57c7f53f0f2ca55567fbf3c537ea0", size = 214879, upload-time = "2026-06-05T13:45:21.245Z" }, +] + [[package]] name = "griddataformats" version = "1.1.0" @@ -1466,6 +1518,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/bc/83e112abc66cd466c6b83f99118035867cecd41802f8d044638aa78a106e/locket-1.0.0-py2.py3-none-any.whl", hash = "sha256:b6c819a722f7b6bd955b80781788e4a66a55628b858d347536b7e81325a3a5e3", size = 4398, upload-time = "2022-04-20T22:04:42.23Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/ff/7841249c247aa650a76b9ee4bbaeae59370dc8bfd2f6c01f3630c35eb134/markdown_it_py-4.2.0.tar.gz", hash = "sha256:04a21681d6fbb623de53f6f364d352309d4094dd4194040a10fd51833e418d49", size = 82454, upload-time = "2026-05-07T12:08:28.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/81/4da04ced5a082363ecfa159c010d200ecbd959ae410c10c0264a38cac0f5/markdown_it_py-4.2.0-py3-none-any.whl", hash = "sha256:9f7ebbcd14fe59494226453aed97c1070d83f8d24b6fc3a3bcf9a38092641c4a", size = 91687, upload-time = "2026-05-07T12:08:27.182Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -1700,6 +1764,15 @@ parallel = [ { name = "dask" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mistune" version = "3.2.0" @@ -2364,6 +2437,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/7b/03941691df425c4405a6729cba4db7168de55e11d08031d9f594aba052e8/openmm_cuda_12-8.4.0.post2-py3-none-win_amd64.whl", hash = "sha256:b9cafc353990f36700d5b6d3b523610590470f37dce06b0c23f4c9959de96df1", size = 1941053, upload-time = "2025-11-24T21:40:20.991Z" }, ] +[[package]] +name = "openpyxl" +version = "3.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "et-xmlfile" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/f9/88d94a75de065ea32619465d2f77b29a0469500e99012523b91cc4141cd1/openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050", size = 186464, upload-time = "2024-06-28T14:03:44.161Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2", size = 250910, upload-time = "2024-06-28T14:03:41.161Z" }, +] + [[package]] name = "overrides" version = "7.7.0" @@ -2494,6 +2579,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e8/a4/7a75d96559b9f093be50e2b5fdc5617e478d163076b9165295323ab0095c/pathsimanalysis-1.2.0-py3-none-any.whl", hash = "sha256:0edce2fd1a55a788ef80b3666f9704455c01f06719f516fec20215fb9e16e291", size = 48816, upload-time = "2024-11-23T00:19:28.878Z" }, ] +[[package]] +name = "patsy" +version = "1.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/be/44/ed13eccdd0519eff265f44b670d46fbb0ec813e2274932dc1c0e48520f7d/patsy-1.0.2.tar.gz", hash = "sha256:cdc995455f6233e90e22de72c37fcadb344e7586fb83f06696f54d92f8ce74c0", size = 399942, upload-time = "2025-10-20T16:17:37.535Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/70/ba4b949bdc0490ab78d545459acd7702b211dfccf7eb89bbc1060f52818d/patsy-1.0.2-py2.py3-none-any.whl", hash = "sha256:37bfddbc58fcf0362febb5f54f10743f8b21dd2aa73dec7e7ef59d1b02ae668a", size = 233301, upload-time = "2025-10-20T16:17:36.563Z" }, +] + [[package]] name = "pdbfixer" version = "1.12.0" @@ -2644,10 +2741,11 @@ wheels = [ [[package]] name = "proteogram" -version = "0.0.3" +version = "0.0.4" source = { editable = "." } dependencies = [ { name = "biopython" }, + { name = "goatools" }, { name = "kmeans-pytorch" }, { name = "matplotlib" }, { name = "mdanalysis", extra = ["analysis", "extra-formats", "parallel"] }, @@ -2660,6 +2758,7 @@ dependencies = [ { name = "psutil" }, { name = "pyrotein" }, { name = "pyyaml" }, + { name = "rcsb-api" }, { name = "torch" }, { name = "torchsummary" }, { name = "torchvision" }, @@ -2668,14 +2767,12 @@ dependencies = [ [package.optional-dependencies] cuda12 = [ - { name = "nvidia-cuda-nvcc-cu12" }, { name = "openmm" }, { name = "openmm-cuda-12" }, ] notebook = [ { name = "jupyterlab" }, { name = "nglview" }, - { name = "rcsbsearchapi" }, ] test = [ { name = "pytest" }, @@ -2685,13 +2782,13 @@ test = [ [package.metadata] requires-dist = [ { name = "biopython", specifier = ">=1.8" }, + { name = "goatools", specifier = ">=1.4" }, { name = "jupyterlab", marker = "extra == 'notebook'", specifier = ">=4.2.5" }, { name = "kmeans-pytorch", specifier = ">=0.3" }, { name = "matplotlib", specifier = ">=3.6" }, { name = "mdanalysis", extras = ["analysis", "extra-formats", "parallel"], specifier = ">=2.10.0" }, { name = "nglview", marker = "extra == 'notebook'", specifier = ">=3.1.4" }, { name = "numpy", specifier = ">=1.26" }, - { name = "nvidia-cuda-nvcc-cu12", marker = "extra == 'cuda12'", specifier = "==12.9.86" }, { name = "objgraph", specifier = ">=3.6.2" }, { name = "openmm", specifier = ">=8.4" }, { name = "openmm", marker = "extra == 'cuda12'", specifier = "==8.4.0" }, @@ -2704,10 +2801,10 @@ requires-dist = [ { name = "pytest", marker = "extra == 'test'" }, { name = "pytest-cov", marker = "extra == 'test'" }, { name = "pyyaml", specifier = ">=6.0" }, - { name = "rcsbsearchapi", marker = "extra == 'notebook'", specifier = ">=1.5.1" }, - { name = "torch", specifier = ">=2.2" }, + { name = "rcsb-api", specifier = ">=1.7.3" }, + { name = "torch", specifier = ">=2.2,<2.11" }, { name = "torchsummary", specifier = ">=1.5" }, - { name = "torchvision", specifier = ">=0.17" }, + { name = "torchvision", specifier = ">=0.17,<0.27" }, { name = "tqdm", specifier = ">=4.67" }, ] provides-extras = ["cuda12", "test", "notebook"] @@ -2767,6 +2864,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992", size = 48172, upload-time = "2026-01-21T14:26:50.693Z" }, ] +[[package]] +name = "pydot" +version = "4.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyparsing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/50/35/b17cb89ff865484c6a20ef46bf9d95a5f07328292578de0b295f4a6beec2/pydot-4.0.1.tar.gz", hash = "sha256:c2148f681c4a33e08bf0e26a9e5f8e4099a82e0e2a068098f32ce86577364ad5", size = 162594, upload-time = "2025-06-17T20:09:56.454Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/32/a7125fb28c4261a627f999d5fb4afff25b523800faed2c30979949d6facd/pydot-4.0.1-py3-none-any.whl", hash = "sha256:869c0efadd2708c0be1f916eb669f3d664ca684bc57ffb7ecc08e70d5e93fee6", size = 37087, upload-time = "2025-06-17T20:09:55.25Z" }, +] + [[package]] name = "pyedr" version = "0.8.0" @@ -3016,14 +3125,20 @@ wheels = [ ] [[package]] -name = "rcsbsearchapi" -version = "2.0.1" +name = "rcsb-api" +version = "1.7.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "requests" }, + { name = "graphql-core" }, + { name = "httpx" }, + { name = "nest-asyncio" }, + { name = "rustworkx" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e1/64/e5b4009eac83ce59eb60a1e6774a3cfe8b711355c6d48b42a6d326716367/rcsbsearchapi-2.0.1.tar.gz", hash = "sha256:50dac1e60f58cbaae93af304ceff1b3ae18611c27477c802a066eb2c79ff32cf", size = 182171, upload-time = "2025-03-26T15:20:04.735Z" } +sdist = { url = "https://files.pythonhosted.org/packages/43/7c/30686058d4fac954f7123e1417a991a1b60a4e36c6c3b268c759f141eb6e/rcsb_api-1.7.3.tar.gz", hash = "sha256:b9b2c147c098a4c482e34a2c5998b8680d3386c3550dce57fe9d57359eb59d0b", size = 601711, upload-time = "2026-05-05T14:38:53.751Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/75/8e/76484df93fa1dab7ce9ec32231b04bcd4b488a5932b56838059f2c692ee4/rcsb_api-1.7.3-py3-none-any.whl", hash = "sha256:674996362c648178bfbe2c4752ae72b37edccb1348a006174716fbcf78620103", size = 70861, upload-time = "2026-05-05T14:38:52.35Z" }, +] [[package]] name = "rdkit" @@ -3114,6 +3229,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/71/44ce230e1b7fadd372515a97e32a83011f906ddded8d03e3c6aafbdedbb7/rfc3987_syntax-1.1.0-py3-none-any.whl", hash = "sha256:6c3d97604e4c5ce9f714898e05401a0445a641cfa276432b0a648c80856f6a3f", size = 8046, upload-time = "2025-07-18T01:05:03.843Z" }, ] +[[package]] +name = "rich" +version = "15.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/8f/0722ca900cc807c13a6a0c696dacf35430f72e0ec571c4275d2371fca3e9/rich-15.0.0.tar.gz", hash = "sha256:edd07a4824c6b40189fb7ac9bc4c52536e9780fbbfbddf6f1e2502c31b068c36", size = 230680, upload-time = "2026-04-12T08:24:00.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" }, +] + [[package]] name = "rpds-py" version = "0.30.0" @@ -3222,6 +3350,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/b7/b95708304cd49b7b6f82fdd039f1748b66ec2b21d6a45180910802f1abf1/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac37f9f516c51e5753f27dfdef11a88330f04de2d564be3991384b2f3535d02e", size = 562191, upload-time = "2025-11-30T20:24:36.853Z" }, ] +[[package]] +name = "rustworkx" +version = "0.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e7/b0/66d96f02120f79eeed86b5c5be04029b6821155f31ed4907a4e9f1460671/rustworkx-0.17.1.tar.gz", hash = "sha256:59ea01b4e603daffa4e8827316c1641eef18ae9032f0b1b14aa0181687e3108e", size = 399407, upload-time = "2025-09-15T16:29:46.429Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/24/8972ed631fa05fdec05a7bb7f1fc0f8e78ee761ab37e8a93d1ed396ba060/rustworkx-0.17.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c08fb8db041db052da404839b064ebfb47dcce04ba9a3e2eb79d0c65ab011da4", size = 2257491, upload-time = "2025-08-13T01:43:31.466Z" }, + { url = "https://files.pythonhosted.org/packages/23/ae/7b6bbae5e0487ee42072dc6a46edf5db9731a0701ed648db22121fb7490c/rustworkx-0.17.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:4ef8e327dadf6500edd76fedb83f6d888b9266c58bcdbffd5a40c33835c9dd26", size = 2040175, upload-time = "2025-08-13T01:43:33.762Z" }, + { url = "https://files.pythonhosted.org/packages/cd/ea/c17fb9428c8f0dcc605596f9561627a5b9ef629d356204ee5088cfcf52c6/rustworkx-0.17.1-cp39-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b809e0aa2927c68574b196f993233e269980918101b0dd235289c4f3ddb2115", size = 2324771, upload-time = "2025-08-13T01:43:35.553Z" }, + { url = "https://files.pythonhosted.org/packages/d7/40/ec8b3b8b0f8c0b768690c454b8dcc2781b4f2c767f9f1215539c7909e35b/rustworkx-0.17.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7e82c46a92fb0fd478b7372e15ca524c287485fdecaed37b8bb68f4df2720f2", size = 2068584, upload-time = "2025-08-13T01:43:37.261Z" }, + { url = "https://files.pythonhosted.org/packages/d9/22/713b900d320d06ce8677e71bba0ec5df0037f1d83270bff5db3b271c10d7/rustworkx-0.17.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:42170075d8a7319e89ff63062c2f1d1116ced37b6f044f3bf36d10b60a107aa4", size = 2380949, upload-time = "2025-08-13T01:52:17.435Z" }, + { url = "https://files.pythonhosted.org/packages/20/4b/54be84b3b41a19caf0718a2b6bb280dde98c8626c809c969f16aad17458f/rustworkx-0.17.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65cba97fa95470239e2d65eb4db1613f78e4396af9f790ff771b0e5476bfd887", size = 2562069, upload-time = "2025-08-13T02:09:27.222Z" }, + { url = "https://files.pythonhosted.org/packages/39/5b/281bb21d091ab4e36cf377088366d55d0875fa2347b3189c580ec62b44c7/rustworkx-0.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246cc252053f89e36209535b9c58755960197e6ae08d48d3973760141c62ac95", size = 2221186, upload-time = "2025-08-13T01:43:38.598Z" }, + { url = "https://files.pythonhosted.org/packages/cc/2d/30a941a21b81e9db50c4c3ef8a64c5ee1c8eea3a90506ca0326ce39d021f/rustworkx-0.17.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c10d25e9f0e87d6a273d1ea390b636b4fb3fede2094bf0cb3fe565d696a91b48", size = 2123510, upload-time = "2025-08-13T01:43:40.288Z" }, + { url = "https://files.pythonhosted.org/packages/4f/ef/c9199e4b6336ee5a9f1979c11b5779c5cf9ab6f8386e0b9a96c8ffba7009/rustworkx-0.17.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:48784a673cf8d04f3cd246fa6b53fd1ccc4d83304503463bd561c153517bccc1", size = 2302783, upload-time = "2025-08-13T01:43:42.073Z" }, + { url = "https://files.pythonhosted.org/packages/30/3d/a49ab633e99fca4ccbb9c9f4bd41904186c175ebc25c530435529f71c480/rustworkx-0.17.1-cp39-abi3-win32.whl", hash = "sha256:5dbc567833ff0a8ad4580a4fe4bde92c186d36b4c45fca755fb1792e4fafe9b5", size = 1931541, upload-time = "2025-08-13T01:43:43.415Z" }, + { url = "https://files.pythonhosted.org/packages/a9/ec/cee878c1879b91ab8dc7d564535d011307839a2fea79d2a650413edf53be/rustworkx-0.17.1-cp39-abi3-win_amd64.whl", hash = "sha256:d0a48fb62adabd549f9f02927c3a159b51bf654c7388a12fc16d45452d5703ea", size = 2055049, upload-time = "2025-08-13T01:43:44.926Z" }, +] + [[package]] name = "scikit-learn" version = "1.8.0" @@ -3407,6 +3557,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, ] +[[package]] +name = "statsmodels" +version = "0.14.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "patsy" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0d/81/e8d74b34f85285f7335d30c5e3c2d7c0346997af9f3debf9a0a9a63de184/statsmodels-0.14.6.tar.gz", hash = "sha256:4d17873d3e607d398b85126cd4ed7aad89e4e9d89fc744cdab1af3189a996c2a", size = 20689085, upload-time = "2025-12-05T23:08:39.522Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/4d/df4dd089b406accfc3bb5ee53ba29bb3bdf5ae61643f86f8f604baa57656/statsmodels-0.14.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6ad5c2810fc6c684254a7792bf1cbaf1606cdee2a253f8bd259c43135d87cfb4", size = 10121514, upload-time = "2025-12-05T19:28:16.521Z" }, + { url = "https://files.pythonhosted.org/packages/82/af/ec48daa7f861f993b91a0dcc791d66e1cf56510a235c5cbd2ab991a31d5c/statsmodels-0.14.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:341fa68a7403e10a95c7b6e41134b0da3a7b835ecff1eb266294408535a06eb6", size = 10003346, upload-time = "2025-12-05T19:28:29.568Z" }, + { url = "https://files.pythonhosted.org/packages/a9/2c/c8f7aa24cd729970728f3f98822fb45149adc216f445a9301e441f7ac760/statsmodels-0.14.6-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdf1dfe2a3ca56f5529118baf33a13efed2783c528f4a36409b46bbd2d9d48eb", size = 10129872, upload-time = "2025-12-05T23:09:25.724Z" }, + { url = "https://files.pythonhosted.org/packages/40/c6/9ae8e9b0721e9b6eb5f340c3a0ce8cd7cce4f66e03dd81f80d60f111987f/statsmodels-0.14.6-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3764ba8195c9baf0925a96da0743ff218067a269f01d155ca3558deed2658ca", size = 10381964, upload-time = "2025-12-05T23:09:41.326Z" }, + { url = "https://files.pythonhosted.org/packages/28/8c/cf3d30c8c2da78e2ad1f50ade8b7fabec3ff4cdfc56fbc02e097c4577f90/statsmodels-0.14.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9e8d2e519852adb1b420e018f5ac6e6684b2b877478adf7fda2cfdb58f5acb5d", size = 10409611, upload-time = "2025-12-05T23:09:57.131Z" }, + { url = "https://files.pythonhosted.org/packages/bf/cc/018f14ecb58c6cb89de9d52695740b7d1f5a982aa9ea312483ea3c3d5f77/statsmodels-0.14.6-cp311-cp311-win_amd64.whl", hash = "sha256:2738a00fca51196f5a7d44b06970ace6b8b30289839e4808d656f8a98e35faa7", size = 9580385, upload-time = "2025-12-05T19:28:42.778Z" }, + { url = "https://files.pythonhosted.org/packages/25/ce/308e5e5da57515dd7cab3ec37ea2d5b8ff50bef1fcc8e6d31456f9fae08e/statsmodels-0.14.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fe76140ae7adc5ff0e60a3f0d56f4fffef484efa803c3efebf2fcd734d72ecb5", size = 10091932, upload-time = "2025-12-05T19:28:55.446Z" }, + { url = "https://files.pythonhosted.org/packages/05/30/affbabf3c27fb501ec7b5808230c619d4d1a4525c07301074eb4bda92fa9/statsmodels-0.14.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:26d4f0ed3b31f3c86f83a92f5c1f5cbe63fc992cd8915daf28ca49be14463a1c", size = 9997345, upload-time = "2025-12-05T19:29:10.278Z" }, + { url = "https://files.pythonhosted.org/packages/48/f5/3a73b51e6450c31652c53a8e12e24eac64e3824be816c0c2316e7dbdcb7d/statsmodels-0.14.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d8c00a42863e4f4733ac9d078bbfad816249c01451740e6f5053ecc7db6d6368", size = 10058649, upload-time = "2025-12-05T23:10:12.775Z" }, + { url = "https://files.pythonhosted.org/packages/81/68/dddd76117df2ef14c943c6bbb6618be5c9401280046f4ddfc9fb4596a1b8/statsmodels-0.14.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:19b58cf7474aa9e7e3b0771a66537148b2df9b5884fbf156096c0e6c1ff0469d", size = 10339446, upload-time = "2025-12-05T23:10:28.503Z" }, + { url = "https://files.pythonhosted.org/packages/56/4a/dce451c74c4050535fac1ec0c14b80706d8fc134c9da22db3c8a0ec62c33/statsmodels-0.14.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:81e7dcc5e9587f2567e52deaff5220b175bf2f648951549eae5fc9383b62bc37", size = 10368705, upload-time = "2025-12-05T23:10:44.339Z" }, + { url = "https://files.pythonhosted.org/packages/60/15/3daba2df40be8b8a9a027d7f54c8dedf24f0d81b96e54b52293f5f7e3418/statsmodels-0.14.6-cp312-cp312-win_amd64.whl", hash = "sha256:b5eb07acd115aa6208b4058211138393a7e6c2cf12b6f213ede10f658f6a714f", size = 9543991, upload-time = "2025-12-05T23:10:58.536Z" }, + { url = "https://files.pythonhosted.org/packages/81/59/a5aad5b0cc266f5be013db8cde563ac5d2a025e7efc0c328d83b50c72992/statsmodels-0.14.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:47ee7af083623d2091954fa71c7549b8443168f41b7c5dce66510274c50fd73e", size = 10072009, upload-time = "2025-12-05T23:11:14.021Z" }, + { url = "https://files.pythonhosted.org/packages/53/dd/d8cfa7922fc6dc3c56fa6c59b348ea7de829a94cd73208c6f8202dd33f17/statsmodels-0.14.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:aa60d82e29fcd0a736e86feb63a11d2380322d77a9369a54be8b0965a3985f71", size = 9980018, upload-time = "2025-12-05T23:11:30.907Z" }, + { url = "https://files.pythonhosted.org/packages/ee/77/0ec96803eba444efd75dba32f2ef88765ae3e8f567d276805391ec2c98c6/statsmodels-0.14.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89ee7d595f5939cc20bf946faedcb5137d975f03ae080f300ebb4398f16a5bd4", size = 10060269, upload-time = "2025-12-05T23:11:46.338Z" }, + { url = "https://files.pythonhosted.org/packages/10/b9/fd41f1f6af13a1a1212a06bb377b17762feaa6d656947bf666f76300fc05/statsmodels-0.14.6-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:730f3297b26749b216a06e4327fe0be59b8d05f7d594fb6caff4287b69654589", size = 10324155, upload-time = "2025-12-05T23:12:01.805Z" }, + { url = "https://files.pythonhosted.org/packages/ee/0f/a6900e220abd2c69cd0a07e3ad26c71984be6061415a60e0f17b152ecf08/statsmodels-0.14.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f1c08befa85e93acc992b72a390ddb7bd876190f1360e61d10cf43833463bc9c", size = 10349765, upload-time = "2025-12-05T23:12:18.018Z" }, + { url = "https://files.pythonhosted.org/packages/98/08/b79f0c614f38e566eebbdcff90c0bcacf3c6ba7a5bbb12183c09c29ca400/statsmodels-0.14.6-cp313-cp313-win_amd64.whl", hash = "sha256:8021271a79f35b842c02a1794465a651a9d06ec2080f76ebc3b7adce77d08233", size = 9540043, upload-time = "2025-12-05T23:12:33.887Z" }, + { url = "https://files.pythonhosted.org/packages/71/de/09540e870318e0c7b58316561d417be45eff731263b4234fdd2eee3511a8/statsmodels-0.14.6-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:00781869991f8f02ad3610da6627fd26ebe262210287beb59761982a8fa88cae", size = 10069403, upload-time = "2025-12-05T23:12:48.424Z" }, + { url = "https://files.pythonhosted.org/packages/ab/f0/63c1bfda75dc53cee858006e1f46bd6d6f883853bea1b97949d0087766ca/statsmodels-0.14.6-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:73f305fbf31607b35ce919fae636ab8b80d175328ed38fdc6f354e813b86ee37", size = 9989253, upload-time = "2025-12-05T23:13:05.274Z" }, + { url = "https://files.pythonhosted.org/packages/c1/98/b0dfb4f542b2033a3341aa5f1bdd97024230a4ad3670c5b0839d54e3dcab/statsmodels-0.14.6-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e443e7077a6e2d3faeea72f5a92c9f12c63722686eb80bb40a0f04e4a7e267ad", size = 10090802, upload-time = "2025-12-05T23:13:20.653Z" }, + { url = "https://files.pythonhosted.org/packages/34/0e/2408735aca9e764643196212f9069912100151414dd617d39ffc72d77eee/statsmodels-0.14.6-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3414e40c073d725007a6603a18247ab7af3467e1af4a5e5a24e4c27bc26673b4", size = 10337587, upload-time = "2025-12-05T23:13:37.597Z" }, + { url = "https://files.pythonhosted.org/packages/0f/36/4d44f7035ab3c0b2b6a4c4ebb98dedf36246ccbc1b3e2f51ebcd7ac83abb/statsmodels-0.14.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a518d3f9889ef920116f9fa56d0338069e110f823926356946dae83bc9e33e19", size = 10363350, upload-time = "2025-12-05T23:13:53.08Z" }, + { url = "https://files.pythonhosted.org/packages/26/33/f1652d0c59fa51de18492ee2345b65372550501ad061daa38f950be390b6/statsmodels-0.14.6-cp314-cp314-win_amd64.whl", hash = "sha256:151b73e29f01fe619dbce7f66d61a356e9d1fe5e906529b78807df9189c37721", size = 9588010, upload-time = "2025-12-05T23:14:07.28Z" }, +] + [[package]] name = "sympy" version = "1.14.0" @@ -3787,6 +3976,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl", hash = "sha256:8156704e4346a571d9ce73b84bee86a29906c9abfd7223b7228a28899ccf3366", size = 2196503, upload-time = "2025-11-01T21:15:53.565Z" }, ] +[[package]] +name = "xlsxwriter" +version = "3.2.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/46/2c/c06ef49dc36e7954e55b802a8b231770d286a9758b3d936bd1e04ce5ba88/xlsxwriter-3.2.9.tar.gz", hash = "sha256:254b1c37a368c444eac6e2f867405cc9e461b0ed97a3233b2ac1e574efb4140c", size = 215940, upload-time = "2025-09-16T00:16:21.63Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/0c/3662f4a66880196a590b202f0db82d919dd2f89e99a27fadef91c4a33d41/xlsxwriter-3.2.9-py3-none-any.whl", hash = "sha256:9a5db42bc5dff014806c58a20b9eae7322a134abb6fce3c92c181bfb275ec5b3", size = 175315, upload-time = "2025-09-16T00:16:20.108Z" }, +] + [[package]] name = "zipp" version = "3.23.0"