TRAFICA: An Open Chromatin Language Model to Improve Transcription Factor Binding Affinity Prediction
- OS: Linux
- Nvidia GPU (CUDA support is need):
The pre-training stage of TRAFICA (base-level tokenization-> max number of tokens) took about 5.8 days on a single Nvidia A100 GPU card.
- Python and other dependencies: environment.yaml
- Creat a new conda environment with the provided '.yaml' file.
conda env create -f environment.yaml
- Activate the conda environment
conda activate TRAFICA
- Loading the pre-trained model and tokenizer using HuggingFace Interface
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
# configuration dict for different TRAFICA versions
config_dict = {
'TRAFICA (BPE-1)': {'model_path':'Allanxu/TRAFICA-BPE1',
'tokenizer_path':'Allanxu/TRAFICA-BPE1',
'tokenization':'BPE'
},
'TRAFICA (BPE-2)': {'model_path':'Allanxu/TRAFICA-BPE2',
'tokenizer_path':'zhihan1996/DNABERT-2-117M',
'tokenization':'BPE'
},
'TRAFICA (4-mer)': {'model_path':'Allanxu/TRAFICA-4_mer',
'tokenizer_path':'Allanxu/TRAFICA-4_mer',
'tokenization':'4_mer'
},
'TRAFICA (5-mer)': {'model_path':'Allanxu/TRAFICA-5_mer',
'tokenizer_path':'Allanxu/TRAFICA-5_mer',
'tokenization':'5_mer'
},
'TRAFICA (6-mer)': {'model_path':'Allanxu/TRAFICA-6_mer',
'tokenizer_path':'Allanxu/TRAFICA-6_mer',
'tokenization':'6_mer'
},
'TRAFICA (base-level)': {'model_path':'Allanxu/TRAFICA-Base_level',
'tokenizer_path':'Allanxu/TRAFICA-Base_level',
'tokenization':'Base-level'
}
}
# tokenizer
Tokenizer = AutoTokenizer.from_pretrained(config_dict['TRAFICA (base-level)']['tokenizer_path'], trust_remote_code=True)
# model
config = AutoConfig.from_pretrained(config_dict['TRAFICA (base-level)']['model_path'])
config.num_labels = 1
model = AutoModelForSequenceClassification.from_pretrained(config_dict['TRAFICA (base-level)']['model_path'], config=config, trust_remote_code=True)- Loading the fine-tuned LoRA module and affinity predictor for specific TFs
from peft import PeftModel
import torch
lora_path = '/<Path of fine-tuned LoRA>/Base-level/PRJEB3289/10000/ATF7_TGGGCG30NCGT' # example for TF ATF7
# LoRA and Affinity predictor
state_dict = torch.load(os.path.join(lora_path,"predict_head_weights.pth"), weights_only=True)
model.classifier.load_state_dict( state_dict['PREDICT_HEAD'] )
model = PeftModel.from_pretrained(model, os.path.join(lora_path,"lora_adapter"))Fine-tuned TF LoRAs
Available at HuggingFace (large size)
- Make prediction
from util.py import piece_sequences # Src/util.py
# Input construction
sequences = ['CCAGAAGACAACTTGTAGAAATAAGCAAAA', 'ATTGCGCCCCAGCCCCACACCCACACGCAT']
tokens_batch = piece_sequences(sequences, config_dict['TRAFICA (base-level)']['tokenization'])
# tokens_batch = ['C C A G A A G A C A A C T T G T A G A A A T A A G C A A A A', 'A T T G C G C C C C A G C C C C A C A C C C A C A C G C A T']
inputs = Tokenizer(tokens_batch, return_tensors="pt", padding=True)
# Prediction
with torch.no_grad():
outputs = model(**inputs)
logit = outputs.logits
print(f"Predicted relative affinities: {logit.flatten()}")- Details of TRAFICA pre-training and LoRA fine-tuning (Click here)
- HT-SELEX Benchmark: Zenodo
Mr. Yu Xu, Email: csyuxu@comp.hkbu.edu.hk; allanxu20@gmail.com
Prof. Eric Lu Zhang, Email: ericluzhang@hkbu.edu.hk
