diff --git a/toolbox/models/embedding/embedder/base_embedder.py b/toolbox/models/embedding/embedder/base_embedder.py index 3f5d8fd..e59153d 100644 --- a/toolbox/models/embedding/embedder/base_embedder.py +++ b/toolbox/models/embedding/embedder/base_embedder.py @@ -25,8 +25,7 @@ def embed(self, sequences, output_path): batch_index = 0 for prot_id, prot_seq in sequences.items(): embeddings_pure = self.get_embedding(prot_id, prot_seq) - # Embeddings now include CLS and EOS tokens, so shape[0] = len(seq) + 2 - assert len(prot_seq) + 2 == embeddings_pure.shape[0], f'Invalid character in {prot_id}: expected {len(prot_seq) + 2}, got {embeddings_pure.shape[0]}' + assert self.validate_embedding(prot_seq, embeddings_pure), f'Invalid embedding shape for {prot_id}: expected shape[0] to match sequence length' embeddings_pure_batch[prot_id] = embeddings_pure torch.cuda.empty_cache() if len(embeddings_pure_batch) >= self.batch_size: @@ -48,4 +47,7 @@ def embed(self, sequences, output_path): p.join() for prot_id in ids: final_index[prot_id] = str(file_path) - return final_index \ No newline at end of file + return final_index + + def validate_embedding(self, prot_seq, embeddings): + return len(prot_seq) + 2 == embeddings.shape[0] \ No newline at end of file diff --git a/toolbox/models/embedding/embedder/embedder_type.py b/toolbox/models/embedding/embedder/embedder_type.py index 276612d..d89e7b5 100644 --- a/toolbox/models/embedding/embedder/embedder_type.py +++ b/toolbox/models/embedding/embedder/embedder_type.py @@ -1,6 +1,7 @@ from enum import Enum from toolbox.models.embedding.embedder.esm2_embedder import ESM2Embedder from toolbox.models.embedding.embedder.esmc_embedder import ESMCEmbedder +from toolbox.models.embedding.embedder.glm2_embedder import GLM2Embedder from toolbox.models.embedding.embedder.base_embedder import BaseEmbedder class EmbedderType(Enum): @@ -8,8 +9,13 @@ class EmbedderType(Enum): ESM2_T33_650M = ("esm2_t33_650M_UR50D", ESM2Embedder, 1280) ESMC_300M = ("esmc_300m", ESMCEmbedder, 960) ESMC_600M = ("esmc_600m", ESMCEmbedder, 1152) + GLM2_150M = ("gLM2_150M", GLM2Embedder, 640) + GLM2_650M = ("gLM2_650M", GLM2Embedder, 1280) - def __init__(self, value, embedder_class: BaseEmbedder, embedding_size: int): + def __init__(self, value, embedder_class: type[BaseEmbedder], embedding_size: int): self._value_ = value - self.embedder_class: BaseEmbedder = embedder_class - self.embedding_size: int = embedding_size \ No newline at end of file + self.embedder_class: type[BaseEmbedder] = embedder_class + self.embedding_size: int = embedding_size + + def create_embedder(self) -> BaseEmbedder: + return self.embedder_class(model_name=self.value) \ No newline at end of file diff --git a/toolbox/models/embedding/embedder/esmc_embedder.py b/toolbox/models/embedding/embedder/esmc_embedder.py index 99212af..6329324 100644 --- a/toolbox/models/embedding/embedder/esmc_embedder.py +++ b/toolbox/models/embedding/embedder/esmc_embedder.py @@ -14,9 +14,9 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class ESMCEmbedder(BaseEmbedder): - def __init__(self, device=None, batch_size=1000): + def __init__(self, device=None, batch_size=1000, model_name="esmc_600m"): super().__init__(device, batch_size) - self.model = ESMC.from_pretrained("esmc_600m").to(self.device) + self.model = ESMC.from_pretrained(model_name).to(self.device) def get_embedding(self, prot_id, prot_seq): protein = ESMProtein(sequence=prot_seq) diff --git a/toolbox/models/embedding/embedder/glm2_embedder.py b/toolbox/models/embedding/embedder/glm2_embedder.py new file mode 100644 index 0000000..a369c0b --- /dev/null +++ b/toolbox/models/embedding/embedder/glm2_embedder.py @@ -0,0 +1,35 @@ +import torch +from transformers import AutoModel, AutoTokenizer +from .base_embedder import BaseEmbedder + +PREP_SIGN = "<+>" + + +class GLM2Embedder(BaseEmbedder): + def __init__(self, device=None, batch_size=1000, model_name="gLM2_150M"): + super().__init__(device, batch_size) + self.model_name = model_name + use_bfloat16 = self.device.type == "cuda" + dtype = torch.bfloat16 if use_bfloat16 else torch.float32 + self.tokenizer = AutoTokenizer.from_pretrained( + f"tattabio/{model_name}", trust_remote_code=True + ) + self.model = ( + AutoModel.from_pretrained( + f"tattabio/{model_name}", + torch_dtype=dtype, + trust_remote_code=True, + ) + .to(self.device) + ) + + def get_embedding(self, prot_id, prot_seq): + sequence = PREP_SIGN + prot_seq + inputs = self.tokenizer([sequence], return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + outputs = self.model(inputs["input_ids"], output_hidden_states=True) + embeddings = outputs.last_hidden_state[0] + return embeddings.to("cpu").detach().to(torch.float32).numpy() + + def validate_embedding(self, prot_seq, embeddings): + return len(prot_seq) + 1 == embeddings.shape[0] diff --git a/toolbox/models/embedding/embedding.py b/toolbox/models/embedding/embedding.py index 0a9aac1..cde4812 100644 --- a/toolbox/models/embedding/embedding.py +++ b/toolbox/models/embedding/embedding.py @@ -66,9 +66,9 @@ def run(self): if self.structures_dataset.embedder_type is None: self.structures_dataset.embedder_type = EmbedderType.ESM2_T33_650M - # Get the embedder class and create an instance - embedder_class = self.structures_dataset.embedder_type.embedder_class - index_of_new_embeddings = embedder_class().embed(sequences, self.outputs_dir) + # Get the embedder instance and run embedding + embedder = self.structures_dataset.embedder_type.create_embedder() + index_of_new_embeddings = embedder.embed(sequences, self.outputs_dir) present_embeddings.update(index_of_new_embeddings)