Re-implementation of peptide toxicity prediction using protein Language Model (pLM) embeddings with a novel deep learning architecture.
| Metric | Value |
|---|---|
| Accuracy | 75.54% |
| Sensitivity | 69.57% |
| Specificity | 81.43% |
| MCC | 0.5138 |
| AUC | 0.7550 |
Dataset: 1388 sequences (694 toxic, 694 non-toxic) — Train: 1110, Val: 139, Test: 139
This project predicts whether a peptide sequence is toxic or non-toxic. Instead of learning embeddings from scratch, we leverage pre-trained protein language models that have learned meaningful representations from millions of protein sequences.
Traditional approach (original paper):
Amino acid sequence → Learned embedding (nn.Embedding) → Model → Prediction
Our approach:
Amino acid sequence → Pre-trained pLM (ESM-2 + ProtT5) → Model → Prediction
Advantages:
- pLMs have seen millions of proteins during pre-training
- They capture evolutionary, structural, and functional information
- Transfer learning: leverage knowledge from large-scale pre-training
Input Sequence
/ \
[Learned Embedding] [Molecular Graph]
| |
BiGRU 2D CNN
| |
Transformer Linear
| |
└──── Cross-Attention ─┘
|
Classifier
Input Sequence
/ \
ESM-2 ProtT5
(480D) (1024D)
\ /
Concatenate
|
Linear (1504 → 512)
|
Positional Encoding
|
Transformer Encoder
(2 layers)
|
Mean Pooling
|
MLP
|
[Toxic / Non-toxic]
| Aspect | Original Paper | Our Implementation |
|---|---|---|
| Sequence embedding | Learned (nn.Embedding) | Pre-trained ESM-2 |
| Structure encoding | Molecular graphs + 2D CNN | Pre-trained ProtT5 |
| Fusion strategy | Cross-attention (late) | Concatenation (early) |
| Architecture | Dual branch | Single unified pathway |
| Recurrent layers | BiGRU | None |
- Model used:
esm2_t12_35M_UR50D(lightweight version, 35M parameters) - Trained on ~250 million protein sequences
- Captures sequence patterns and evolutionary information
- Output: 480-dimensional vector per amino acid
- Model used:
prot_t5_xl_uniref50 - Based on T5 architecture, trained on protein sequences from UniRef50
- Captures contextual sequence information via encoder-decoder pre-training
- Output: 1024-dimensional vector per amino acid
- ESM-2: Captures evolutionary patterns via masked language modeling
- ProtT5: Captures contextual patterns via T5 sequence-to-sequence pre-training
- Combined: More robust and diverse representation of the peptide
src/
├── embeddings.py # Extract ESM and ProstT5 embeddings
├── model.py # Neural network architecture
├── dataset.py # Data loading and preprocessing
└── main.py # Training pipeline
Extracts pre-trained embeddings from pLM models.
extractor = EmbeddingExtractor()
esm_emb = extractor.extract_esm(sequences) # List of (seq_len, 480) tensors
prostt5_emb = extractor.extract_prostt5(sequences) # List of (seq_len, 1024) tensorsKey functions:
load_esm(): Loads ESM-2 model (esm2_t12_35M_UR50D) from Facebook's libraryload_prostt5(): Loads ProtT5 (prot_t5_xl_uniref50) from HuggingFacepad_embeddings(): Pads variable-length sequences to fixed sizeprecompute_and_save(): CLI to precompute embeddings once
Defines the neural network architecture.
Classes:
PositionalEncoding: Adds position information (Transformers have no inherent order)TransformerEncoder: Self-attention layers to model residue interactionsFocalLoss: Loss function for class imbalance (down-weights easy examples)ToxiPepModel: Main model combining all components
Model forward pass:
# 1. Concatenate embeddings
combined = concat([esm, prostt5], dim=-1) # (batch, seq_len, 1504)
# 2. Project to lower dimension
x = Linear(1504 → 512)(combined)
# 3. Add positional encoding
x = x + positional_encoding
# 4. Transformer layers (self-attention)
x = transformer_encoder(x)
# 5. Pool over sequence (mean)
x = mean(x, dim=1) # (batch, 512)
# 6. Classify
output = MLP(x) # (batch, 2)Handles data loading and preprocessing.
Key functions:
load_sequences_from_fasta(): Parse FASTA files, extract sequences and labelsEmbeddingDataset: PyTorch Dataset for pre-computed embeddings (ESM: 480D, ProtT5: 1024D)create_data_splits(): Split data into train/val/test with stratificationget_class_weights(): Compute weights for imbalanced classes
Why stratified splits?
Without stratification:
Train: 90% negative, 10% positive
Test: 70% negative, 30% positive ← Different distribution!
With stratification:
Train: 80% negative, 20% positive
Test: 80% negative, 20% positive ← Same distribution
Training pipeline with proper methodology.
Training flow:
1. Load pre-computed embeddings
2. Split: 80% train, 10% val, 10% test
3. For each epoch:
- Train on train_loader
- Evaluate on val_loader
- Save model if val improves
- Early stopping if no improvement
4. Final evaluation on test_loader (once!)
5. Print metrics: ACC, Sensitivity, Specificity, MCC, AUC
Why validation set?
- Original code only had train/test
- Problem: If you tune hyperparameters based on test performance, you're "cheating"
- Solution: Use validation set for tuning, test set only for final evaluation
pip install torch fair-esm transformers scikit-learnpython src/embeddings.py -i data/train.fasta -o data/train_embeddings.ptThis step is slow (requires running pLMs) but only needs to be done once.
python src/main.py --data data/train_embeddings.pt --epochs 100Arguments:
| Argument | Default | Description |
|---|---|---|
--data |
required | Path to embeddings file |
--batch_size |
32 | Batch size |
--epochs |
100 | Maximum epochs |
--lr |
1e-4 | Learning rate |
--patience |
10 | Early stopping patience |
--loss |
focal | Loss function (focal or ce) |
The script outputs:
best_model.pth: Trained model weightsbest_model_results.pt: Predictions and metrics
Peptide datasets are often imbalanced (more non-toxic than toxic).
Dataset: 900 non-toxic, 100 toxic
Naive model: Predict all non-toxic → 90% accuracy!
But: 0% sensitivity (misses all toxic peptides)
1. Focal Loss
FL = -α(1-pt)^γ * log(pt)
- pt high (easy example) → (1-pt)^γ small → low loss
- pt low (hard example) → (1-pt)^γ large → high loss
Forces model to focus on hard-to-classify examples.
2. Class Weights
weight[class] = total_samples / (n_classes × count[class])
Example:
- weight_negative = 1000 / (2 × 900) = 0.56
- weight_positive = 1000 / (2 × 100) = 5.0
Positive class contributes 9× more to the loss.
| Metric | Formula | Interpretation |
|---|---|---|
| Accuracy | (TP+TN)/Total | Overall correctness |
| Sensitivity | TP/(TP+FN) | How well we detect toxic |
| Specificity | TN/(TN+FP) | How well we detect non-toxic |
| MCC | See below | Balanced metric (-1 to +1) |
| AUC | Area under ROC | Ranking quality |
Why MCC?
- Accuracy is misleading with imbalanced data
- MCC considers all four confusion matrix values
- MCC = 0: random predictions
- MCC = 1: perfect predictions
- MCC = -1: completely wrong
Data leakage: When information from test set influences training.
- ❌ Computing class weights on full dataset → ✅ Compute only on train set
- ❌ Tuning hyperparameters on test set → ✅ Use validation set
- ❌ Evaluating on test every epoch → ✅ Only at the end
- ❌ Random split without stratification → ✅ Stratified split
- Python 3.8+
- PyTorch 2.0+
- fair-esm (for ESM-2)
- transformers (for ProstT5)
- scikit-learn (for metrics and splits)