Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions toolbox/models/embedding/embedder/base_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def embed(self, sequences, output_path):
batch_index = 0
for prot_id, prot_seq in sequences.items():
embeddings_pure = self.get_embedding(prot_id, prot_seq)
# Embeddings now include CLS and EOS tokens, so shape[0] = len(seq) + 2
assert len(prot_seq) + 2 == embeddings_pure.shape[0], f'Invalid character in {prot_id}: expected {len(prot_seq) + 2}, got {embeddings_pure.shape[0]}'
assert self.validate_embedding(prot_seq, embeddings_pure), f'Invalid embedding shape for {prot_id}: expected shape[0] to match sequence length'
embeddings_pure_batch[prot_id] = embeddings_pure
torch.cuda.empty_cache()
if len(embeddings_pure_batch) >= self.batch_size:
Expand All @@ -48,4 +47,7 @@ def embed(self, sequences, output_path):
p.join()
for prot_id in ids:
final_index[prot_id] = str(file_path)
return final_index
return final_index

def validate_embedding(self, prot_seq, embeddings):
return len(prot_seq) + 2 == embeddings.shape[0]
12 changes: 9 additions & 3 deletions toolbox/models/embedding/embedder/embedder_type.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from enum import Enum
from toolbox.models.embedding.embedder.esm2_embedder import ESM2Embedder
from toolbox.models.embedding.embedder.esmc_embedder import ESMCEmbedder
from toolbox.models.embedding.embedder.glm2_embedder import GLM2Embedder
from toolbox.models.embedding.embedder.base_embedder import BaseEmbedder

class EmbedderType(Enum):
ESM2_T30_150M = ("esm2_t30_150M_UR50D", ESM2Embedder, 640)
ESM2_T33_650M = ("esm2_t33_650M_UR50D", ESM2Embedder, 1280)
ESMC_300M = ("esmc_300m", ESMCEmbedder, 960)
ESMC_600M = ("esmc_600m", ESMCEmbedder, 1152)
GLM2_150M = ("gLM2_150M", GLM2Embedder, 640)
GLM2_650M = ("gLM2_650M", GLM2Embedder, 1280)

def __init__(self, value, embedder_class: BaseEmbedder, embedding_size: int):
def __init__(self, value, embedder_class: type[BaseEmbedder], embedding_size: int):
self._value_ = value
self.embedder_class: BaseEmbedder = embedder_class
self.embedding_size: int = embedding_size
self.embedder_class: type[BaseEmbedder] = embedder_class
self.embedding_size: int = embedding_size

def create_embedder(self) -> BaseEmbedder:
return self.embedder_class(model_name=self.value)
4 changes: 2 additions & 2 deletions toolbox/models/embedding/embedder/esmc_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ESMCEmbedder(BaseEmbedder):
def __init__(self, device=None, batch_size=1000):
def __init__(self, device=None, batch_size=1000, model_name="esmc_600m"):
super().__init__(device, batch_size)
self.model = ESMC.from_pretrained("esmc_600m").to(self.device)
self.model = ESMC.from_pretrained(model_name).to(self.device)

def get_embedding(self, prot_id, prot_seq):
protein = ESMProtein(sequence=prot_seq)
Expand Down
35 changes: 35 additions & 0 deletions toolbox/models/embedding/embedder/glm2_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from transformers import AutoModel, AutoTokenizer
from .base_embedder import BaseEmbedder

PREP_SIGN = "<+>"


class GLM2Embedder(BaseEmbedder):
def __init__(self, device=None, batch_size=1000, model_name="gLM2_150M"):
super().__init__(device, batch_size)
self.model_name = model_name
use_bfloat16 = self.device.type == "cuda"
dtype = torch.bfloat16 if use_bfloat16 else torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(
f"tattabio/{model_name}", trust_remote_code=True
)
self.model = (
AutoModel.from_pretrained(
f"tattabio/{model_name}",
torch_dtype=dtype,
trust_remote_code=True,
)
.to(self.device)
)

def get_embedding(self, prot_id, prot_seq):
sequence = PREP_SIGN + prot_seq
inputs = self.tokenizer([sequence], return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(inputs["input_ids"], output_hidden_states=True)
embeddings = outputs.last_hidden_state[0]
return embeddings.to("cpu").detach().to(torch.float32).numpy()

def validate_embedding(self, prot_seq, embeddings):
return len(prot_seq) + 1 == embeddings.shape[0]
6 changes: 3 additions & 3 deletions toolbox/models/embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def run(self):
if self.structures_dataset.embedder_type is None:
self.structures_dataset.embedder_type = EmbedderType.ESM2_T33_650M

# Get the embedder class and create an instance
embedder_class = self.structures_dataset.embedder_type.embedder_class
index_of_new_embeddings = embedder_class().embed(sequences, self.outputs_dir)
# Get the embedder instance and run embedding
embedder = self.structures_dataset.embedder_type.create_embedder()
index_of_new_embeddings = embedder.embed(sequences, self.outputs_dir)

present_embeddings.update(index_of_new_embeddings)

Expand Down
Loading