diff --git a/.gitignore b/.gitignore index 15fc612..f379760 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,8 @@ pytest-*.xml !web/public/assets/**/*.svg node_modules/ .next/ + +# Sample data +!data/sample/peptide-metadata.tsv +!data/sample/reactivity-zscore.tsv +!data/sample/protein-targets.fasta diff --git a/README.md b/README.md index 86c937b..ce00121 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ The repository profile is the source of truth for reproducing experiments end-to Stage 1 normalize dataset inputs (PV1/CWP/BKP) to a shared training contract Stage 2 generate ESM-2 per-residue embeddings Stage 3 build residue-level label shards -Stage 4 train FFNN (seeded or ensemble-kfold, DDP-aware) +Stage 4 train FFNN (unified n-fold interface, DDP-aware) Stage 5 optional Optuna tuning (DDP-aware) Stage 6 predict residue masks from checkpoint/manifest Stage 7 evaluate residue metrics (+ optional Cocci peptide compare) @@ -62,6 +62,8 @@ Stage 7 evaluate residue metrics (+ optional Cocci peptide compare) ### Stage 1: Multi-Dataset Prepare (PV1/CWP/BKP) +***Note:** [`data/sample/`](data/sample/) contains the expected dataset sample formats. + **CLI:** `pepseqpred-prepare-dataset` (`src/pepseqpred/apps/prepare_dataset_cli.py`) This stage is the recommended entrypoint when training on one or more of: @@ -70,7 +72,7 @@ This stage is the recommended entrypoint when training on one or more of: - CWP/Cocci (fungal) - BKP (bacterial) -It normalizes source-specific metadata and FASTA headers into a shared PV1-compatible contract so downstream embedding, label generation, and training CLIs can be reused unchanged. +It normalizes source-specific metadata and FASTA headers into a shared PV1-compatible contract (i.e., ID= AC= OXX=) so downstream embedding, label generation, and training CLIs can be reused unchanged. **Core module** @@ -91,11 +93,11 @@ It normalizes source-specific metadata and FASTA headers into a shared PV1-compa ```bash pepseqpred-prepare-dataset \ - localdata/PV1/PV1_meta_2020-11-23_cleaned.tsv \ - localdata/PV1/prepared \ + data/PV1/PV1_meta_2020-11-23_cleaned.tsv \ + data/PV1/prepared \ --dataset-kind pv1 \ - --protein-fasta localdata/PV1/PV1_targets.fasta \ - --z-file localdata/PV1/PV1_zscores.tsv + --protein-fasta data/PV1/PV1_targets.fasta \ + --z-file data/PV1/PV1_zscores.tsv ``` **CWP/Cocci inputs and command** @@ -107,12 +109,12 @@ pepseqpred-prepare-dataset \ ```bash pepseqpred-prepare-dataset \ - localdata/Cocci/CWP_metadata.tsv \ - localdata/Cocci/prepared \ + data/Cocci/CWP_metadata.tsv \ + data/Cocci/prepared \ --dataset-kind cwp \ - --protein-fasta localdata/Cocci/CWP_targets.faa \ - --reactive-codes localdata/Cocci/CWP_reactive_Z20N4.tsv \ - --nonreactive-codes localdata/Cocci/CWP_nonreactive_Z20N4.tsv + --protein-fasta data/Cocci/CWP_targets.faa \ + --reactive-codes data/Cocci/CWP_reactive_Z20N4.tsv \ + --nonreactive-codes data/Cocci/CWP_nonreactive_Z20N4.tsv ``` **BKP inputs and command** @@ -124,12 +126,12 @@ pepseqpred-prepare-dataset \ ```bash pepseqpred-prepare-dataset \ - localdata/BKP/BKP_metadata.tsv \ - localdata/BKP/prepared \ + data/BKP/BKP_metadata.tsv \ + data/BKP/prepared \ --dataset-kind bkp \ - --protein-fasta localdata/BKP/BKP.faa \ - --reactive-codes localdata/BKP/BKP_reactive_Z20N4.tsv \ - --nonreactive-codes localdata/BKP/BKP_nonreactive_Z20N4.tsv + --protein-fasta data/BKP/BKP.faa \ + --reactive-codes data/BKP/BKP_reactive_Z20N4.tsv \ + --nonreactive-codes data/BKP/BKP_nonreactive_Z20N4.tsv ``` **Dataset-specific grouping used for leakage-aware splitting (`--split-type id-family`)** @@ -150,7 +152,7 @@ pepseqpred-prepare-dataset \ **Inputs** -- metadata TSV (PV1-style) +- metadata TSV - z-score TSV **Core modules** @@ -190,7 +192,7 @@ pepseqpred-preprocess data/meta.tsv data/zscores.tsv --save ```bash pepseqpred-esm \ --fasta-file data/targets.fasta \ - --out-dir localdata/esm2 \ + --out-dir data/esm2 \ --embedding-key-mode id-family \ --key-delimiter - \ --model-name esm2_t33_650M_UR50D \ @@ -222,8 +224,8 @@ pepseqpred-esm \ ```bash pepseqpred-labels \ data/input_data_20_4_10_all.tsv \ - localdata/labels/labels_shard_000.pt \ - --emb-dir localdata/esm2/artifacts/pts/shard_000 \ + data/labels/labels_shard_000.pt \ + --emb-dir data/esm2/artifacts/pts/shard_000 \ --restrict-to-embeddings \ --calc-pos-weight \ --embedding-key-delim - @@ -238,10 +240,11 @@ pepseqpred-labels \ **CLI:** `pepseqpred-train-ffnn` (`src/pepseqpred/apps/train_ffnn_cli.py`) -**Modes** +**Unified run interface** -- `seeded`: split/train seed pairs define runs -- `ensemble-kfold`: K-fold members per set, optional multiple set seeds +- `--n-folds 1`: one holdout run per split/train seed pair (uses `--val-frac`) +- `--n-folds K` (`K > 1`): K-fold members per split/train seed pair set +- `--split-seeds` and `--train-seeds` are paired by index; if both are omitted, both default to `--seed` **Core modules** @@ -253,12 +256,12 @@ pepseqpred-labels \ ```bash pepseqpred-train-ffnn \ - --embedding-dirs localdata/esm2/artifacts/pts/shard_000 \ - --label-shards localdata/labels/labels_shard_000.pt \ + --embedding-dirs data/esm2/artifacts/pts/shard_000 \ + --label-shards data/labels/labels_shard_000.pt \ --epochs 1 \ --subset 100 \ - --save-path localdata/models/ffnn_smoke \ - --results-csv localdata/models/ffnn_smoke/runs.csv + --save-path data/models/ffnn_smoke \ + --results-csv data/models/ffnn_smoke/runs.csv ``` **Submit one SLURM training job with multiple datasets (PV1 + CWP + BKP)** @@ -314,7 +317,7 @@ Notes: - run checkpoint(s), usually `fully_connected.pt` - per-run CSV (`runs.csv` or `multi_run_results.csv`) - aggregate `multi_run_summary.json` -- ensemble manifest JSON in `ensemble-kfold` mode +- ensemble manifest JSON when `--n-folds > 1` ### Stage 5: Optuna Tuning (Optional) @@ -329,14 +332,50 @@ Notes: ```bash pepseqpred-train-ffnn-optuna \ - --embedding-dirs localdata/esm2/artifacts/pts/shard_000 \ - --label-shards localdata/labels/labels_shard_000.pt \ + --embedding-dirs data/esm2/artifacts/pts/shard_000 \ + --label-shards data/labels/labels_shard_000.pt \ --n-trials 2 \ --epochs 1 \ - --save-path localdata/models/optuna_smoke \ - --csv-path localdata/models/optuna_smoke/trials.csv + --save-path data/models/optuna_smoke \ + --csv-path data/models/optuna_smoke/trials.csv ``` +**Current Optuna search space** + +The current study samples the following hyperparameters per trial: + +| Hyperparameter (`best_params` key) | Type | Search space (current implementation) | Controlled by | +| --- | --- | --- | --- | +| `depth` | integer | `[depth_min, depth_max]` | `--depth-min`, `--depth-max` | +| `width_step` | categorical | `{16, 32, 64}` | fixed in code | +| `base_width` | integer | `[width_min, width_max]` with `step=width_step` | `--width-min`, `--width-max` | +| `shape_ratio` | float | `[0.60, 0.95]` | sampled only when `--arch-mode` is `bottleneck` or `pyramid` | +| `dropout` | float | `[0.00, 0.25]` | fixed in code | +| `use_layer_norm` | categorical | `{True, False}` | fixed in code | +| `use_residual` | categorical | `{True, False}` | fixed in code | +| `learning_rate` | float (log) | `[lr_min, lr_max]` | `--lr-min`, `--lr-max` | +| `weight_decay` | float (log) | `[wd_min, wd_max]` | `--wd-min`, `--wd-max` | +| `batch_size` | categorical | values from `--batch-sizes` CSV | `--batch-sizes` | + +Architecture shaping behavior: + +- `--arch-mode flat`: hidden widths are `[base_width] * depth` +- `--arch-mode bottleneck`: widths decrease by `shape_ratio` across layers +- `--arch-mode pyramid`: widths increase by `shape_ratio` across layers + +Not tuned by Optuna in the current setup: + +- `pos_weight` (fixed for the study via `--pos-weight`, or computed once from label shards if omitted) +- split strategy and validation fraction (`--split-type`, `--val-frac`) +- sequence windowing (`--window-size`, `--stride`) and data-loader behavior +- trial budget/pruning controls (`--n-trials`, `--epochs`, `--pruner-warmup`, `--timeout-s`) +- optimization target metric selection (`--metric`) is user-selected, then maximized by Optuna + +HPC default override note: + +- The CLI default for `--batch-sizes` is `32,64,128`. +- The SLURM wrapper `scripts/hpc/trainffnnoptuna.sh` currently overrides this to `256,512,1024` unless changed via env var. + **Outputs** - trial rows CSV @@ -362,9 +401,9 @@ pepseqpred-train-ffnn-optuna \ ```bash pepseqpred-predict \ - localdata/models/run_001/fully_connected.pt \ + data/models/run_001/fully_connected.pt \ data/inference_targets.fasta \ - --output-fasta localdata/predictions/predictions.fasta + --output-fasta data/predictions/predictions.fasta ``` **Output** @@ -393,10 +432,10 @@ pepseqpred-predict \ ```bash pepseqpred-eval-ffnn \ - localdata/models/run_001/fully_connected.pt \ - --embedding-dirs localdata/esm2/artifacts/pts/shard_000 \ - --label-shards localdata/labels/labels_shard_000.pt \ - --output-json localdata/eval/ffnn_eval_summary.json + data/models/run_001/fully_connected.pt \ + --embedding-dirs data/esm2/artifacts/pts/shard_000 \ + --label-shards data/labels/labels_shard_000.pt \ + --output-json data/eval/ffnn_eval_summary.json ``` **Output** @@ -473,7 +512,7 @@ Bundled pretrained registry currently includes: | `pepseqpred-preprocess` | `apps/preprocess_cli.py` | metadata + z-score preprocessing | | `pepseqpred-esm` | `apps/esm_cli.py` | ESM-2 embedding generation | | `pepseqpred-labels` | `apps/labels_cli.py` | residue label shard generation | -| `pepseqpred-train-ffnn` | `apps/train_ffnn_cli.py` | seeded or ensemble-kfold training | +| `pepseqpred-train-ffnn` | `apps/train_ffnn_cli.py` | unified holdout/K-fold FFNN training (`--n-folds`) | | `pepseqpred-train-ffnn-optuna` | `apps/train_ffnn_optuna_cli.py` | Optuna tuning | | `pepseqpred-predict` | `apps/prediction_cli.py` | FASTA inference from checkpoint/manifest | | `pepseqpred-eval-ffnn` | `apps/evaluate_ffnn_cli.py` | residue-level evaluation | @@ -490,7 +529,7 @@ These wrappers are production-facing interfaces and should be treated as first-c | `trainffnnoptuna.sh` | Optuna | GPU, `4xa100`, `20` CPU, `448G`, `48:00:00` | | `predictepitope.sh` | Predict | GPU, `a100`, `4` CPU, `32G`, `00:30:00` | | `evaluateffnn.sh` | End-to-end eval pipeline | GPU, `a100`, `8` CPU, `128G`, `04:00:00` | -| `evalffnnsweep.sh` | Seeded eval batch submitter | wrapper script (calls `evaluateffnn.sh`) | +| `evalffnnsweep.sh` | Set-indexed eval batch submitter | wrapper script (calls `evaluateffnn.sh`) | | `preprocessdata.sh` | Preprocess helper | local helper, not a SLURM script | ### Important HPC notes diff --git a/localdata/.gitkeep b/data/.gitkeep similarity index 100% rename from localdata/.gitkeep rename to data/.gitkeep diff --git a/data/sample/peptide-metadata.tsv b/data/sample/peptide-metadata.tsv new file mode 100644 index 0000000..d107ade --- /dev/null +++ b/data/sample/peptide-metadata.tsv @@ -0,0 +1,2 @@ +CodeName Category SpeciesID Species Protein AlignStart AlignStop FullName Peptide Encoding +PV1_000673 SetCover 130310 Human mastadenovirus D ID=A0A2Z5WIK7_ADE08 AC=A0A2Z5WIK7 OXX=31545,130310,10509,10508_157_187 FNHTCNIQNLTLLFVNLTHNGAYIGYTKDG TTCAACCATACTTGCAACATTCAGAACCTGACCCTGCTGTTCGTAAACCTGACCCACAACGGTGCGTATATCGGCTACACCAAAGACGGT \ No newline at end of file diff --git a/data/sample/protein-targets.fasta b/data/sample/protein-targets.fasta new file mode 100644 index 0000000..c2bfe70 --- /dev/null +++ b/data/sample/protein-targets.fasta @@ -0,0 +1,2 @@ +>ID=A0A2Z5WIK7_ADE08 AC=A0A2Z5WIK7 OXX=31545,130310,10509,10508 +MNTLTSVVLLSLLVAFSQAGIINLNVLWGINLTLVGPLDLPVTWYDKKGMQFCIGNTIKNPQIKHSCDQQNLTLLNADKSHERTYLGYRHDSKGKVDYKVTVIPPPPTTRKPLSEPHYVTVTMDHNITLVGPLNLPVTWYDGEGNKFCDGEKVEHAEFNHTCNIQNLTLLFVNLTHNGAYIGYTKDGSDRELYEVSVKTLFKNGAKQSKVEQGNTAQSGGKKTKTEHTNHSAKTKSTNNLQPTQLYVRPFTNVSLTGPPNGKVIWYDGELNDPCEQKYKLRTFCNQQNLTLINVTSTYNGIYYGTDEKDKANRYRIKVNTTNHKTVKIKPHTKKPSAKQEKQFELQVTKTNKNQSQIPSATVAIVAGVIAGFVTLIIVFLCYICCRKRSRAYNHMVDPLLSFSY \ No newline at end of file diff --git a/data/sample/reactivity-zscore.tsv b/data/sample/reactivity-zscore.tsv new file mode 100644 index 0000000..3726364 --- /dev/null +++ b/data/sample/reactivity-zscore.tsv @@ -0,0 +1,2 @@ +Sequence name VW_001 VW_002 VW_003 VW_004 VW_006 VW_007 VW_008 VW_009 VW_010 VW_011 VW_012 VW_013 VW_014 VW_015 VW_016 VW_017 VW_018 VW_019 VW_020 VW_021 VW_022 VW_023 VW_024 VW_025 VW_026 VW_027 VW_028 VW_029 VW_030 VW_031 VW_032 VW_033 VW_034 VW_035 VW_036 VW_037 VW_038 VW_039 VW_040 VW_041 VW_042 VW_043 VW_044 VW_045 VW_046 VW_047 VW_048 VW_049 VW_050 VW_051 VW_053 VW_054 VW_055 VW_056 VW_057 VW_058 VW_059 VW_060 VW_062 VW_063 VW_064 VW_065 VW_066 VW_067 VW_068 VW_069 VW_070 VW_071 VW_072 VW_073 VW_074 VW_075 VW_076 VW_077 VW_078 VW_079 VW_080 VW_081 VW_082 VW_083 VW_084 VW_085 VW_086 VW_087 VW_088 VW_089 VW_090 VW_091 VW_092 VW_093 VW_094 VW_095 VW_096 VW_097 VW_098 VW_099 VW_100 VW_101 VW_102 VW_103 VW_104 VW_105 VW_106 VW_107 VW_108 VW_109 VW_110 VW_111 VW_112 VW_113 VW_114 VW_115 VW_116 VW_117 VW_118 VW_119 VW_120 VW_121 VW_122 VW_123 VW_124 VW_125 VW_126 VW_127 VW_128 VW_129 VW_130 VW_131 VW_132 VW_133 VW_134 VW_135 VW_136 VW_137 VW_138 VW_139 VW_140 VW_141 VW_142 VW_143 VW_144 VW_145 VW_146 VW_148 VW_149 VW_150 VW_151 VW_152 VW_153 VW_154 VW_155 VW_156 VW_157 VW_158 VW_159 VW_160 VW_161 VW_162 VW_164 VW_165 VW_166 VW_168 VW_169 VW_170 VW_171 VW_172 VW_173 VW_174 VW_175 VW_176 VW_177 VW_178 VW_179 VW_180 VW_181 VW_182 VW_183 VW_184 VW_185 VW_186 VW_187 VW_188 VW_189 VW_190 VW_191 VW_192 VW_193 VW_194 VW_195 VW_196 VW_197 VW_198 VW_199 VW_200 VW_201 VW_202 VW_203 VW_204 VW_205 VW_206 VW_207 VW_208 VW_209 VW_210 VW_211 VW_212 VW_213 VW_214 VW_215 VW_216 VW_217 VW_218 VW_219 VW_220 VW_221 VW_222 VW_223 VW_224 VW_225 VW_226 VW_227 VW_228 VW_229 VW_230 VW_231 VW_232 VW_233 VW_234 VW_235 VW_236 VW_237 VW_238 VW_239 VW_240 VW_241 VW_242 VW_243 VW_244 VW_245 VW_246 VW_247 VW_248 VW_249 VW_250 VW_251 VW_252 VW_253 VW_254 VW_255 VW_256 VW_257 VW_258 VW_259 VW_260 VW_261 VW_262 VW_263 VW_264 VW_265 VW_266 VW_267 VW_269 VW_270 VW_271 VW_272 VW_273 VW_275 VW_276 VW_277 VW_278 VW_279 VW_280 VW_281 VW_282 VW_283 VW_284 VW_285 VW_286 VW_287 VW_288 VW_289 VW_290 VW_291 VW_292 VW_293 VW_294 VW_295 VW_296 VW_297 VW_298 VW_299 VW_300 VW_301 VW_302 VW_303 VW_304 VW_305 VW_306 VW_307 VW_308 VW_309 VW_310 VW_311 VW_312 VW_313 VW_314 VW_315 VW_316 VW_317 VW_318 VW_319 VW_320 VW_321 VW_322 VW_323 VW_324 VW_325 VW_326 VW_327 VW_328 VW_329 VW_330 VW_331 VW_332 VW_333 VW_334 VW_335 VW_336 VW_337 VW_338 VW_339 VW_340 VW_341 VW_342 VW_344 VW_345 VW_346 VW_347 VW_348 VW_349 VW_350 VW_351 VW_352 VW_353 VW_354 VW_355 VW_356 VW_357 VW_358 VW_359 VW_360 VW_361 VW_362 VW_363 VW_364 VW_365 VW_366 VW_367 VW_368 VW_369 VW_370 VW_371 VW_372 VW_373 VW_374 VW_375 VW_376 VW_377 VW_378 VW_379 VW_380 VW_381 VW_382 VW_383 VW_384 VW_385 VW_386 VW_387 VW_388 VW_389 VW_390 VW_391 VW_392 VW_393 VW_394 VW_395 VW_397 VW_398 VW_399 VW_400 +PV1_000673 1.21 -0.15 -0.18 -0.09 1.61 0.13 -0.23 1.01 1.63 -0.4 -0.17 -0.63 0.45 -0.24 1.22 -0.5 0.35 -0.15 -0.08 0.64 -0.57 -0.2 0.08 -0.3 0.06 0.28 -0.46 2.86 0.17 -0.65 -0.2 -0.08 0.73 1.37 1.29 0.21 -0.05 -0.53 0.9 0.15 -0.83 0.08 -0.21 -0.52 0.08 -0.97 1.03 -0.14 -0.45 -0.74 0.94 -0.39 -0.82 0.53 -0.6 -0.86 2.15 2.13 1.14 -0.52 -0.18 -1.03 -0.16 -0.03 1.02 -0.37 -0.28 -0.43 -0.96 0.83 1.13 1.18 -0.01 0.03 0.14 0.81 -0.1 -0.66 -0.2 3.55 -0.44 0.44 -0.74 -0.3 -0.42 -0.51 0.39 0.68 -0.94 -0.31 -0.6 -0.41 -0.22 -0.56 0.18 0.2 -0.6 0.81 -0.08 1.51 -0.71 -0.16 -0.18 -0.85 -0.18 -0.61 0.95 -0.73 -0.06 0.28 -0.22 -0.42 0.74 -0.31 -0.1 -0.55 0.85 -0.63 1.08 0.28 -0.19 0.04 -0.04 0.6 0.08 -0.23 -0.39 -0.84 -1.07 -0.26 0.42 0.24 0.62 -0.53 -0.02 -0.31 0.89 -0.38 0.11 -0.81 -0.34 -0.24 2.31 0.4 -0.11 1.62 -0.21 0.09 -1.3 0.23 -0.65 0.94 -0.75 -0.15 -0.6 -0.69 -1.14 -0.04 0.51 -1.13 -0.01 -0.17 0 -0.37 -0.73 -0.56 0.71 0.01 0.25 0.33 0.33 0.36 -0.77 -0.77 -0.35 0.15 0.22 -0.14 -0.53 0.72 0.39 0.09 -0.41 0.44 1.71 0.56 -0.32 -0.19 0.67 0.21 -0.99 -0.15 -0.31 0.4 0.56 1.06 1.01 0.51 0.53 -0.34 -0.54 0.23 -0.33 -0.71 0.25 -0.03 -0.4 -0.24 0.94 -0.65 0.42 0 -0.68 0.23 -0.05 0.47 0.49 -0.4 0.79 0.61 0.98 -0.64 -0.67 0.16 -0.37 -0.16 -0.67 0.56 -0.95 0.5 2.14 0.56 0.1 -0.15 -0.82 -0.56 -0.46 0.02 -1.05 0.31 -0.41 1.45 -0.22 1.38 1.98 0.17 -0.92 1.53 0.36 -0.66 -0.66 0.46 0.47 0.51 -0.26 -0.02 -0.8 -0.98 0.53 1.07 0.82 -0.28 0.1 1.67 0.78 -0.25 1.71 -0.34 -0.41 -0.31 -0.01 -0.73 -1.18 1.69 -0.42 1.61 -0.65 0.69 1.04 1.82 0.29 -0.37 3.94 -0.47 -0.48 0.95 -0.06 -0.37 -1.23 0.95 0.54 -0.72 -0.65 -0.27 -0.25 -0.67 -0.11 0.2 0.73 -0.94 1.28 0.71 0.62 -0.26 -0.55 1.41 -1.1 0.95 -0.46 0.4 0.06 -0.75 -0.34 -0.46 1.18 0.45 -0.46 -0.73 0.35 -0.45 -0.18 0.29 1.52 -0.01 -0.69 0.71 1.16 0.12 -0.51 0.35 1.48 -0.42 0.78 1.26 1.14 -0.32 0.05 -0.01 0.04 -0.8 -0.8 0.27 0.71 1.57 1.15 2.55 -0.88 0.59 -0.05 -0.35 -0.61 1.01 1.77 -0.31 2.29 -0.09 -0.81 0.94 0.91 0.05 0.08 -0.69 1.51 -0.44 0.6 -1.09 0.42 0.35 0.71 1.15 0.37 -0.77 0.94 -0.92 -1.05 -0.69 0.17 -0.59 0.85 0.34 1.15 -0.71 0.05 -0.58 -0.81 0.76 2.36 2.36 -0.23 0 diff --git a/notebooks/pretrained_sample.ipynb b/notebooks/pretrained_sample.ipynb index 5735bf7..f71d7f7 100644 --- a/notebooks/pretrained_sample.ipynb +++ b/notebooks/pretrained_sample.ipynb @@ -43,6 +43,16 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "ba93a090", + "metadata": {}, + "source": [ + "### Requirements\n", + "\n", + "It is highly recommended you run this tool with at least one GPU, preferably 1-4 GPUs on a cluster. CPU usage has not been fully tested so use at your own risk." + ] + }, { "cell_type": "code", "execution_count": null, @@ -61,7 +71,7 @@ "print(\"cuda device count:\", torch.cuda.device_count())\n", "print(\"cuda[0]:\", torch.cuda.get_device_name(0))\n", "\n", - "DEVICE = \"cuda\"" + "DEVICE = \"cuda\" # change to \"cpu\" to run using CPUs" ] }, { @@ -101,6 +111,20 @@ "print(\"classmethod:\", predictor_cls.pretrained_meta.get(\"model_id\"), \"members:\", predictor_cls.n_members)" ] }, + { + "cell_type": "markdown", + "id": "21dc1072", + "metadata": {}, + "source": [ + "### Single Sequence Pretrained Example\n", + "\n", + "Using one of the available pretrained examples, you can predict exact epitope locations within the overall sequence.\n", + "\n", + "Typical header format (derived from FASTA store): ID=[id],ACC=[acc],OXX=[oxx].
\n", + "\n", + "A protein sequence is represented as a string of residues, where each character corresponds to one of the 21 amino acids: A, R, N, D, C, E, Q, G, H, I, L, K, M, F, P, U, S, T, W, Y, and V." + ] + }, { "cell_type": "code", "execution_count": null, @@ -122,6 +146,40 @@ "assert result.meta.get(\"pretrained\", {}).get(\"model_id\") == \"flagship2-v1\"" ] }, + { + "cell_type": "markdown", + "id": "b423db7f", + "metadata": {}, + "source": [ + "### Prediction Example Using FASTA\n", + "\n", + "A common method of running large-scale prediction studies is by passing a FASTA file as input to one of the FASTA prediction APIs.
\n", + "\n", + "A FASTA file used as input should be structured in this following way:
\n", + "```text\n", + ">ID=[id1],ACC=[acc1],OXX=[oxx1]\n", + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\n", + ">ID=[id2],ACC=[acc2],OXX=[oxx2]\n", + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\n", + ">ID=[id3],ACC=[acc3],OXX=[oxx3]\n", + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\n", + "...\n", + "```\n", + "\n", + "Where the header comes after the '>' character, followed by the target protein sequence on a newline.
\n", + "\n", + "If an output file is requested, it follow the same FASTA format, but replace the sequence string with a binary string, where 1=epitope and 0=not epitope:
\n", + "```text\n", + ">ID=[id1],ACC=[acc1],OXX=[oxx1]\n", + "00000000000000001111000000000000000000000111111000000000\n", + ">ID=[id2],ACC=[acc2],OXX=[oxx2]\n", + "00000000000000000000111111111100000000000000000000000000\n", + ">ID=[id3],ACC=[acc3],OXX=[oxx3]\n", + "11111000000000000000000000000000000000000000000000011111\n", + "...\n", + "```" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/pyproject.toml b/pyproject.toml index e685f2c..b4b3730 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pepseqpred" -version = "1.1.0" +version = "1.1.1" description = "Residue-level epitope prediction pipeline for peptide/protein workflows." readme = "README.pypi.md" requires-python = ">=3.12" diff --git a/scripts/hpc/trainffnn.sh b/scripts/hpc/trainffnn.sh index a845946..6517700 100644 --- a/scripts/hpc/trainffnn.sh +++ b/scripts/hpc/trainffnn.sh @@ -23,12 +23,9 @@ usage() { echo " $0 /scratch/$USER/embeddings/shard1 /scratch/$USER/embeddings/shard2 -- /scratch/$USER/labels/labels_00.pt /scratch/$USER/labels/labels_01.pt" echo "" echo "Optional environment variables:" - echo " TRAIN_MODE default: seeded (seeded or ensemble-kfold)" - echo " N_FOLDS default: 5 (ensemble-kfold only)" + echo " N_FOLDS default: 1 (1=single holdout, >1=K-fold ensemble)" echo " SPLIT_SEEDS default: 11,22,33,44,55" echo " TRAIN_SEEDS default: 101,202,303,404,505" - echo " FOLD_SEED default: unset (legacy single-set fallback)" - echo " ENSEMBLE_MEMBER_TRAIN_SEEDS default: unset (legacy per-fold train seeds CSV)" } # require at least one embedding dir, separator (--), one label shard @@ -64,10 +61,7 @@ EPOCHS="${EPOCHS:-10}" BEST_MODEL_METRIC="${BEST_MODEL_METRIC:-pr_auc}" SPLIT_SEEDS="${SPLIT_SEEDS:-11,22,33,44,55}" TRAIN_SEEDS="${TRAIN_SEEDS:-101,202,303,404,505}" -TRAIN_MODE="${TRAIN_MODE:-seeded}" # seeded or ensemble-kfold -N_FOLDS="${N_FOLDS:-5}" -FOLD_SEED="${FOLD_SEED:-}" -ENSEMBLE_MEMBER_TRAIN_SEEDS="${ENSEMBLE_MEMBER_TRAIN_SEEDS:-}" +N_FOLDS="${N_FOLDS:-1}" BATCH_SIZE="${BATCH_SIZE:-256}" # ensure batch size is 4 times what you would do for one GPU (for example. 256 = 64 * 4) LR="${LR:-0.001}" WD="${WD:-0.0}" @@ -101,20 +95,13 @@ else LAUNCHER="" fi -TRAIN_MODE_ARGS=(--train-mode "$TRAIN_MODE") -if [ "$TRAIN_MODE" = "ensemble-kfold" ]; then - TRAIN_MODE_ARGS+=(--n-folds "$N_FOLDS") - TRAIN_MODE_ARGS+=(--ensemble-manifest "$ENSEMBLE_MANIFEST") - if [ -n "$SPLIT_SEEDS" ] && [ -n "$TRAIN_SEEDS" ]; then - TRAIN_MODE_ARGS+=(--split-seeds "$SPLIT_SEEDS") - TRAIN_MODE_ARGS+=(--train-seeds "$TRAIN_SEEDS") - else - [ -n "$FOLD_SEED" ] && TRAIN_MODE_ARGS+=(--fold-seed "$FOLD_SEED") - [ -n "$ENSEMBLE_MEMBER_TRAIN_SEEDS" ] && TRAIN_MODE_ARGS+=(--ensemble-train-seeds "$ENSEMBLE_MEMBER_TRAIN_SEEDS") - fi -else - TRAIN_MODE_ARGS+=(--split-seeds "$SPLIT_SEEDS") - TRAIN_MODE_ARGS+=(--train-seeds "$TRAIN_SEEDS") +TRAIN_ARGS=( + --n-folds "$N_FOLDS" + --split-seeds "$SPLIT_SEEDS" + --train-seeds "$TRAIN_SEEDS" +) +if [ "$N_FOLDS" -gt 1 ]; then + TRAIN_ARGS+=(--ensemble-manifest "$ENSEMBLE_MANIFEST") fi VAL_CURVE_ARGS=() @@ -131,7 +118,7 @@ ${LAUNCHER} torchrun --nproc_per_node=4 train_ffnn.pyz \ --hidden-sizes "$HIDDEN_SIZES" \ --dropouts "$DROPOUTS" \ --epochs "$EPOCHS" \ - "${TRAIN_MODE_ARGS[@]}" \ + "${TRAIN_ARGS[@]}" \ --batch-size "$BATCH_SIZE" \ --lr "$LR" \ --wd "$WD" \ diff --git a/scripts/tools/cocci_eval_pipeline.py b/scripts/tools/cocci_eval_pipeline.py index 9e71f81..ece1d15 100644 --- a/scripts/tools/cocci_eval_pipeline.py +++ b/scripts/tools/cocci_eval_pipeline.py @@ -65,7 +65,7 @@ def parse_protein_id_from_prediction_header(header: str) -> str: def build_fullname(protein_id: str, oxx: str) -> str: - """Builds canonical PV1-style fullname.""" + """Builds canonical PV1-style fullname (i.e., ID= AC= OXX=).""" return f"ID={protein_id} AC={protein_id} OXX={oxx}" diff --git a/src/pepseqpred/apps/prepare_dataset_cli.py b/src/pepseqpred/apps/prepare_dataset_cli.py index bbba5d4..5e68534 100644 --- a/src/pepseqpred/apps/prepare_dataset_cli.py +++ b/src/pepseqpred/apps/prepare_dataset_cli.py @@ -1,6 +1,7 @@ """prepare_dataset_cli.py -Normalize PV1/CWP/BKP sources into a shared PV1-compatible training contract. +Normalize PV1/CWP/BKP sources into a shared PV1-compatible training +contract (i.e., ID= AC= OXX=). """ import argparse import time @@ -15,7 +16,7 @@ def main() -> None: parser = argparse.ArgumentParser( description=( "Prepare dataset-specific metadata/labels/targets into a PV1-compatible " - "contract for embedding, label generation, and training." + "contract (i.e., ID= AC= OXX=) for embedding, label generation, and training." ) ) parser.add_argument( diff --git a/src/pepseqpred/apps/train_ffnn_cli.py b/src/pepseqpred/apps/train_ffnn_cli.py index fe5b943..a43ddfa 100644 --- a/src/pepseqpred/apps/train_ffnn_cli.py +++ b/src/pepseqpred/apps/train_ffnn_cli.py @@ -130,6 +130,36 @@ def _as_optional_int(value: Any) -> int | None: return None +def _legacy_train_mode_label(n_folds: int) -> str: + """Resolves legacy-compatible train_mode label from unified n_folds input.""" + return "ensemble-kfold" if int(n_folds) > 1 else "seeded" + + +_LEGACY_TRAIN_FLAGS: Dict[str, str] = { + "--train-mode": ( + "Removed. Use --n-folds 1 for holdout runs, or --n-folds K (K>1) for K-fold ensemble runs." + ), + "--fold-seed": ( + "Removed. Use --split-seeds (paired with --train-seeds) to control per-set fold assignment seeds." + ), + "--ensemble-train-seeds": ( + "Removed. Use --train-seeds as per-set training seeds." + ), +} + + +def _match_legacy_train_flags(tokens: Sequence[str]) -> List[str]: + """Finds removed legacy train-mode flags in parser unknown-token output.""" + found: List[str] = [] + for token in tokens: + token_s = str(token).strip() + for flag in _LEGACY_TRAIN_FLAGS: + if token_s == flag or token_s.startswith(f"{flag}="): + if flag not in found: + found.append(flag) + return found + + def _resolve_pr_zoom_limits( fold_evaluations: Sequence[Mapping[str, Any]], baseline_y: float | None, @@ -642,126 +672,14 @@ def _build_run_plans( protein_ids: List[str], family_groups: Dict[str, str] ) -> Tuple[List[RunPlan], Dict[str, Any]]: - """Builds run plans for either seeded or ensemble-kfold training modes.""" + """Builds run plans from unified seed lists and n-folds configuration.""" if len(protein_ids) == 0: raise ValueError("No proteins found to train on") - if args.train_mode == "ensemble-kfold": - n_folds = int(args.n_folds) - if n_folds < 2: - raise ValueError( - "--n-folds must be >= 2 when --train-mode ensemble-kfold") - - uses_set_seed_lists = (args.split_seeds is not None) or ( - args.train_seeds is not None) - - # Preferred mode: each (split_seed, train_seed) pair defines one K-fold set. - if uses_set_seed_lists: - if args.split_seeds is None or args.train_seeds is None: - raise ValueError( - "Provide both --split-seeds and --train-seeds for ensemble-kfold set replication") - if args.fold_seed is not None: - raise ValueError( - "--fold-seed cannot be combined with --split-seeds in ensemble-kfold mode") - if args.ensemble_train_seeds is not None: - raise ValueError( - "--ensemble-train-seeds cannot be combined with --split-seeds/--train-seeds") - - set_split_seeds = parse_int_csv(args.split_seeds, "--split-seeds") - set_train_seeds = parse_int_csv(args.train_seeds, "--train-seeds") - if len(set_split_seeds) != len(set_train_seeds): - raise ValueError( - "--split-seeds and --train-seeds must be the same length") - ensemble_seed_mode = "set-paired" - legacy_per_fold_train_seeds: List[int] = [] - else: - # Backward-compatible mode: single K-fold set with optional per-fold train seeds. - fold_seed = int(args.fold_seed) if args.fold_seed is not None else int(args.seed) - set_split_seeds = [int(fold_seed)] - set_train_seeds = [int(args.seed)] - - if args.ensemble_train_seeds is not None: - legacy_per_fold_train_seeds = parse_int_csv( - args.ensemble_train_seeds, "--ensemble-train-seeds") - if len(legacy_per_fold_train_seeds) != n_folds: - raise ValueError("--ensemble-train-seeds must match --n-folds") - ensemble_seed_mode = "legacy-per-fold-train-seeds" - else: - legacy_per_fold_train_seeds = [] - ensemble_seed_mode = "set-paired" - - n_sets = len(set_split_seeds) - run_plans: List[RunPlan] = [] - global_run_index = 1 - for set_index, (set_split_seed, set_train_seed) in enumerate( - zip(set_split_seeds, set_train_seeds), - start=1 - ): - if args.split_type == "id-family": - fold_splits = build_grouped_kfold_splits( - protein_ids, n_folds=n_folds, seed=int(set_split_seed), groups=family_groups - ) - else: - fold_splits = build_kfold_splits( - protein_ids, n_folds=n_folds, seed=int(set_split_seed) - ) - - if n_sets > 1: - set_dir_name = f"set_{set_index:03d}_split_{int(set_split_seed)}_train_{int(set_train_seed)}" - else: - set_dir_name = None - - for fold_index, (train_ids_all, val_ids_all) in enumerate(fold_splits, start=1): - if args.split_type == "id-family": - _check_family_leakage( - train_ids_all, val_ids_all, family_groups) + n_folds = int(args.n_folds) + if n_folds < 1: + raise ValueError("--n-folds must be >= 1") - if ensemble_seed_mode == "legacy-per-fold-train-seeds": - fold_train_seed = int(legacy_per_fold_train_seeds[fold_index - 1]) - else: - fold_train_seed = int(set_train_seed) - - if set_dir_name is None: - save_dir_name = ( - f"fold_{fold_index:02d}_split_{int(set_split_seed)}_train_{int(fold_train_seed)}") - else: - save_dir_name = f"{set_dir_name}/fold_{fold_index:02d}" - - run_plans.append( - RunPlan( - run_index=global_run_index, - train_mode=args.train_mode, - split_seed=int(set_split_seed), - train_seed=int(fold_train_seed), - train_ids_all=list(train_ids_all), - val_ids_all=list(val_ids_all), - save_dir_name=save_dir_name, - fold_index=fold_index, - n_folds=n_folds, - ensemble_set_index=set_index, - ensemble_set_split_seed=int(set_split_seed), - ensemble_set_train_seed=int(set_train_seed), - ensemble_set_dir_name=set_dir_name - ) - ) - global_run_index += 1 - - split_meta: Dict[str, Any] = { - "split_seeds": [int(x) for x in set_split_seeds], - "train_seeds": [int(x) for x in set_train_seeds], - "n_folds": int(n_folds), - "n_sets": int(n_sets), - "ensemble_seed_mode": ensemble_seed_mode - } - if args.fold_seed is not None: - split_meta["fold_seed"] = int(args.fold_seed) - if ensemble_seed_mode == "legacy-per-fold-train-seeds": - split_meta["ensemble_train_seeds_per_fold"] = [ - int(x) for x in legacy_per_fold_train_seeds - ] - return run_plans, split_meta - - # seeded/single mode (backward-compatible default) if args.split_seeds is None and args.train_seeds is None: split_seeds = [int(args.seed)] train_seeds = [int(args.seed)] @@ -778,36 +696,105 @@ def _build_run_plans( "--split-seeds and --train-seeds must be the same length" ) + n_sets = len(split_seeds) + train_mode = _legacy_train_mode_label(n_folds) + + if n_folds == 1: + run_plans: List[RunPlan] = [] + for run_index, (split_seed, train_seed) in enumerate(zip(split_seeds, train_seeds), start=1): + if args.split_type == "id-family": + train_ids_all, val_ids_all = split_ids_grouped( + protein_ids, args.val_frac, split_seed, family_groups + ) + _check_family_leakage(train_ids_all, val_ids_all, family_groups) + else: + train_ids_all, val_ids_all = split_ids( + protein_ids, args.val_frac, split_seed + ) + + if len(train_ids_all) == 0: + raise ValueError("Global split produced 0 train IDs") + + run_plans.append( + RunPlan( + run_index=run_index, + train_mode=train_mode, + split_seed=int(split_seed), + train_seed=int(train_seed), + train_ids_all=list(train_ids_all), + val_ids_all=list(val_ids_all), + save_dir_name=f"run_{run_index:03d}_split_{int(split_seed)}_train_{int(train_seed)}", + n_folds=1 + ) + ) + + return run_plans, { + "split_seeds": [int(x) for x in split_seeds], + "train_seeds": [int(x) for x in train_seeds], + "n_folds": 1, + "n_sets": int(n_sets), + "train_mode": train_mode + } + run_plans = [] - for run_index, (split_seed, train_seed) in enumerate(zip(split_seeds, train_seeds), start=1): + global_run_index = 1 + for set_index, (set_split_seed, set_train_seed) in enumerate( + zip(split_seeds, train_seeds), + start=1 + ): if args.split_type == "id-family": - train_ids_all, val_ids_all = split_ids_grouped( - protein_ids, args.val_frac, split_seed, family_groups + fold_splits = build_grouped_kfold_splits( + protein_ids, n_folds=n_folds, seed=int(set_split_seed), groups=family_groups ) - _check_family_leakage(train_ids_all, val_ids_all, family_groups) else: - train_ids_all, val_ids_all = split_ids( - protein_ids, args.val_frac, split_seed + fold_splits = build_kfold_splits( + protein_ids, n_folds=n_folds, seed=int(set_split_seed) ) - if len(train_ids_all) == 0: - raise ValueError("Global split produced 0 train IDs") + if n_sets > 1: + set_dir_name = f"set_{set_index:03d}_split_{int(set_split_seed)}_train_{int(set_train_seed)}" + else: + set_dir_name = None + + for fold_index, (train_ids_all, val_ids_all) in enumerate(fold_splits, start=1): + if args.split_type == "id-family": + _check_family_leakage( + train_ids_all, val_ids_all, family_groups) - run_plans.append( - RunPlan( - run_index=run_index, - train_mode=args.train_mode, - split_seed=int(split_seed), - train_seed=int(train_seed), - train_ids_all=list(train_ids_all), - val_ids_all=list(val_ids_all), - save_dir_name=f"run_{run_index:03d}_split_{int(split_seed)}_train_{int(train_seed)}" + fold_train_seed = int(set_train_seed) + + if set_dir_name is None: + save_dir_name = ( + f"fold_{fold_index:02d}_split_{int(set_split_seed)}_train_{int(fold_train_seed)}") + else: + save_dir_name = f"{set_dir_name}/fold_{fold_index:02d}" + + run_plans.append( + RunPlan( + run_index=global_run_index, + train_mode=train_mode, + split_seed=int(set_split_seed), + train_seed=int(fold_train_seed), + train_ids_all=list(train_ids_all), + val_ids_all=list(val_ids_all), + save_dir_name=save_dir_name, + fold_index=fold_index, + n_folds=n_folds, + ensemble_set_index=set_index, + ensemble_set_split_seed=int(set_split_seed), + ensemble_set_train_seed=int(set_train_seed), + ensemble_set_dir_name=set_dir_name + ) ) - ) + global_run_index += 1 return run_plans, { "split_seeds": [int(x) for x in split_seeds], - "train_seeds": [int(x) for x in train_seeds] + "train_seeds": [int(x) for x in train_seeds], + "n_folds": int(n_folds), + "n_sets": int(n_sets), + "ensemble_seed_mode": "set-paired", + "train_mode": train_mode } @@ -904,11 +891,6 @@ def main() -> None: default="id-family", choices=["id", "id-family"], help="Data partition type, use ID only or ID and taxonomic family.") - parser.add_argument("--train-mode", - type=str, - default="seeded", - choices=["seeded", "ensemble-kfold"], - help="Training mode: seeded (single/multi-seed runs) or ensemble-kfold.") parser.add_argument("--num-workers", action="store", dest="num_workers", @@ -953,23 +935,15 @@ def main() -> None: parser.add_argument("--split-seeds", type=str, default=None, - help="CSV split seeds (e.g., 11,22,33). In ensemble-kfold mode, each split seed defines one K-fold set.") + help="CSV split seeds (e.g., 11,22,33). Each split seed defines one run set.") parser.add_argument("--train-seeds", type=str, default=None, - help="CSV train seeds (e.g., 44,55,66). In ensemble-kfold mode, each train seed pairs with one split seed to define one K-fold set.") + help="CSV train seeds (e.g., 44,55,66). Each train seed pairs with one split seed by index.") parser.add_argument("--n-folds", type=int, - default=5, - help="Number of folds for --train-mode ensemble-kfold.") - parser.add_argument("--fold-seed", - type=int, - default=None, - help="Seed for fold assignment in single-set ensemble-kfold fallback mode (default: --seed).") - parser.add_argument("--ensemble-train-seeds", - type=str, - default=None, - help="Legacy single-set ensemble mode only: CSV per-fold train seeds; length must equal --n-folds.") + default=1, + help="Number of folds per set. Use 1 for holdout split mode and >1 for K-fold ensemble mode.") parser.add_argument("--best-model-metric", type=str, default="loss", @@ -986,7 +960,7 @@ def main() -> None: default=False, help=( "If set, save per-epoch validation ROC/PR curve data and plots. " - "In ensemble-kfold mode, also writes set-level fold consistency ROC/PR plots " + "When --n-folds > 1, also writes set-level fold consistency ROC/PR plots " "using each fold's best epoch." )) parser.add_argument("--val-curve-max-points", @@ -1004,9 +978,21 @@ def main() -> None: parser.add_argument("--ensemble-manifest", type=Path, default=None, - help="Optional JSON output path for ensemble manifest (written in ensemble-kfold mode).") - - args = parser.parse_args() + help="Optional JSON output path for ensemble manifest (written when --n-folds > 1).") + + args, unknown = parser.parse_known_args() + legacy_flags = _match_legacy_train_flags(unknown) + if len(legacy_flags) > 0: + details = " ".join( + f"{flag}: {_LEGACY_TRAIN_FLAGS[flag]}" + for flag in legacy_flags + ) + raise ValueError( + "Legacy train-mode flags are no longer supported. " + + details + ) + if len(unknown) > 0: + parser.error(f"unrecognized arguments: {' '.join(unknown)}") logger = setup_logger(json_lines=True, json_indent=2, name="train_ffnn_cli") @@ -1096,16 +1082,15 @@ def main() -> None: run_plans, split_meta = _build_run_plans(args, protein_ids, family_groups) if rank == 0: logger.info("run_plan_init", extra={"extra": { - "train_mode": args.train_mode, + "train_mode": str(split_meta["train_mode"]), "n_runs": len(run_plans), "split_type": args.split_type, "missing_family_ids": missing_family_ids, "n_folds": int(split_meta["n_folds"]) if "n_folds" in split_meta else None, "n_sets": int(split_meta["n_sets"]) if "n_sets" in split_meta else None, - "fold_seed": int(split_meta["fold_seed"]) if "fold_seed" in split_meta else None, "ensemble_seed_mode": split_meta.get("ensemble_seed_mode") }}) - if args.train_mode == "ensemble-kfold": + if int(split_meta["n_folds"]) > 1: logger.info("ensemble_kfold_note", extra={"extra": { "message": "--val-frac is ignored in ensemble-kfold mode because folds define validation sets." }}) @@ -1375,7 +1360,7 @@ def main() -> None: df_runs = pd.DataFrame(run_rows) summary_payload = { "n_runs": int(len(run_rows)), - "train_mode": str(args.train_mode), + "train_mode": str(split_meta["train_mode"]), "split_type": str(args.split_type), "best_model_metric": str(args.best_model_metric), "split_seeds": [int(x) for x in split_meta["split_seeds"]], @@ -1396,15 +1381,9 @@ def main() -> None: summary_payload["n_folds"] = int(split_meta["n_folds"]) if "n_sets" in split_meta: summary_payload["n_sets"] = int(split_meta["n_sets"]) - if "fold_seed" in split_meta: - summary_payload["fold_seed"] = int(split_meta["fold_seed"]) if "ensemble_seed_mode" in split_meta: summary_payload["ensemble_seed_mode"] = str( split_meta["ensemble_seed_mode"]) - if "ensemble_train_seeds_per_fold" in split_meta: - summary_payload["ensemble_train_seeds_per_fold"] = [ - int(x) for x in split_meta["ensemble_train_seeds_per_fold"] - ] summary_path = args.save_path / "multi_run_summary.json" summary_path.write_text( json.dumps(_sanitize_for_json(summary_payload), @@ -1412,7 +1391,7 @@ def main() -> None: encoding="utf-8" ) - if args.train_mode == "ensemble-kfold": + if int(split_meta["n_folds"]) > 1: n_sets = int(split_meta.get("n_sets", 1)) sets_map: Dict[int, Dict[str, Any]] = {} for row in run_rows: @@ -1500,7 +1479,7 @@ def main() -> None: set_payload = { "schema_version": 1, "ensemble_type": "kfold_majority_vote", - "train_mode": str(args.train_mode), + "train_mode": str(split_meta["train_mode"]), "split_type": str(args.split_type), "n_folds": int(split_meta["n_folds"]), "set_index": int(entry["set_index"]), @@ -1549,7 +1528,7 @@ def main() -> None: root_manifest_payload = { "schema_version": 2, "ensemble_type": "kfold_majority_vote", - "train_mode": str(args.train_mode), + "train_mode": str(split_meta["train_mode"]), "split_type": str(args.split_type), "n_folds": int(split_meta["n_folds"]), "n_sets": int(n_sets), diff --git a/src/pepseqpred/core/io/read.py b/src/pepseqpred/core/io/read.py index be67788..d83bda4 100644 --- a/src/pepseqpred/core/io/read.py +++ b/src/pepseqpred/core/io/read.py @@ -2,8 +2,9 @@ Input parsing utilities for PepSeqPred tabular and FASTA data. -Provides helpers to read PV1-style FASTA files, metadata TSVs, and z-score -reactivity tables into normalized pandas DataFrames. Also includes CLI CSV arguments parsers for downstream training and predictions. +Provides helpers to read PV1-style (i.e., ID= AC= OXX=) FASTA files, metadata TSVs, +and z-score reactivity tables into normalized pandas DataFrames. +Also includes CLI CSV arguments parsers for downstream training and predictions. """ import re @@ -14,7 +15,7 @@ def read_fasta(fasta_path: Path | str, full_name: bool = False) -> pd.DataFrame: """ - Parses a FASTA file with PV1-style headers into a pandas DataFrame. + Parses a FASTA file with PV1-style headers (i.e., ID= AC= OXX=) into a pandas DataFrame. Expected pattern (example): >ID=A8D0M1_ADE02 AC=A8D0M1 OXX=10515,129951,10509,10508 @@ -92,7 +93,7 @@ def read_metadata(meta_path: Path | str, peptide_end_idx: str = "AlignStop", drop_cols: Optional[Iterable[str]] = None) -> pd.DataFrame: """ - Parses a metadata TSV file with PV1-style headers into a pandas DataFrame. + Parses a metadata TSV file with PV1-style headers (i.e., ID= AC= OXX=) into a pandas DataFrame. Expected columns: CodeName, Category, SpeciesID, Species, Protein, AlignStart, AlignStop, FullName, diff --git a/src/pepseqpred/core/labels/builder.py b/src/pepseqpred/core/labels/builder.py index 1aabcb6..8c15211 100644 --- a/src/pepseqpred/core/labels/builder.py +++ b/src/pepseqpred/core/labels/builder.py @@ -2,8 +2,8 @@ Label building utilities for PepSeqPred peptide metadata. -Provides helpers to parse PV1-style metadata, map peptides to protein embeddings, -and build residue-level label tensors and peptide metadata for training. +Provides helpers to parse PV1-style (i.e., ID= AC= OXX=) metadata, map peptides +to protein embeddings, and build residue-level label tensors and peptide metadata for training. """ import logging diff --git a/src/pepseqpred/core/preprocess/preparedataset.py b/src/pepseqpred/core/preprocess/preparedataset.py index a87d9ec..3e0f7b6 100644 --- a/src/pepseqpred/core/preprocess/preparedataset.py +++ b/src/pepseqpred/core/preprocess/preparedataset.py @@ -3,7 +3,7 @@ Dataset normalization adapter for multi-source training preparation. This module converts PV1/CWP/BKP source inputs into a common PV1-compatible -contract used by existing embedding, label, and training CLIs. +contract (i.e., ID= AC= OXX=) used by existing embedding, label, and training CLIs. """ import csv import json @@ -16,7 +16,7 @@ def _build_fullname(protein_id: str, group_numeric: int) -> str: - """Builds a PV1-style fullname for normalized outputs.""" + """Builds a PV1-style fullname (i.e., ID= AC= OXX=) for normalized outputs.""" return f"ID={protein_id} AC={protein_id} OXX=0,0,0,{int(group_numeric)}" @@ -140,7 +140,7 @@ def _build_nonpv1_fasta_index(fasta_path: Path | str) -> Tuple[Dict[str, str], D def _build_pv1_fasta_index(fasta_path: Path | str) -> Tuple[Dict[str, str], Dict[str, List[str]]]: """ - Builds PV1 protein_id -> sequence mapping from PV1-style FASTA headers. + Builds PV1 protein_id -> sequence mapping from PV1-style FASTA headers (i.e., ID= AC= OXX=). """ seqs_by_id: Dict[str, List[str]] = {} for header, seq in _read_fasta_records(fasta_path): @@ -620,7 +620,7 @@ def prepare_dataset( logger: Optional[logging.Logger] = None ) -> Dict[str, Any]: """ - Converts dataset-specific sources into a shared PV1-compatible contract. + Converts dataset-specific sources into a shared PV1-compatible contract (i.e., ID= AC= OXX=). Outputs under `output_dir`: - prepared_targets.fasta diff --git a/tests/integration/test_train_clis_inprocess.py b/tests/integration/test_train_clis_inprocess.py index 04af1dd..f94c04e 100644 --- a/tests/integration/test_train_clis_inprocess.py +++ b/tests/integration/test_train_clis_inprocess.py @@ -132,8 +132,6 @@ def test_train_ffnn_cli_ensemble_kfold_inprocess(training_artifacts, tmp_path: P "0.1", "--split-type", "id-family", - "--train-mode", - "ensemble-kfold", "--n-folds", "2", "--split-seeds", @@ -200,8 +198,6 @@ def test_train_ffnn_cli_ensemble_kfold_with_aggregate_val_curves( "0.1", "--split-type", "id-family", - "--train-mode", - "ensemble-kfold", "--n-folds", "2", "--split-seeds", diff --git a/tests/unit/apps/test_train_cli_coverage.py b/tests/unit/apps/test_train_cli_coverage.py index fff23f4..3872f46 100644 --- a/tests/unit/apps/test_train_cli_coverage.py +++ b/tests/unit/apps/test_train_cli_coverage.py @@ -89,6 +89,32 @@ def test_train_cli_helper_parsers_and_numeric_summary(): train_cli._parse_plot_formats("png,jpg") +@pytest.mark.parametrize( + ("legacy_flag", "legacy_value"), + [ + ("--train-mode", "ensemble-kfold"), + ("--fold-seed", "17"), + ("--ensemble-train-seeds", "11,12"), + ], +) +def test_train_ffnn_cli_rejects_removed_legacy_mode_flags( + legacy_flag: str, legacy_value: str +): + with pytest.raises(ValueError, match="Legacy train-mode flags are no longer supported"): + _run_main( + train_cli.main, + [ + "train_ffnn_cli.py", + "--embedding-dirs", + "dummy_embedding_dir", + "--label-shards", + "dummy_labels.pt", + legacy_flag, + legacy_value, + ], + ) + + def test_train_ffnn_cli_real_no_valid_score_with_val_curve_artifacts(): case_dir = _mk_case_dir("ffnn_no_valid") emb_dir, label_shard = _write_training_artifacts( @@ -180,8 +206,6 @@ def test_train_ffnn_cli_real_ensemble_manifest_generation(): "8", "--dropouts", "0.1", - "--train-mode", - "ensemble-kfold", "--split-type", "id-family", "--n-folds", diff --git a/tests/unit/apps/test_train_ffnn_run_plans.py b/tests/unit/apps/test_train_ffnn_run_plans.py index 3fb949b..692cb81 100644 --- a/tests/unit/apps/test_train_ffnn_run_plans.py +++ b/tests/unit/apps/test_train_ffnn_run_plans.py @@ -7,13 +7,10 @@ def _args(**overrides): base = argparse.Namespace( - train_mode="ensemble-kfold", n_folds=2, split_seeds=None, train_seeds=None, - fold_seed=None, seed=42, - ensemble_train_seeds=None, split_type="id-family", val_frac=0.5, ) @@ -22,9 +19,8 @@ def _args(**overrides): return base -def test_build_run_plans_ensemble_set_pairs_map_to_full_kfold_sets(): +def test_build_run_plans_kfold_set_pairs_map_to_full_kfold_sets(): args = _args( - train_mode="ensemble-kfold", n_folds=2, split_seeds="101,202", train_seeds="11,22", @@ -41,6 +37,7 @@ def test_build_run_plans_ensemble_set_pairs_map_to_full_kfold_sets(): assert meta["n_folds"] == 2 assert meta["split_seeds"] == [101, 202] assert meta["train_seeds"] == [11, 22] + assert meta["train_mode"] == "ensemble-kfold" assert meta["ensemble_seed_mode"] == "set-paired" set_1 = [plan for plan in plans if plan.ensemble_set_index == 1] @@ -70,7 +67,6 @@ def test_build_run_plans_ensemble_set_pairs_map_to_full_kfold_sets(): def test_build_run_plans_ensemble_set_pairs_require_matching_lengths(): args = _args( - train_mode="ensemble-kfold", n_folds=2, split_seeds="101,202", train_seeds="11", @@ -82,45 +78,36 @@ def test_build_run_plans_ensemble_set_pairs_require_matching_lengths(): _build_run_plans(args, ids, {}) -def test_build_run_plans_ensemble_legacy_per_fold_train_seeds(): +def test_build_run_plans_single_fold_keeps_holdout_behavior(): args = _args( - train_mode="ensemble-kfold", - n_folds=2, - split_seeds=None, - train_seeds=None, - fold_seed=17, - seed=5, - ensemble_train_seeds="11,12", - split_type="id", - ) - ids = ["p1", "p2", "p3", "p4"] - - plans, meta = _build_run_plans(args, ids, {}) - - assert len(plans) == 2 - assert meta["ensemble_seed_mode"] == "legacy-per-fold-train-seeds" - assert meta["ensemble_train_seeds_per_fold"] == [11, 12] - assert all(plan.split_seed == 17 for plan in plans) - assert [plan.train_seed for plan in plans] == [11, 12] - assert [plan.ensemble_set_index for plan in plans] == [1, 1] - - -def test_build_run_plans_seeded_mode_keeps_prior_behavior(): - args = _args( - train_mode="seeded", + n_folds=1, split_seeds="11,22", train_seeds="101,202", split_type="id", - val_frac=0.5, ) ids = ["p1", "p2", "p3", "p4"] plans, meta = _build_run_plans(args, ids, {}) assert len(plans) == 2 - assert [plan.run_index for plan in plans] == [1, 2] assert [plan.split_seed for plan in plans] == [11, 22] assert [plan.train_seed for plan in plans] == [101, 202] + assert all(plan.train_mode == "seeded" for plan in plans) + assert all(plan.n_folds == 1 for plan in plans) + assert all(plan.fold_index is None for plan in plans) assert all(plan.ensemble_set_index is None for plan in plans) + assert all(plan.save_dir_name.startswith("run_") for plan in plans) + assert all(len(plan.train_ids_all) > 0 for plan in plans) assert meta["split_seeds"] == [11, 22] assert meta["train_seeds"] == [101, 202] + assert meta["n_folds"] == 1 + assert meta["n_sets"] == 2 + assert meta["train_mode"] == "seeded" + + +def test_build_run_plans_requires_n_folds_at_least_one(): + args = _args(n_folds=0, split_type="id") + ids = ["p1", "p2", "p3", "p4"] + + with pytest.raises(ValueError, match="--n-folds must be >= 1"): + _build_run_plans(args, ids, {})