From 868fbf608c167c6d30b0d06cb940747643844873 Mon Sep 17 00:00:00 2001 From: Adam Nowak Date: Sun, 22 Mar 2026 22:54:47 +0000 Subject: [PATCH 1/4] h5 parse content with names --- fridata.py | 26 +++++++++++++++++++-- toolbox/scripts/command_parser.py | 1 + toolbox/utlis/inspect_h5.py | 38 ++++++++++++++++++++++++------- 3 files changed, 55 insertions(+), 10 deletions(-) diff --git a/fridata.py b/fridata.py index f9a73a9..845b233 100644 --- a/fridata.py +++ b/fridata.py @@ -115,6 +115,13 @@ def add_embedder_argument(parser, required=True): ) +def _parse_comma_separated_pdb_names(value: str) -> list[str]: + names = [s.strip() for s in value.split(",") if s.strip()] + if not names: + raise argparse.ArgumentTypeError("--name requires at least one non-empty PDB key") + return names + + def configure_logging(verbose, log_file=None): """Configure logging based on verbose flag and optional log file""" log_level = logging.DEBUG if verbose else logging.INFO @@ -234,7 +241,18 @@ def create_parser(): "--mode", choices=["structure", "content", "keys"], default="structure", - help="Display mode: structure (groups/datasets/shapes), content (PDB text via read_all_pdbs_from_h5), or keys (protein codes only)", + help="Display mode: structure (groups/datasets/shapes), content (PDB text via read_pdbs_from_h5), or keys (protein codes only)", + ) + inspect_h5_parser.add_argument( + "--name", + dest="pdb_names", + type=_parse_comma_separated_pdb_names, + default=None, + metavar="KEYS", + help=( + "Comma-separated PDB keys to show (content mode only; must match keys in the H5). " + "Use inspect_h5 --mode keys to list keys." + ), ) inspect_idx_parser = subparsers.add_parser( @@ -307,7 +325,11 @@ def create_parser(): def main(): parser = create_parser() args = parser.parse_args() - + + if args.command == "inspect_h5" and getattr(args, "pdb_names", None) is not None: + if args.mode != "content": + parser.error("--name is only valid with --mode content") + # Load config and raise if not found from toolbox.config import load_config try: diff --git a/toolbox/scripts/command_parser.py b/toolbox/scripts/command_parser.py index df6e501..0236e5e 100644 --- a/toolbox/scripts/command_parser.py +++ b/toolbox/scripts/command_parser.py @@ -176,6 +176,7 @@ def inspect_h5(self): inspect_h5( Path(self.args.file), mode=getattr(self.args, "mode", "structure"), + names=getattr(self.args, "pdb_names", None), ) def inspect_idx(self): diff --git a/toolbox/utlis/inspect_h5.py b/toolbox/utlis/inspect_h5.py index 878573f..937ec3f 100644 --- a/toolbox/utlis/inspect_h5.py +++ b/toolbox/utlis/inspect_h5.py @@ -3,10 +3,11 @@ import subprocess import tempfile from pathlib import Path +from typing import Optional, Sequence import h5py -from toolbox.models.manage_dataset.utils import read_all_pdbs_from_h5 +from toolbox.models.manage_dataset.utils import read_pdbs_from_h5 from toolbox.utlis.logging import logger @@ -23,13 +24,19 @@ def _inspect_in_vi(content: str) -> None: Path(tmp_path).unlink(missing_ok=True) -def inspect_h5(path: Path, mode: str = "structure") -> None: +def inspect_h5( + path: Path, + mode: str = "structure", + names: Optional[Sequence[str]] = None, +) -> None: """Read h5 file and display in vi editor. Args: path: Path to the HDF5 file. mode: Display mode - 'structure' (groups/datasets/shapes), 'content' - (PDB text via read_all_pdbs_from_h5), or 'keys' (protein codes only). + (PDB text via read_pdbs_from_h5), or 'keys' (protein codes only). + names: When mode is 'content', optional list of PDB keys to include (exact + match). Order in the buffer follows this list. Ignored for other modes. """ if not path.exists(): logger.error(f"File not found: {path}") @@ -55,14 +62,29 @@ def walk(name, obj, prefix=""): content = "\n".join(lines) elif mode == "content": - pdbs = read_all_pdbs_from_h5(str(path)) + codes = list(names) if names else None + pdbs = read_pdbs_from_h5(str(path), codes) if pdbs is None: logger.error("Failed to read h5 file") return - for code, pdb_text in pdbs.items(): - lines.append(f"=== {code} ===") - lines.append(pdb_text) - lines.append("") + if codes: + missing = set(codes) - set(pdbs.keys()) + if missing: + logger.warning( + "PDB keys not found in h5 (skipped): %s", + ", ".join(sorted(missing)), + ) + for code in codes: + if code not in pdbs: + continue + lines.append(f"=== {code} ===") + lines.append(pdbs[code]) + lines.append("") + else: + for code, pdb_text in pdbs.items(): + lines.append(f"=== {code} ===") + lines.append(pdb_text) + lines.append("") content = "\n".join(lines) else: with h5py.File(path, "r") as hf: From fa9e7343cf055ce1ae88241f88c2be5cdb9d631b Mon Sep 17 00:00:00 2001 From: Adam Nowak Date: Sun, 22 Mar 2026 23:27:03 +0000 Subject: [PATCH 2/4] Improve fridata module loading speed - optimize imports in command_parser; optimize embedding class imports --- .../embedding/embedder/embedder_type.py | 55 +++++-- .../manage_dataset/structures_dataset.py | 3 +- toolbox/scripts/command_parser.py | 154 +++++++++++------- 3 files changed, 134 insertions(+), 78 deletions(-) diff --git a/toolbox/models/embedding/embedder/embedder_type.py b/toolbox/models/embedding/embedder/embedder_type.py index d89e7b5..383209f 100644 --- a/toolbox/models/embedding/embedder/embedder_type.py +++ b/toolbox/models/embedding/embedder/embedder_type.py @@ -1,21 +1,44 @@ +from __future__ import annotations + from enum import Enum -from toolbox.models.embedding.embedder.esm2_embedder import ESM2Embedder -from toolbox.models.embedding.embedder.esmc_embedder import ESMCEmbedder -from toolbox.models.embedding.embedder.glm2_embedder import GLM2Embedder -from toolbox.models.embedding.embedder.base_embedder import BaseEmbedder +from typing import TYPE_CHECKING, Type + +if TYPE_CHECKING: + from toolbox.models.embedding.embedder.base_embedder import BaseEmbedder + + +def _lazy_embedder_class(member: "EmbedderType") -> Type[BaseEmbedder]: + name = member.name + if name.startswith("ESM2"): + from toolbox.models.embedding.embedder.esm2_embedder import ESM2Embedder + + return ESM2Embedder + if name.startswith("ESMC"): + from toolbox.models.embedding.embedder.esmc_embedder import ESMCEmbedder + + return ESMCEmbedder + if name.startswith("GLM2"): + from toolbox.models.embedding.embedder.glm2_embedder import GLM2Embedder + + return GLM2Embedder + raise ValueError(f"Unknown embedder kind for {name!r}") + class EmbedderType(Enum): - ESM2_T30_150M = ("esm2_t30_150M_UR50D", ESM2Embedder, 640) - ESM2_T33_650M = ("esm2_t33_650M_UR50D", ESM2Embedder, 1280) - ESMC_300M = ("esmc_300m", ESMCEmbedder, 960) - ESMC_600M = ("esmc_600m", ESMCEmbedder, 1152) - GLM2_150M = ("gLM2_150M", GLM2Embedder, 640) - GLM2_650M = ("gLM2_650M", GLM2Embedder, 1280) - - def __init__(self, value, embedder_class: type[BaseEmbedder], embedding_size: int): - self._value_ = value - self.embedder_class: type[BaseEmbedder] = embedder_class - self.embedding_size: int = embedding_size + ESM2_T30_150M = ("esm2_t30_150M_UR50D", 640) + ESM2_T33_650M = ("esm2_t33_650M_UR50D", 1280) + ESMC_300M = ("esmc_300m", 960) + ESMC_600M = ("esmc_600m", 1152) + GLM2_150M = ("gLM2_150M", 640) + GLM2_650M = ("gLM2_650M", 1280) + + def __init__(self, model_id: str, embedding_size: int): + self._value_ = model_id + self.embedding_size = embedding_size + + @property + def embedder_class(self) -> Type[BaseEmbedder]: + return _lazy_embedder_class(self) def create_embedder(self) -> BaseEmbedder: - return self.embedder_class(model_name=self.value) \ No newline at end of file + return self.embedder_class(model_name=self.value) diff --git a/toolbox/models/manage_dataset/structures_dataset.py b/toolbox/models/manage_dataset/structures_dataset.py index ec275ab..0840efb 100644 --- a/toolbox/models/manage_dataset/structures_dataset.py +++ b/toolbox/models/manage_dataset/structures_dataset.py @@ -40,7 +40,6 @@ chunk ) from toolbox.models.manage_dataset.distograms.generate_distograms import generate_distograms -from toolbox.models.embedding.embedding import Embedding from toolbox.models.utils.from_archive import extract_batch_from_archive from toolbox.models.utils.create_client import create_client, total_workers from toolbox.utlis.filter_pdb_codes import filter_pdb_codes @@ -613,6 +612,8 @@ def generate_distograms(self): def generate_embeddings(self): """Generate embeddings for the dataset.""" + from toolbox.models.embedding.embedding import Embedding + embedding = Embedding(self) embedding.run() diff --git a/toolbox/scripts/command_parser.py b/toolbox/scripts/command_parser.py index 0236e5e..93f33ba 100644 --- a/toolbox/scripts/command_parser.py +++ b/toolbox/scripts/command_parser.py @@ -1,31 +1,19 @@ +from __future__ import annotations + from argparse import Namespace import json import logging import sys import traceback -from pathlib import Path - -from toolbox.models.chains.verify_chains import verify_chains -from toolbox.models.manage_dataset.distograms.generate_distograms import ( - read_distograms_from_file, -) -from toolbox.models.manage_dataset.structures_dataset import FatalDatasetError, StructuresDataset -from toolbox.models.manage_dataset.utils import ( - read_pdbs_from_h5, - format_time, -) -from toolbox.models.utils.create_client import create_client -from toolbox.scripts.archive import create_archive -from toolbox.models.embedding.embedder.embedder_type import EmbedderType - import time +from pathlib import Path +from typing import TYPE_CHECKING -from toolbox.utlis.logging import logger from toolbox.config import Config -from toolbox.viewer.export_index_html import export_index_view -from toolbox.utlis.inspect_h5 import inspect_h5 -from toolbox.utlis.inspect_idx import inspect_idx -from toolbox.utlis.remove_dataset import remove_dataset +from toolbox.utlis.logging import logger + +if TYPE_CHECKING: + from toolbox.models.manage_dataset.structures_dataset import StructuresDataset class CommandParser: @@ -35,6 +23,9 @@ def __init__(self, args: Namespace, config: Config): self.config = config def _create_dataset_from_path_(self) -> StructuresDataset: + from toolbox.models.manage_dataset.structures_dataset import StructuresDataset + from toolbox.models.utils.create_client import create_client + if self.structures_dataset is not None: return self.structures_dataset path = self.args.file_path @@ -59,33 +50,38 @@ def _log_command(self): """Log the complete command line that started the program.""" full_command = " ".join(sys.argv) logger.info(f"Started with command: {full_command}") - + def _configure_dataset_logging(self): """Configure logging to dataset log file if not already specified.""" - if not hasattr(self.args, 'log_file') or self.args.log_file is None: + if not hasattr(self.args, "log_file") or self.args.log_file is None: from toolbox.utlis.colored_logging import setup_logging_with_file + log_level = logging.DEBUG if self.args.verbose else logging.INFO - log_format = '%(asctime)s %(levelname)s %(message)s' + log_format = "%(asctime)s %(levelname)s %(message)s" setup_logging_with_file( - level=log_level, - fmt=log_format, - log_file=self.structures_dataset.log_file_path() + level=log_level, + fmt=log_format, + log_file=self.structures_dataset.log_file_path(), ) logger.info(f"Logging configured to: {self.structures_dataset.log_file_path()}") # Log the complete command line for dataset operations self._log_command() def dataset(self): + from toolbox.models.embedding.embedder.embedder_type import EmbedderType + from toolbox.models.manage_dataset.structures_dataset import StructuresDataset + from toolbox.models.manage_dataset.utils import format_time + start = time.time() - + # Convert embedder string to EmbedderType enum if provided embedder_type = None - if hasattr(self.args, 'embedder') and self.args.embedder: + if hasattr(self.args, "embedder") and self.args.embedder: for embedder_enum in EmbedderType: if embedder_enum.value == self.args.embedder: embedder_type = embedder_enum break - + dataset = StructuresDataset( db_type=self.args.db, collection_type=self.args.collection, @@ -101,39 +97,40 @@ def dataset(self): binary_data_download=self.args.binary, is_hpc_cluster=self.args.slurm, input_path=self.args.input_path, - verbose=self.args.verbose if hasattr(self.args, 'verbose') else False, + verbose=self.args.verbose if hasattr(self.args, "verbose") else False, config=self.config, - embedder_type=embedder_type + embedder_type=embedder_type, ) self.structures_dataset = dataset - + # Configure logging to dataset log file if not already specified self._configure_dataset_logging() - + dataset.create_dataset() - + # Print dataset name in special format for shell script parsing dataset_name = dataset.dataset_dir_name() print(f"DATASET_NAME:{dataset_name}") - + end = time.time() logger.info(f"Total time: {format_time(end - start)}") return dataset def generate_embeddings(self): + from toolbox.models.embedding.embedder.embedder_type import EmbedderType + self._create_dataset_from_path_() - + # Configure logging to dataset log file if not already specified self._configure_dataset_logging() - + # Set embedder type if provided - if hasattr(self.args, 'embedder') and self.args.embedder: + if hasattr(self.args, "embedder") and self.args.embedder: for embedder_enum in EmbedderType: if embedder_enum.value == self.args.embedder: self.structures_dataset.embedder_type = embedder_enum break - - + self.structures_dataset.generate_embeddings() def load(self): @@ -142,37 +139,49 @@ def load(self): def generate_sequence(self): self._create_dataset_from_path_() - + # Configure logging to dataset log file if not already specified self._configure_dataset_logging() - + self.structures_dataset.extract_sequence_and_coordinates( self.args.ca_mask, self.args.no_substitution ) def generate_distograms(self): self._create_dataset_from_path_() - + # Configure logging to dataset log file if not already specified self._configure_dataset_logging() - + self.structures_dataset.generate_distograms() def read_distograms(self): + from toolbox.models.manage_dataset.distograms.generate_distograms import ( + read_distograms_from_file, + ) + logger.info(read_distograms_from_file(self.args.file_path)) def read_pdbs(self): - read_pdbs(self.args.file_path, self.args.ids, self.args.to_directory, self.args.print) + read_pdbs( + self.args.file_path, self.args.ids, self.args.to_directory, self.args.print + ) def verify_chains(self): + from toolbox.models.chains.verify_chains import verify_chains + self._create_dataset_from_path_() verify_chains(self.structures_dataset, "./toolbox/pdb_seqres.txt") def create_archive(self): + from toolbox.scripts.archive import create_archive + self._create_dataset_from_path_() create_archive(self.structures_dataset) def inspect_h5(self): + from toolbox.utlis.inspect_h5 import inspect_h5 + inspect_h5( Path(self.args.file), mode=getattr(self.args, "mode", "structure"), @@ -180,32 +189,49 @@ def inspect_h5(self): ) def inspect_idx(self): + from toolbox.utlis.inspect_idx import inspect_idx + inspect_idx(Path(self.args.file)) def remove_dataset(self): + from toolbox.utlis.remove_dataset import remove_dataset + remove_dataset(self.args.name, self.config) def create_dashboard(self): # CLI handler for create_dashboard (formerly export_index_view) # Uses global config from args/config loaded in fridata.py + from toolbox.viewer.export_index_html import export_index_view + index_types = None - if hasattr(self.args, 'index_types') and self.args.index_types and self.args.index_types != 'all': - index_types = [s.strip() for s in self.args.index_types.split(',') if s.strip()] + if ( + hasattr(self.args, "index_types") + and self.args.index_types + and self.args.index_types != "all" + ): + index_types = [ + s.strip() + for s in self.args.index_types.split(",") + if s.strip() + ] out_path = export_index_view( config=self.config, - dataset=getattr(self.args, 'dataset', None), - dataset_slug=getattr(self.args, 'dataset_slug', None), - root=getattr(self.args, 'root', None), + dataset=getattr(self.args, "dataset", None), + dataset_slug=getattr(self.args, "dataset_slug", None), + root=getattr(self.args, "root", None), index_types=index_types, - output_dir=getattr(self.args, 'output_dir', None), + output_dir=getattr(self.args, "output_dir", None), ) logger.info(f"Report generated: {out_path}") def input_generation(self): + from toolbox.models.manage_dataset.structures_dataset import FatalDatasetError + from toolbox.models.manage_dataset.utils import format_time + total_time = time.time() is_error = False - try: + try: self.dataset() except FatalDatasetError as e: logger.error("Fatal error! Exiting...") @@ -224,7 +250,7 @@ def input_generation(self): except Exception as e: print_exc(e) is_error = True - try: + try: self.structures_dataset.generate_embeddings() except Exception as e: print_exc(e) @@ -243,23 +269,25 @@ def run(self): self.cleanup() else: raise ValueError(f"Unknown command - {self.args.command}") - + def cleanup(self): - ds = getattr(self, 'structures_dataset', None) - client = getattr(ds, '_client', None) if ds is not None else None + ds = getattr(self, "structures_dataset", None) + client = getattr(ds, "_client", None) if ds is not None else None if client: import warnings + import distributed + warnings.simplefilter("ignore", distributed.comm.core.CommClosedError) # Suppress noisy tornado/asyncio tracebacks that fire during # nanny shutdown (TimeoutError / CancelledError). These are # harmless – the work is already done – but alarming for users. - logging.getLogger('tornado.application').setLevel(logging.CRITICAL) - logging.getLogger('distributed.nanny').setLevel(logging.CRITICAL) - logging.getLogger('distributed.process').setLevel(logging.CRITICAL) + logging.getLogger("tornado.application").setLevel(logging.CRITICAL) + logging.getLogger("distributed.nanny").setLevel(logging.CRITICAL) + logging.getLogger("distributed.process").setLevel(logging.CRITICAL) - cluster = getattr(client, 'cluster', None) + cluster = getattr(client, "cluster", None) try: client.close() except Exception: @@ -271,11 +299,15 @@ def cleanup(self): pass ds._client = None + def print_exc(e): logger.error(f"Error ({type(e)}): {str(e)}") logger.error(traceback.format_exc()) + def read_pdbs(file_path, ids, to_directory, is_print): + from toolbox.models.manage_dataset.utils import read_pdbs_from_h5 + if ids.exists(): ids = ids.read_text().splitlines() @@ -283,7 +315,7 @@ def read_pdbs(file_path, ids, to_directory, is_print): if is_print: logger.info(json.dumps(pdbs_dict)) - + if to_directory: extract_dir: Path = to_directory if not extract_dir.exists() and not extract_dir.is_dir(): From 43371307eb754f298ec9fe8c7eb7b6beeecd075b Mon Sep 17 00:00:00 2001 From: Adam Nowak Date: Mon, 23 Mar 2026 18:41:58 +0000 Subject: [PATCH 3/4] Create dataset - small safety fix when reading json file --- toolbox/scripts/command_parser.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/toolbox/scripts/command_parser.py b/toolbox/scripts/command_parser.py index 93f33ba..632fa55 100644 --- a/toolbox/scripts/command_parser.py +++ b/toolbox/scripts/command_parser.py @@ -34,6 +34,9 @@ def _create_dataset_from_path_(self) -> StructuresDataset: (path / "dataset.json").read_text() ) elif path.is_file(): + if path.suffix != ".json": + logger.error("Dataset path is not valid") + raise FileNotFoundError self.structures_dataset = StructuresDataset.model_validate_json( path.read_text() ) From c9f138ea97b6f0658a9e53d8eb5b46a53e831b11 Mon Sep 17 00:00:00 2001 From: Adam Nowak Date: Mon, 23 Mar 2026 18:42:38 +0000 Subject: [PATCH 4/4] Report progress + parallel collection of results --- toolbox/scripts/archive.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/toolbox/scripts/archive.py b/toolbox/scripts/archive.py index e92a4c1..0c4ca30 100644 --- a/toolbox/scripts/archive.py +++ b/toolbox/scripts/archive.py @@ -3,7 +3,13 @@ import os import zipfile from pathlib import Path + +from dask.distributed import as_completed +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + from toolbox.models.manage_dataset.utils import read_all_pdbs_from_h5 +from toolbox.utlis.logging import logger def process_h5_file(h5_file, dataset_path, output_dir): @@ -17,9 +23,6 @@ def process_h5_file(h5_file, dataset_path, output_dir): code = p.removesuffix(".pdb") zipf.writestr(f"{code}.pdb", pdb_file_content) - with open(archive_path / f"{code}.pdb", "w") as f: - f.write(pdb_file_content) - os.system(f"tar -czf {str(archive_path)}.tgz {str(archive_path)}") return str(archive_path) @@ -38,17 +41,21 @@ def create_archive(structures_dataset: "StructuresDataset"): future = client.submit(process_h5_file, h5_file, dataset_path, output_dir) futures.append(future) - archive_paths = client.gather(futures) + n = len(futures) + logger.info("Building combined PDB archive from %s H5 shard(s)", n) - # Combine the archives into one archive current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - final_archive_name = f"archive_pdb_{current_time}.zip" + final_archive_name = f"archive_pdb_{structures_dataset.dataset_dir_name()}_{current_time}.zip" final_archive_path = Path.cwd() / final_archive_name with zipfile.ZipFile(final_archive_path, "w") as final_zip: - for idx, archive_path in enumerate(archive_paths): - archive_name_in_final = f"{idx}.zip" - # Read the archive file and write it into the final archive - with open(archive_path, "rb") as f: - archive_data = f.read() - final_zip.writestr(archive_name_in_final, archive_data) + with logging_redirect_tqdm(): + with tqdm(total=n, desc="H5 shards → final zip", unit="h5") as pbar: + i = 0 + for fut in as_completed(futures): + archive_path = fut.result() + with open(archive_path, "rb") as f: + archive_data = f.read() + final_zip.writestr(f"{i}.zip", archive_data) + i += 1 + pbar.update(1)