Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions fridata.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ def add_embedder_argument(parser, required=True):
)


def _parse_comma_separated_pdb_names(value: str) -> list[str]:
names = [s.strip() for s in value.split(",") if s.strip()]
if not names:
raise argparse.ArgumentTypeError("--name requires at least one non-empty PDB key")
return names


def configure_logging(verbose, log_file=None):
"""Configure logging based on verbose flag and optional log file"""
log_level = logging.DEBUG if verbose else logging.INFO
Expand Down Expand Up @@ -234,7 +241,18 @@ def create_parser():
"--mode",
choices=["structure", "content", "keys"],
default="structure",
help="Display mode: structure (groups/datasets/shapes), content (PDB text via read_all_pdbs_from_h5), or keys (protein codes only)",
help="Display mode: structure (groups/datasets/shapes), content (PDB text via read_pdbs_from_h5), or keys (protein codes only)",
)
inspect_h5_parser.add_argument(
"--name",
dest="pdb_names",
type=_parse_comma_separated_pdb_names,
default=None,
metavar="KEYS",
help=(
"Comma-separated PDB keys to show (content mode only; must match keys in the H5). "
"Use inspect_h5 --mode keys to list keys."
),
)

inspect_idx_parser = subparsers.add_parser(
Expand Down Expand Up @@ -307,7 +325,11 @@ def create_parser():
def main():
parser = create_parser()
args = parser.parse_args()


if args.command == "inspect_h5" and getattr(args, "pdb_names", None) is not None:
if args.mode != "content":
parser.error("--name is only valid with --mode content")

# Load config and raise if not found
from toolbox.config import load_config
try:
Expand Down
55 changes: 39 additions & 16 deletions toolbox/models/embedding/embedder/embedder_type.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,44 @@
from __future__ import annotations

from enum import Enum
from toolbox.models.embedding.embedder.esm2_embedder import ESM2Embedder
from toolbox.models.embedding.embedder.esmc_embedder import ESMCEmbedder
from toolbox.models.embedding.embedder.glm2_embedder import GLM2Embedder
from toolbox.models.embedding.embedder.base_embedder import BaseEmbedder
from typing import TYPE_CHECKING, Type

if TYPE_CHECKING:
from toolbox.models.embedding.embedder.base_embedder import BaseEmbedder


def _lazy_embedder_class(member: "EmbedderType") -> Type[BaseEmbedder]:
name = member.name
if name.startswith("ESM2"):
from toolbox.models.embedding.embedder.esm2_embedder import ESM2Embedder

return ESM2Embedder
if name.startswith("ESMC"):
from toolbox.models.embedding.embedder.esmc_embedder import ESMCEmbedder

return ESMCEmbedder
if name.startswith("GLM2"):
from toolbox.models.embedding.embedder.glm2_embedder import GLM2Embedder

return GLM2Embedder
raise ValueError(f"Unknown embedder kind for {name!r}")


class EmbedderType(Enum):
ESM2_T30_150M = ("esm2_t30_150M_UR50D", ESM2Embedder, 640)
ESM2_T33_650M = ("esm2_t33_650M_UR50D", ESM2Embedder, 1280)
ESMC_300M = ("esmc_300m", ESMCEmbedder, 960)
ESMC_600M = ("esmc_600m", ESMCEmbedder, 1152)
GLM2_150M = ("gLM2_150M", GLM2Embedder, 640)
GLM2_650M = ("gLM2_650M", GLM2Embedder, 1280)

def __init__(self, value, embedder_class: type[BaseEmbedder], embedding_size: int):
self._value_ = value
self.embedder_class: type[BaseEmbedder] = embedder_class
self.embedding_size: int = embedding_size
ESM2_T30_150M = ("esm2_t30_150M_UR50D", 640)
ESM2_T33_650M = ("esm2_t33_650M_UR50D", 1280)
ESMC_300M = ("esmc_300m", 960)
ESMC_600M = ("esmc_600m", 1152)
GLM2_150M = ("gLM2_150M", 640)
GLM2_650M = ("gLM2_650M", 1280)

def __init__(self, model_id: str, embedding_size: int):
self._value_ = model_id
self.embedding_size = embedding_size

@property
def embedder_class(self) -> Type[BaseEmbedder]:
return _lazy_embedder_class(self)

def create_embedder(self) -> BaseEmbedder:
return self.embedder_class(model_name=self.value)
return self.embedder_class(model_name=self.value)
3 changes: 2 additions & 1 deletion toolbox/models/manage_dataset/structures_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
chunk
)
from toolbox.models.manage_dataset.distograms.generate_distograms import generate_distograms
from toolbox.models.embedding.embedding import Embedding
from toolbox.models.utils.from_archive import extract_batch_from_archive
from toolbox.models.utils.create_client import create_client, total_workers
from toolbox.utlis.filter_pdb_codes import filter_pdb_codes
Expand Down Expand Up @@ -613,6 +612,8 @@ def generate_distograms(self):

def generate_embeddings(self):
"""Generate embeddings for the dataset."""
from toolbox.models.embedding.embedding import Embedding

embedding = Embedding(self)
embedding.run()

Expand Down
31 changes: 19 additions & 12 deletions toolbox/scripts/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import os
import zipfile
from pathlib import Path

from dask.distributed import as_completed
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

from toolbox.models.manage_dataset.utils import read_all_pdbs_from_h5
from toolbox.utlis.logging import logger


def process_h5_file(h5_file, dataset_path, output_dir):
Expand All @@ -17,9 +23,6 @@ def process_h5_file(h5_file, dataset_path, output_dir):
code = p.removesuffix(".pdb")
zipf.writestr(f"{code}.pdb", pdb_file_content)

with open(archive_path / f"{code}.pdb", "w") as f:
f.write(pdb_file_content)

os.system(f"tar -czf {str(archive_path)}.tgz {str(archive_path)}")

return str(archive_path)
Expand All @@ -38,17 +41,21 @@ def create_archive(structures_dataset: "StructuresDataset"):
future = client.submit(process_h5_file, h5_file, dataset_path, output_dir)
futures.append(future)

archive_paths = client.gather(futures)
n = len(futures)
logger.info("Building combined PDB archive from %s H5 shard(s)", n)

# Combine the archives into one archive
current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
final_archive_name = f"archive_pdb_{current_time}.zip"
final_archive_name = f"archive_pdb_{structures_dataset.dataset_dir_name()}_{current_time}.zip"
final_archive_path = Path.cwd() / final_archive_name

with zipfile.ZipFile(final_archive_path, "w") as final_zip:
for idx, archive_path in enumerate(archive_paths):
archive_name_in_final = f"{idx}.zip"
# Read the archive file and write it into the final archive
with open(archive_path, "rb") as f:
archive_data = f.read()
final_zip.writestr(archive_name_in_final, archive_data)
with logging_redirect_tqdm():
with tqdm(total=n, desc="H5 shards → final zip", unit="h5") as pbar:
i = 0
for fut in as_completed(futures):
archive_path = fut.result()
with open(archive_path, "rb") as f:
archive_data = f.read()
final_zip.writestr(f"{i}.zip", archive_data)
i += 1
pbar.update(1)
Loading
Loading