Skip to content

lueasf/toxipep

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ToxiPep: Peptide Toxicity Prediction with pLM Embeddings

Re-implementation of peptide toxicity prediction using protein Language Model (pLM) embeddings with a novel deep learning architecture.

Results

Metric Value
Accuracy 75.54%
Sensitivity 69.57%
Specificity 81.43%
MCC 0.5138
AUC 0.7550

Dataset: 1388 sequences (694 toxic, 694 non-toxic) — Train: 1110, Val: 139, Test: 139

Overview

This project predicts whether a peptide sequence is toxic or non-toxic. Instead of learning embeddings from scratch, we leverage pre-trained protein language models that have learned meaningful representations from millions of protein sequences.

Why pLM Embeddings?

Traditional approach (original paper):

Amino acid sequence → Learned embedding (nn.Embedding) → Model → Prediction

Our approach:

Amino acid sequence → Pre-trained pLM (ESM-2 + ProtT5) → Model → Prediction

Advantages:

  • pLMs have seen millions of proteins during pre-training
  • They capture evolutionary, structural, and functional information
  • Transfer learning: leverage knowledge from large-scale pre-training

Architecture

Original Paper Architecture

                Input Sequence
                /           \
    [Learned Embedding]    [Molecular Graph]
           |                      |
        BiGRU                  2D CNN
           |                      |
     Transformer              Linear
           |                      |
           └──── Cross-Attention ─┘
                      |
                  Classifier

Our Architecture: Early Fusion + Transformer

            Input Sequence
            /            \
       ESM-2           ProtT5
       (480D)         (1024D)
            \            /
             Concatenate
                  |
            Linear (1504 → 512)
                  |
         Positional Encoding
                  |
          Transformer Encoder
             (2 layers)
                  |
            Mean Pooling
                  |
               MLP
                  |
           [Toxic / Non-toxic]

Key Differences

Aspect Original Paper Our Implementation
Sequence embedding Learned (nn.Embedding) Pre-trained ESM-2
Structure encoding Molecular graphs + 2D CNN Pre-trained ProtT5
Fusion strategy Cross-attention (late) Concatenation (early)
Architecture Dual branch Single unified pathway
Recurrent layers BiGRU None

Embeddings Explained

ESM-2 (Evolutionary Scale Modeling)

  • Model used: esm2_t12_35M_UR50D (lightweight version, 35M parameters)
  • Trained on ~250 million protein sequences
  • Captures sequence patterns and evolutionary information
  • Output: 480-dimensional vector per amino acid

ProtT5 (Protein T5)

  • Model used: prot_t5_xl_uniref50
  • Based on T5 architecture, trained on protein sequences from UniRef50
  • Captures contextual sequence information via encoder-decoder pre-training
  • Output: 1024-dimensional vector per amino acid

Why Both?

  • ESM-2: Captures evolutionary patterns via masked language modeling
  • ProtT5: Captures contextual patterns via T5 sequence-to-sequence pre-training
  • Combined: More robust and diverse representation of the peptide

Project Structure

src/
├── embeddings.py   # Extract ESM and ProstT5 embeddings
├── model.py        # Neural network architecture
├── dataset.py      # Data loading and preprocessing
└── main.py         # Training pipeline

File Descriptions

embeddings.py

Extracts pre-trained embeddings from pLM models.

extractor = EmbeddingExtractor()
esm_emb = extractor.extract_esm(sequences)      # List of (seq_len, 480) tensors
prostt5_emb = extractor.extract_prostt5(sequences)  # List of (seq_len, 1024) tensors

Key functions:

  • load_esm(): Loads ESM-2 model (esm2_t12_35M_UR50D) from Facebook's library
  • load_prostt5(): Loads ProtT5 (prot_t5_xl_uniref50) from HuggingFace
  • pad_embeddings(): Pads variable-length sequences to fixed size
  • precompute_and_save(): CLI to precompute embeddings once

model.py

Defines the neural network architecture.

Classes:

  • PositionalEncoding: Adds position information (Transformers have no inherent order)
  • TransformerEncoder: Self-attention layers to model residue interactions
  • FocalLoss: Loss function for class imbalance (down-weights easy examples)
  • ToxiPepModel: Main model combining all components

Model forward pass:

# 1. Concatenate embeddings
combined = concat([esm, prostt5], dim=-1)  # (batch, seq_len, 1504)

# 2. Project to lower dimension
x = Linear(1504512)(combined)

# 3. Add positional encoding
x = x + positional_encoding

# 4. Transformer layers (self-attention)
x = transformer_encoder(x)

# 5. Pool over sequence (mean)
x = mean(x, dim=1)  # (batch, 512)

# 6. Classify
output = MLP(x)  # (batch, 2)

dataset.py

Handles data loading and preprocessing.

Key functions:

  • load_sequences_from_fasta(): Parse FASTA files, extract sequences and labels
  • EmbeddingDataset: PyTorch Dataset for pre-computed embeddings (ESM: 480D, ProtT5: 1024D)
  • create_data_splits(): Split data into train/val/test with stratification
  • get_class_weights(): Compute weights for imbalanced classes

Why stratified splits?

Without stratification:
  Train: 90% negative, 10% positive
  Test:  70% negative, 30% positive  ← Different distribution!

With stratification:
  Train: 80% negative, 20% positive
  Test:  80% negative, 20% positive  ← Same distribution

main.py

Training pipeline with proper methodology.

Training flow:

1. Load pre-computed embeddings
2. Split: 80% train, 10% val, 10% test
3. For each epoch:
   - Train on train_loader
   - Evaluate on val_loader
   - Save model if val improves
   - Early stopping if no improvement
4. Final evaluation on test_loader (once!)
5. Print metrics: ACC, Sensitivity, Specificity, MCC, AUC

Why validation set?

  • Original code only had train/test
  • Problem: If you tune hyperparameters based on test performance, you're "cheating"
  • Solution: Use validation set for tuning, test set only for final evaluation

Usage

Step 1: Install Dependencies

pip install torch fair-esm transformers scikit-learn

Step 2: Precompute Embeddings

python src/embeddings.py -i data/train.fasta -o data/train_embeddings.pt

This step is slow (requires running pLMs) but only needs to be done once.

Step 3: Train Model

python src/main.py --data data/train_embeddings.pt --epochs 100

Arguments:

Argument Default Description
--data required Path to embeddings file
--batch_size 32 Batch size
--epochs 100 Maximum epochs
--lr 1e-4 Learning rate
--patience 10 Early stopping patience
--loss focal Loss function (focal or ce)

Step 4: Results

The script outputs:

  • best_model.pth: Trained model weights
  • best_model_results.pt: Predictions and metrics

Handling Class Imbalance

Peptide datasets are often imbalanced (more non-toxic than toxic).

Problem

Dataset: 900 non-toxic, 100 toxic
Naive model: Predict all non-toxic → 90% accuracy!
But: 0% sensitivity (misses all toxic peptides)

Solutions Implemented

1. Focal Loss

FL = -α(1-pt)^γ * log(pt)

- pt high (easy example) → (1-pt)^γ small → low loss
- pt low (hard example) → (1-pt)^γ large → high loss

Forces model to focus on hard-to-classify examples.

2. Class Weights

weight[class] = total_samples / (n_classes × count[class])

Example:
- weight_negative = 1000 / (2 × 900) = 0.56
- weight_positive = 1000 / (2 × 100) = 5.0

Positive class contributes 9× more to the loss.

Evaluation Metrics

Metric Formula Interpretation
Accuracy (TP+TN)/Total Overall correctness
Sensitivity TP/(TP+FN) How well we detect toxic
Specificity TN/(TN+FP) How well we detect non-toxic
MCC See below Balanced metric (-1 to +1)
AUC Area under ROC Ranking quality

Why MCC?

  • Accuracy is misleading with imbalanced data
  • MCC considers all four confusion matrix values
  • MCC = 0: random predictions
  • MCC = 1: perfect predictions
  • MCC = -1: completely wrong

Avoiding Data Leakage

Data leakage: When information from test set influences training.

Common Mistakes (Avoided Here)

  1. ❌ Computing class weights on full dataset → ✅ Compute only on train set
  2. ❌ Tuning hyperparameters on test set → ✅ Use validation set
  3. ❌ Evaluating on test every epoch → ✅ Only at the end
  4. ❌ Random split without stratification → ✅ Stratified split

Requirements

  • Python 3.8+
  • PyTorch 2.0+
  • fair-esm (for ESM-2)
  • transformers (for ProstT5)
  • scikit-learn (for metrics and splits)

About

Peptide Toxicity Prediction with pLM Embeddings.

Topics

Resources

Stars

Watchers

Forks

Contributors