diff --git a/toolbox/models/embedding/embedder/embedder_type.py b/toolbox/models/embedding/embedder/embedder_type.py index 45a93a3..276612d 100644 --- a/toolbox/models/embedding/embedder/embedder_type.py +++ b/toolbox/models/embedding/embedder/embedder_type.py @@ -1,8 +1,6 @@ from enum import Enum from toolbox.models.embedding.embedder.esm2_embedder import ESM2Embedder from toolbox.models.embedding.embedder.esmc_embedder import ESMCEmbedder -from toolbox.models.embedding.embedder.esmc_performance_embedder import ESMCPerformanceEmbedder -from toolbox.models.embedding.embedder.esm2_performance_embedder import ESM2PerformanceEmbedder from toolbox.models.embedding.embedder.base_embedder import BaseEmbedder class EmbedderType(Enum): @@ -10,8 +8,6 @@ class EmbedderType(Enum): ESM2_T33_650M = ("esm2_t33_650M_UR50D", ESM2Embedder, 1280) ESMC_300M = ("esmc_300m", ESMCEmbedder, 960) ESMC_600M = ("esmc_600m", ESMCEmbedder, 1152) - ESMC_600M_PERFORMANCE = ("esmc_600m_performance", ESMCPerformanceEmbedder, 1152) - ESM2_T33_650M_PERFORMANCE = ("esm2_t33_650M_UR50D_performance", ESM2PerformanceEmbedder, 1280) def __init__(self, value, embedder_class: BaseEmbedder, embedding_size: int): self._value_ = value diff --git a/toolbox/models/embedding/embedder/esm2_performance_embedder.py b/toolbox/models/embedding/embedder/esm2_performance_embedder.py deleted file mode 100644 index ba052a7..0000000 --- a/toolbox/models/embedding/embedder/esm2_performance_embedder.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Dict -from pathlib import Path -import torch -import tqdm -import h5py -import gc # Added for garbage collection -from transformers import AutoTokenizer, AutoModelForMaskedLM -from multiprocessing import Process -from toolbox.models.embedding.utils import save_batch -from .base_embedder import BaseEmbedder - -from toolbox.models.embedding.embedder.performance_comparison_embedder import PerformanceComparisonEmbedder - -from toolbox.utlis.logging import logger -# Parameters -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -class ESM2PerformanceEmbedder(PerformanceComparisonEmbedder): - def __init__(self, device=None, batch_size=1000, model_name='esm2_t33_650M_UR50D'): - super().__init__(device, batch_size, "both") - self.model_name = model_name - self.tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_name}") - self.model = AutoModelForMaskedLM.from_pretrained(f"facebook/{model_name}").to(self.device).to(torch.float32) - - - def get_embedding(self, prot_id, prot_seq): - inputs = self.tokenizer(prot_seq, return_tensors="pt") - inputs = {k: v.to(self.device) for k, v in inputs.items()} - outputs = self.model(**inputs, output_hidden_states=True) - embeddings = outputs.hidden_states[-1] - embeddings = embeddings.detach().to('cpu', non_blocking=True)[0, 1:-1] - embeddings = embeddings.to(torch.float32).numpy() if embeddings.dtype != torch.float32 else embeddings.numpy() - return embeddings - - def get_embedding_profiled(self, prot_id, prot_seq): - import time - - timing = {} - - # Time tokenization - t0 = time.perf_counter() - inputs = self.tokenizer(prot_seq, return_tensors="pt") - t1 = time.perf_counter() - timing["tokenization"] = t1 - t0 - - # Time device transfer - t0 = time.perf_counter() - inputs = {k: v.to(self.device) for k, v in inputs.items()} - t1 = time.perf_counter() - timing["to_device"] = t1 - t0 - - # Time model inference - t0 = time.perf_counter() - outputs = self.model(**inputs, output_hidden_states=True) - embeddings = outputs.hidden_states[-1] - t1 = time.perf_counter() - timing["inference"] = t1 - t0 - - # Time post-processing - t0 = time.perf_counter() - embeddings = embeddings.detach().to('cpu', non_blocking=True)[0, 1:-1] - t1 = time.perf_counter() - timing["post_processing_get_embeddings"] = t1 - t0 - t0 = time.perf_counter() - embeddings = embeddings.to(torch.float32).numpy() if embeddings.dtype != torch.float32 else embeddings.numpy() - t1 = time.perf_counter() - timing["post_processing_convert_to_numpy"] = t1 - t0 - - return embeddings, timing - - def _initialize_model(self): - if self.model is not None: - del self.model - torch.cuda.empty_cache() - gc.collect() - - logger.info(f"Loading ESM2 model: {self.model_name}") - - self.model = AutoModelForMaskedLM.from_pretrained(f"facebook/{self.model_name}").to(self.device).to(torch.float32) - self.model.eval() - logger.info(f"ESM2 model loaded successfully on {self.device}") \ No newline at end of file diff --git a/toolbox/models/embedding/embedder/esmc_performance_embedder.py b/toolbox/models/embedding/embedder/esmc_performance_embedder.py deleted file mode 100644 index f80ae81..0000000 --- a/toolbox/models/embedding/embedder/esmc_performance_embedder.py +++ /dev/null @@ -1,122 +0,0 @@ -""" -ESMC Performance Comparison Embedder - -Concrete implementation of the performance comparison embedder using the ESMC (Evolutionary Scale Modeling Compound) model. -This embedder compares parallel batch saving vs batch-end saving strategies specifically for ESMC embeddings. -""" - -import torch -from pathlib import Path -from typing import Dict -import gc - -from toolbox.models.embedding.embedder.performance_comparison_embedder import PerformanceComparisonEmbedder -from toolbox.utlis.logging import logger - -# Import ESMC dependencies - same as the original ESMC embedder -try: - from esm.models.esmc import ESMC - from esm.sdk.api import ESMProtein, LogitsConfig - ESMC_AVAILABLE = True -except ImportError as e: - logger.warning(f"ESMC not available: {e}") - ESMC_AVAILABLE = False - - -class ESMCPerformanceEmbedder(PerformanceComparisonEmbedder): - """ - ESMC-based performance comparison embedder. - - Uses the ESMC model to generate embeddings while comparing the performance - of parallel batch saving vs batch-end saving strategies. - """ - - def __init__(self, device=None, batch_size=1000, enable_detailed_profiling=True, - comparison_mode="both", model_name="esmc_600m"): - """ - Initialize the ESMC performance comparison embedder. - - Args: - device: PyTorch device to use for computation - batch_size: Number of sequences to process before saving a batch (for parallel mode) - enable_detailed_profiling: If True, enables per-sequence profiling - comparison_mode: "both", "parallel_only", or "batch_end_only" - model_name: ESMC model name to use - """ - if not ESMC_AVAILABLE: - raise ImportError("ESMC is not available. Please install the required dependencies.") - - super().__init__(device, batch_size, enable_detailed_profiling, comparison_mode) - self.model_name = model_name - self.model = None - - logger.info(f"Initialized ESMC Performance Embedder with model: {model_name}") - - def _initialize_model(self): - """Initialize the ESMC model.""" - if self.model is not None: - # Clean up existing model - del self.model - torch.cuda.empty_cache() - gc.collect() - - logger.info(f"Loading ESMC model: {self.model_name}") - - try: - self.model = ESMC.from_pretrained(self.model_name).to(self.device) - self.model.eval() - logger.info(f"ESMC model loaded successfully on {self.device}") - except Exception as e: - logger.error(f"Failed to load ESMC model: {e}") - raise - - def get_embedding(self, prot_id: str, prot_seq: str): - """ - Generate embedding for a protein sequence using ESMC. - - Args: - prot_id: Unique identifier for the protein - prot_seq: Protein sequence string - - Returns: - np.ndarray: Embedding tensor of shape (seq_length, embedding_dim) - """ - - - # Create ESM protein object - protein = ESMProtein(sequence=prot_seq) - - # Encode the protein - protein_tensor = self.model.encode(protein) - - # Get logits with embeddings - logits_output = self.model.logits( - protein_tensor, - LogitsConfig(sequence=True, return_embeddings=True) - ) - - # Extract embeddings (remove start/end tokens) - embeddings = logits_output.embeddings[0, 1:-1, :].to('cpu').detach().to(torch.float32).numpy() - - # Clean up intermediate tensors - del logits_output, protein_tensor, protein - - return embeddings - - - def embed(self, sequences: Dict[str, str], output_path: Path) -> Dict[str, str]: - """ - Run ESMC performance comparison between parallel batch saving and batch-end saving. - - Args: - sequences: Dictionary mapping protein IDs to sequences - output_path: Path to save embedding files and performance reports - - Returns: - dict: Final index mapping protein IDs to file paths - """ - logger.info(f"Starting ESMC performance comparison for {len(sequences)} sequences") - logger.info(f"Model: {self.model_name}, Device: {self.device}") - - # Run the parent comparison - return super().embed(sequences, output_path) diff --git a/toolbox/models/embedding/embedder/performance_comparison_embedder.py b/toolbox/models/embedding/embedder/performance_comparison_embedder.py deleted file mode 100644 index 90543ca..0000000 --- a/toolbox/models/embedding/embedder/performance_comparison_embedder.py +++ /dev/null @@ -1,528 +0,0 @@ -""" -Performance Comparison Embedder - -This module provides a specialized embedder that compares the performance of two different -saving strategies: -1. Parallel batch saving - saves batches as they are generated (current approach) -2. Batch-end saving - accumulates all embeddings and saves them in one operation at the end - -The embedder provides comprehensive timing analysis and metrics to compare the efficiency -of both approaches. -""" - -from abc import abstractmethod -from pathlib import Path -from multiprocessing import Process -from contextlib import nullcontext -import torch -import gc -import time -import json -import os -import threading -from typing import Dict, List, Tuple, Any -import h5py - -from toolbox.models.embedding.embedder.base_embedder import BaseEmbedder -from toolbox.models.embedding.utils_profiled import ( - create_shared_memory_batch, - save_batch_profiled_from_shared_memory, -) -from toolbox.utlis.logging import logger - - -def save_all_embeddings_at_end(output_path: Path, all_embeddings: Dict[str, torch.Tensor]) -> Dict[str, Any]: - """ - Save all embeddings in a single operation at the end. - - Args: - output_path: Directory to save the embeddings - all_embeddings: Dictionary mapping protein IDs to their embedding tensors - - Returns: - dict: Time performance metrics for the operation - """ - start_time = time.perf_counter() - output_file = output_path / "all_embeddings.h5" - - # Initialize timing log entry - log_entry = { - "strategy": "batch_end_save", - "num_sequences": len(all_embeddings), - "start_time": start_time - } - - try: - # Time the file writing process - with h5py.File(output_file, 'w') as f: - for seq_id, embedding in all_embeddings.items(): - f.create_dataset(seq_id, data=embedding) - f.flush() - - end_time = time.perf_counter() - total_duration = end_time - start_time - - # Calculate time performance metrics - log_entry.update({ - "end_time": end_time, - "total_duration": total_duration, - "sequences_per_second": len(all_embeddings) / total_duration if total_duration > 0 else 0, - "success": True - }) - - except Exception as e: - error_time = time.perf_counter() - log_entry.update({ - "end_time": error_time, - "total_duration": error_time - start_time, - "success": False, - "error": str(e) - }) - raise - - return log_entry - - -class PerformanceComparisonEmbedder(BaseEmbedder): - """ - Performance comparison embedder that tests both parallel batch saving and batch-end saving strategies. - - This embedder runs the embedding process twice using different saving strategies: - 1. Parallel batch saving (saves batches as they are generated in parallel processes) - 2. Batch-end saving (accumulates all embeddings and saves them at the end) - - It provides comprehensive performance analysis and comparison between the two approaches. - """ - - def __init__(self, device=None, batch_size=1000, comparison_mode="both"): - """ - Initialize the performance comparison embedder. - - Args: - device: PyTorch device to use for computation - batch_size: Number of sequences to process before saving a batch (for parallel mode) - enable_detailed_profiling: If True, enables per-sequence profiling - comparison_mode: "both", "parallel_only", or "batch_end_only" - """ - # OVERWRITE_BATCH_SIZE = 100 - super().__init__(device, batch_size) - self.comparison_mode = comparison_mode - self.model = None # Will be set by concrete implementations - - @abstractmethod - def get_embedding(self, prot_id, prot_seq): - """ - Abstract method to generate embeddings for a protein sequence. - Must be implemented by concrete embedder classes. - """ - pass - - @abstractmethod - def _initialize_model(self): - """ - Abstract method to initialize the embedding model. - Must be implemented by concrete embedder classes. - """ - pass - - def _run_parallel_process_parall_save_batch_embedding(self, sequences: Dict[str, str], output_path: Path) -> Tuple[Dict[str, str], Dict[str, Any]]: - """ - Run embedding process with parallel batch saving strategy. - Processes batches on main thread and saves to h5 in separate processes. - - Args: - sequences: Dictionary mapping protein IDs to sequences - output_path: Path to save embedding batch files - """ - parallel_output_path = output_path / "parallel_batches" - parallel_output_path.mkdir(exist_ok=True, parents=True) - - parallel_start_time = time.perf_counter() - performance_metrics = { - "strategy": "parallel_batch_save", - "start_time": parallel_start_time, - "total_sequences": len(sequences) - } - - logger.info(f"Starting parallel batch embedding strategy for {len(sequences)} sequences") - - # embeddings_pure_batch = {} - - batch_ids = [] - batch_embeddings = [] - - save_batch_processes = [] - batch_index = 0 - sequence_count = 0 - - timing_metrics = { - "tokenization": 0, - "to_device": 0, - "inference": 0, - "post_processing_get_embeddings": 0, - "post_processing_convert_to_numpy": 0, - "save_batch": 0, - "cache_clear": 0, - "model_eval": 0 - } - generate_embedding_total_time = 0 - - sequences_count = len(sequences) - - try: - t0 = time.perf_counter() - self.model.eval() - t1 = time.perf_counter() - timing_metrics["model_eval"] += t1 - t0 - with torch.inference_mode(): - for prot_id, prot_seq in sequences.items(): - sequence_count += 1 - - # Generate embedding - # generate_embedding_start_time = time.perf_counter() - # embeddings_pure, timing = self.get_embedding_profiled(prot_id, prot_seq) - # generate_embedding_end_time = time.perf_counter() - # generate_embedding_total_time += generate_embedding_end_time - generate_embedding_start_time - - # timing_metrics["tokenization"] += timing["tokenization"] - # timing_metrics["to_device"] += timing["to_device"] - # timing_metrics["inference"] += timing["inference"] - # timing_metrics["post_processing_get_embeddings"] += timing["post_processing_get_embeddings"] - # timing_metrics["post_processing_convert_to_numpy"] += timing["post_processing_convert_to_numpy"] - - t0 = time.perf_counter() - embeddings_pure = self.get_embedding(prot_id, prot_seq) - t1 = time.perf_counter() - generate_embedding_total_time += t1 - t0 - - # Validate embedding dimensions - # assert len(prot_seq) == embeddings_pure.shape[0], f'Invalid character in {prot_id}' - - batch_ids.append(prot_id) - batch_embeddings.append(embeddings_pure) - - # Save batch when full - if len(batch_ids) >= self.batch_size: - t0 = time.perf_counter() - embeddings_pure_batch = dict(zip(batch_ids, batch_embeddings)) - # Create shared memory and spawn save process - metadata = create_shared_memory_batch(embeddings_pure_batch) - p = Process(target=save_batch_profiled_from_shared_memory, - args=(parallel_output_path, batch_index, metadata)) - p.start() - save_batch_processes.append((p, parallel_output_path / f"batch_{batch_index}.h5", - batch_ids.copy())) - - batch_ids.clear() - batch_embeddings.clear() - batch_index += 1 - t1 = time.perf_counter() - timing_metrics["save_batch"] += t1 - t0 - - if sequence_count % 1000 == 0: - logger.info(f"Parallel mode: Processed {sequence_count}/{sequences_count} sequences") - - # t0 = time.perf_counter() - # torch.cuda.empty_cache() - # t1 = time.perf_counter() - # timing_metrics["cache_clear"] += t1 - t0 - - # Handle remaining sequences in final batch - if len(batch_ids) > 0: - t0 = time.perf_counter() - embeddings_pure_batch = dict(zip(batch_ids, batch_embeddings)) - metadata = create_shared_memory_batch(embeddings_pure_batch) - p = Process(target=save_batch_profiled_from_shared_memory, - args=(parallel_output_path, batch_index, metadata)) - p.start() - save_batch_processes.append((p, parallel_output_path / f"batch_{batch_index}.h5", - batch_ids.copy())) - t1 = time.perf_counter() - timing_metrics["save_batch"] += t1 - t0 - finally: - # Cleanup model memory - t0 = time.perf_counter() - if hasattr(self, 'model') and self.model is not None: - torch.cuda.empty_cache() - gc.collect() - t1 = time.perf_counter() - timing_metrics["cache_clear"] += t1 - t0 - - # Wait for all save processes to complete - process_join_start = time.perf_counter() - final_index = {} - - for p, batch_file, prot_ids in save_batch_processes: - p.join() - for prot_id in prot_ids: - final_index[prot_id] = str(batch_file) - - process_join_end = time.perf_counter() - parallel_end_time = time.perf_counter() - - performance_metrics.update({ - "end_time": parallel_end_time, - "total_duration": parallel_end_time - parallel_start_time, - "sequences_per_second": len(sequences) / (parallel_end_time - parallel_start_time), - "process_join_time": process_join_end - process_join_start, - "generate_embedding_total_time": generate_embedding_total_time, - "timing_metrics": timing_metrics - }) - - logger.info(f"Parallel batch strategy completed in {performance_metrics['total_duration']:.2f}s") - return final_index, performance_metrics - - def _run_batch_end_embedding(self, sequences: Dict[str, str], output_path: Path) -> Tuple[Dict[str, str], Dict[str, Any]]: - """ - Run embedding process with batch-end saving strategy. - - Args: - sequences: Dictionary mapping protein IDs to sequences - output_path: Path to save embedding files - - Returns: - tuple: (final_index, performance_metrics) - """ - batch_end_start_time = time.perf_counter() - batch_end_output_path = output_path / "batch_end" - batch_end_output_path.mkdir(exist_ok=True, parents=True) - - performance_metrics = { - "strategy": "batch_end_save", - "start_time": batch_end_start_time, - "total_sequences": len(sequences) - } - - logger.info(f"Starting batch-end embedding strategy for {len(sequences)} sequences") - - # Accumulate all embeddings - all_embeddings = {} - - generate_embedding_total_time = 0 - - timing_metrics = { - "tokenization": 0, - "to_device": 0, - "inference": 0, - "post_processing_get_embeddings": 0, - "post_processing_convert_to_numpy": 0, - "save_batch": 0, - "cache_clear": 0 - } - - try: - self.model.eval() - with torch.inference_mode(): - sequence_count = 0 - - for prot_id, prot_seq in sequences.items(): - sequence_count += 1 - - # Generate embedding - # generate_embedding_start_time = time.perf_counter() - # embeddings_pure, timing = self.get_embedding_profiled(prot_id, prot_seq) - # generate_embedding_end_time = time.perf_counter() - # generate_embedding_total_time += generate_embedding_end_time - generate_embedding_start_time - # timing_metrics["tokenization"] += timing["tokenization"] - # timing_metrics["to_device"] += timing["to_device"] - # timing_metrics["inference"] += timing["inference"] - # timing_metrics["post_processing_get_embeddings"] += timing["post_processing_get_embeddings"] - # timing_metrics["post_processing_convert_to_numpy"] += timing["post_processing_convert_to_numpy"] - - - t0 = time.perf_counter() - embeddings_pure = self.get_embedding(prot_id, prot_seq) - t1 = time.perf_counter() - generate_embedding_total_time += t1 - t0 - - # Validate embedding dimensions - assert len(prot_seq) == embeddings_pure.shape[0], f'Invalid character in {prot_id}' - - # Store embedding (accumulate in memory) - all_embeddings[prot_id] = embeddings_pure - - if sequence_count % 50 == 0: - logger.info(f"Batch-end mode: Processed {sequence_count}/{len(sequences)} sequences") - t0 = time.perf_counter() - torch.cuda.empty_cache() - t1 = time.perf_counter() - timing_metrics["cache_clear"] += t1 - t0 - - finally: - # Cleanup model memory but keep embeddings - t0 = time.perf_counter() - if hasattr(self, 'model') and self.model is not None: - torch.cuda.empty_cache() - gc.collect() - t1 = time.perf_counter() - timing_metrics["cache_clear"] += t1 - t0 - - embedding_generation_end = time.perf_counter() - - # Save all embeddings at once - logger.info(f"Saving all {len(all_embeddings)} embeddings to disk...") - save_start_time = time.perf_counter() - - save_metrics = save_all_embeddings_at_end(batch_end_output_path, all_embeddings) - - save_end_time = time.perf_counter() - batch_end_end_time = time.perf_counter() - - # Create final index - output_file = batch_end_output_path / "all_embeddings.h5" - final_index = {prot_id: str(output_file) for prot_id in all_embeddings.keys()} - - # Update performance metrics with time data only - performance_metrics.update({ - "end_time": batch_end_end_time, - "total_duration": batch_end_end_time - batch_end_start_time, - "embedding_generation_time": embedding_generation_end - batch_end_start_time, - "saving_time": save_end_time - save_start_time, - "sequences_per_second": len(sequences) / (batch_end_end_time - batch_end_start_time), - "save_metrics": save_metrics, - "generate_embedding_total_time": generate_embedding_total_time, - "timing_metrics": timing_metrics - }) - - logger.info(f"Batch-end strategy completed in {performance_metrics['total_duration']:.2f}s") - return final_index, performance_metrics - - def embed(self, sequences: Dict[str, str], output_path: Path) -> Dict[str, str]: - """ - Run performance comparison between parallel batch saving and batch-end saving. - - Args: - sequences: Dictionary mapping protein IDs to sequences - output_path: Path to save embedding files and performance reports - - Returns: - dict: Final index mapping protein IDs to file paths (from the last run strategy) - """ - - if len(sequences) > 3000: - sequences = dict(list(sequences.items())[:3000]) - - # Create output directories - output_path.mkdir(exist_ok=True, parents=True) - performance_log_path = output_path / "performance_comparison.json" - - # Initialize model - self._initialize_model() - - logger.info(f"Starting performance comparison for {len(sequences)} sequences") - logger.info(f"Device: {self.device}, Batch size: {self.batch_size}, Mode: {self.comparison_mode}") - - comparison_results = { - "experiment_metadata": { - "total_sequences": len(sequences), - "batch_size": self.batch_size, - "comparison_mode": self.comparison_mode, - "timestamp": time.time() - }, - "strategies": {} - } - - final_index = {} - - try: - # Run parallel batch strategy - if self.comparison_mode in ["both", "parallel_only"]: - logger.info("=" * 60) - logger.info("STARTING PARALLEL BATCH SAVING STRATEGY") - logger.info("=" * 60) - - # Re-initialize model for fair comparison - self._initialize_model() - - parallel_index, parallel_metrics = self._run_parallel_process_parall_save_batch_embedding(sequences, output_path) - comparison_results["strategies"]["parallel_batch_save"] = parallel_metrics - final_index = parallel_index - - # Clean up model to free memory - if hasattr(self, 'model') and self.model is not None: - del self.model - self.model = None - torch.cuda.empty_cache() - gc.collect() - - logger.info(f"Parallel strategy: {parallel_metrics['total_duration']:.2f}s, {parallel_metrics['sequences_per_second']:.2f} seq/s") - - # Run batch-end strategy - # if self.comparison_mode in ["both", "batch_end_only"]: - # logger.info("=" * 60) - # logger.info("STARTING BATCH-END SAVING STRATEGY") - # logger.info("=" * 60) - - # # Re-initialize model for fair comparison - # self._initialize_model() - - # batch_end_index, batch_end_metrics = self._run_batch_end_embedding(sequences, output_path) - # comparison_results["strategies"]["batch_end_save"] = batch_end_metrics - - # # If we only ran batch-end, use its index - # if self.comparison_mode == "batch_end_only": - # final_index = batch_end_index - - # logger.info(f"Batch-end strategy: {batch_end_metrics['total_duration']:.2f}s, {batch_end_metrics['sequences_per_second']:.2f} seq/s") - - # Generate comparison summary - # if len(comparison_results["strategies"]) > 1: - # parallel_metrics = comparison_results["strategies"]["parallel_batch_save"] - # batch_end_metrics = comparison_results["strategies"]["batch_end_save"] - - # comparison_results["comparison_summary"] = { - # "parallel_total_time": parallel_metrics["total_duration"], - # "batch_end_total_time": batch_end_metrics["total_duration"], - # "time_difference": batch_end_metrics["total_duration"] - parallel_metrics["total_duration"], - # "time_difference_percent": ((batch_end_metrics["total_duration"] - parallel_metrics["total_duration"]) / parallel_metrics["total_duration"]) * 100, - # "parallel_throughput": parallel_metrics["sequences_per_second"], - # "batch_end_throughput": batch_end_metrics["sequences_per_second"], - # "throughput_difference_percent": ((batch_end_metrics["sequences_per_second"] - parallel_metrics["sequences_per_second"]) / parallel_metrics["sequences_per_second"]) * 100, - # "parallel_process_join_time": parallel_metrics.get("process_join_time", 0), - # "batch_end_saving_time": batch_end_metrics.get("saving_time", 0), - # "recommendation": self._generate_recommendation(parallel_metrics, batch_end_metrics) - # } - - # logger.info("=" * 60) - # logger.info("PERFORMANCE COMPARISON SUMMARY") - # logger.info("=" * 60) - # summary = comparison_results["comparison_summary"] - # logger.info(f"Parallel batch saving: {summary['parallel_total_time']:.2f}s ({summary['parallel_throughput']:.2f} seq/s)") - # logger.info(f"Batch-end saving: {summary['batch_end_total_time']:.2f}s ({summary['batch_end_throughput']:.2f} seq/s)") - # logger.info(f"Time difference: {summary['time_difference']:.2f}s ({summary['time_difference_percent']:.1f}%)") - # logger.info(f"Throughput difference: {summary['throughput_difference_percent']:.1f}%") - # logger.info(f"Recommendation: {summary['recommendation']}") - - finally: - # Final cleanup - if hasattr(self, 'model') and self.model is not None: - del self.model - self.model = None - torch.cuda.empty_cache() - gc.collect() - - # Save performance comparison results - try: - with open(performance_log_path, 'w') as f: - json.dump(comparison_results, f, indent=2, default=str) - logger.info(f"Performance comparison results saved to: {performance_log_path}") - except Exception as e: - logger.warning(f"Failed to save performance comparison results: {e}") - - return final_index - - def _generate_recommendation(self, parallel_metrics: Dict[str, Any], batch_end_metrics: Dict[str, Any]) -> str: - """Generate a recommendation based on the time performance comparison.""" - parallel_time = parallel_metrics["total_duration"] - batch_end_time = batch_end_metrics["total_duration"] - - time_diff_percent = ((batch_end_time - parallel_time) / parallel_time) * 100 - - if abs(time_diff_percent) < 5: - return "Performance is similar between both strategies. Use parallel batch saving for better memory efficiency." - elif time_diff_percent < -10: - return "Batch-end saving is significantly faster." - elif time_diff_percent > 10: - return "Parallel batch saving is significantly faster." - else: - return "Performance difference is moderate. Choose based on your specific requirements." diff --git a/toolbox/models/embedding/utils_profiled.py b/toolbox/models/embedding/utils_profiled.py deleted file mode 100644 index b5360c5..0000000 --- a/toolbox/models/embedding/utils_profiled.py +++ /dev/null @@ -1,213 +0,0 @@ -from pathlib import Path -from typing import Dict -import torch -import h5py -import time -import os -import threading -import json -import numpy as np -from multiprocessing import shared_memory - - -def save_batch_profiled(output_path: Path, batch_index: int, embeddings_pure: Dict[str, torch.Tensor]): - """ - Profiled version of save_batch function that logs detailed timing information. - - This function saves a batch of embeddings to an H5 file while capturing - comprehensive performance metrics including I/O timing, tensor sizes, - and throughput measurements. - - Args: - output_path: Directory to save the batch file - batch_index: Index of the current batch - embeddings_pure: Dictionary mapping protein IDs to their embedding tensors - """ - # Start timing and metadata collection - start_time = time.perf_counter() - start_timestamp = time.time() - - batch_file = output_path / f"batch_{batch_index}.h5" - log_file = output_path / f"save_batch_profile_{batch_index}_{os.getpid()}.json" - - # Initialize comprehensive log entry - log_entry = { - "batch_index": batch_index, - "process_id": os.getpid(), - "thread_id": threading.get_ident(), - "start_timestamp": start_timestamp, - "start_perf_counter": start_time, - "num_sequences": len(embeddings_pure), - "batch_file": str(batch_file), - "sequence_ids": list(embeddings_pure.keys()), - "performance_metrics": {}, - "timing_breakdown": {} - } - - try: - # Time the file creation and writing process - file_create_start = time.perf_counter() - - with h5py.File(batch_file, 'w') as f: - file_create_end = time.perf_counter() - - # Track individual dataset creation times - dataset_times = [] - dataset_write_start = time.perf_counter() - - for seq_id, embedding in embeddings_pure.items(): - dataset_start = time.perf_counter() - - # Create dataset with compression for better I/O performance tracking - dataset = f.create_dataset(seq_id, data=embedding) - - dataset_end = time.perf_counter() - - dataset_times.append({ - "seq_id": seq_id, - "duration": dataset_end - dataset_start - }) - - dataset_write_end = time.perf_counter() - - # Force flush to ensure all data is written - f.flush() - - file_close_end = time.perf_counter() - - # Get final file statistics - file_stats = batch_file.stat() - final_file_size = file_stats.st_size - - # Calculate comprehensive timing breakdown - log_entry["timing_breakdown"] = { - "file_creation": file_create_end - file_create_start, - "dataset_writing": dataset_write_end - dataset_write_start, - "file_closing": file_close_end - dataset_write_end, - "total_io_time": file_close_end - file_create_start - } - - # Calculate performance metrics - total_duration = file_close_end - start_time - - log_entry["performance_metrics"] = { - "total_duration": total_duration, - "sequences_per_second": len(embeddings_pure) / total_duration if total_duration > 0 else 0, - "bytes_per_second": final_file_size / total_duration if total_duration > 0 else 0, - "avg_dataset_time": sum(d["duration"] for d in dataset_times) / len(dataset_times) if dataset_times else 0 - } - - # Success metadata - log_entry.update({ - "end_timestamp": time.time(), - "end_perf_counter": file_close_end, - "success": True, - "file_size_bytes": final_file_size, - "dataset_count": len(dataset_times), - "dataset_times": dataset_times - }) - - except Exception as e: - # Error handling with timing - error_time = time.perf_counter() - - log_entry.update({ - "end_timestamp": time.time(), - "end_perf_counter": error_time, - "total_duration": error_time - start_time, - "success": False, - "error": str(e), - "error_type": type(e).__name__, - "error_occurred_at": error_time - start_time - }) - - # Re-raise the exception after logging - raise - - finally: - # Always attempt to write the log file - try: - log_entry["log_write_start"] = time.perf_counter() - - with open(log_file, 'w') as f: - json.dump(log_entry, f, indent=2) - - log_entry["log_write_duration"] = time.perf_counter() - log_entry["log_write_start"] - - except Exception as log_error: - # If logging fails, print to stderr but don't fail the main operation - print(f"Warning: Failed to write profiling log for batch {batch_index}: {log_error}") - print(f"Batch {batch_index} performance summary:") - print(f" Duration: {log_entry.get('performance_metrics', {}).get('total_duration', 'unknown')} seconds") - print(f" Sequences: {log_entry.get('num_sequences', 'unknown')}") - print(f" Success: {log_entry.get('success', 'unknown')}") - - -# Keep the original function available for backward compatibility -def save_batch_original(output_path: Path, batch_index: int, embeddings_pure: Dict[str, torch.Tensor]): - """Original save_batch function without profiling.""" - batch_file = output_path / f"batch_{batch_index}.h5" - with h5py.File(batch_file, 'w') as f: - for seq_id, embedding in embeddings_pure.items(): - f.create_dataset(seq_id, data=embedding) - - -def create_shared_memory_batch(embeddings_pure: Dict[str, torch.Tensor]): - """Serialize embeddings into shared memory blocks and return metadata.""" - batch_metadata = [] - shared_blocks = [] - try: - for prot_id, embedding in embeddings_pure.items(): - if isinstance(embedding, torch.Tensor): - array = embedding.detach().cpu().numpy() - else: - array = np.asarray(embedding) - - shm = shared_memory.SharedMemory(create=True, size=array.nbytes) - shared_blocks.append(shm) - - shm_array = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf) - shm_array[...] = array - - batch_metadata.append({ - "prot_id": prot_id, - "shape": array.shape, - "dtype": str(array.dtype), - "shm_name": shm.name, - }) - except Exception: - for shm in shared_blocks: - shm.close() - shm.unlink() - raise - finally: - for shm in shared_blocks: - shm.close() - - return batch_metadata - - -def save_batch_profiled_from_shared_memory(output_path: Path, batch_index: int, batch_metadata): - """Reconstruct embeddings from shared memory and delegate to save_batch_profiled.""" - # embeddings_pure = {} - shared_blocks = [] - try: - batch_ids = [] - batch_embeddings = [] - for meta in batch_metadata: - shm = shared_memory.SharedMemory(name=meta["shm_name"]) - shared_blocks.append(shm) - - batch_ids.append(meta["prot_id"]) - - array = np.ndarray(meta["shape"], dtype=np.dtype(meta["dtype"]), buffer=shm.buf) - - batch_embeddings.append(array) - - embeddings_pure = dict(zip(batch_ids, batch_embeddings)) - - save_batch_original(output_path, batch_index, embeddings_pure) - finally: - for shm in shared_blocks: - shm.close() - shm.unlink() diff --git a/toolbox/models/manage_dataset/compress_experiment/exp.py b/toolbox/models/manage_dataset/compress_experiment/exp.py deleted file mode 100644 index b9dec70..0000000 --- a/toolbox/models/manage_dataset/compress_experiment/exp.py +++ /dev/null @@ -1,174 +0,0 @@ -import h5py -import numpy as np -import zlib -import time -from pathlib import Path -from typing import List, Tuple - - -# Approach 1: Creating an HDF5 dataset for each protein structure -def compress_and_save_h5_individual( - path_for_batch: Path, results: Tuple[List[str], List[str], List[str]] -): - start_time = time.time() - pdbs_file = path_for_batch / "pdbs_individual_gzip.h5" - all_res_pdbs = results[0] - all_contents = results[1] - if len(all_contents) == 0 or len(all_res_pdbs) == 0: - print("No files to save") - return None - if len(all_res_pdbs) != len(all_contents): - print("Wrong length of names and pdb contents") - return None - with h5py.File(pdbs_file, "w") as hf: - for pdb_name, pdb_content in zip(all_res_pdbs, all_contents): - hf.create_dataset( - pdb_name, - data=np.frombuffer(pdb_content.encode("utf-8"), dtype=np.uint8), - compression="gzip", - ) - end_time = time.time() - total_time = end_time - start_time - print(f"Compress time (individual): {total_time}") - return str(pdbs_file) - - -def compress_and_save_h5_individual_lzf( - path_for_batch: Path, results: Tuple[List[str], List[str], List[str]] -): - start_time = time.time() - pdbs_file = path_for_batch / "pdbs_individual_lzf.h5" - all_res_pdbs = results[0] - all_contents = results[1] - if len(all_contents) == 0 or len(all_res_pdbs) == 0: - print("No files to save") - return None - if len(all_res_pdbs) != len(all_contents): - print("Wrong length of names and pdb contents") - return None - with h5py.File(pdbs_file, "w") as hf: - for pdb_name, pdb_content in zip(all_res_pdbs, all_contents): - hf.create_dataset( - pdb_name, - data=np.frombuffer(pdb_content.encode("utf-8"), dtype=np.uint8), - compression="lzf", - ) - end_time = time.time() - total_time = end_time - start_time - print(f"Compress time (individual): {total_time}") - return str(pdbs_file) - - -def compress_and_save_h5_individual_lzf_shuffle( - path_for_batch: Path, results: Tuple[List[str], List[str], List[str]] -): - start_time = time.time() - pdbs_file = path_for_batch / "pdbs_individual_lzf_shuffle.h5" - all_res_pdbs = results[0] - all_contents = results[1] - if len(all_contents) == 0 or len(all_res_pdbs) == 0: - print("No files to save") - return None - if len(all_res_pdbs) != len(all_contents): - print("Wrong length of names and pdb contents") - return None - with h5py.File(pdbs_file, "w") as hf: - for pdb_name, pdb_content in zip(all_res_pdbs, all_contents): - hf.create_dataset( - pdb_name, - data=np.frombuffer(pdb_content.encode("utf-8"), dtype=np.uint8), - compression="lzf", - shuffle=True, - ) - end_time = time.time() - total_time = end_time - start_time - print(f"Compress time (shuffle_individual): {total_time}") - return str(pdbs_file) - - -# Approach 2: Creating an HDF5 dataset for all protein structures -def compress_and_save_h5_combined( - path_for_batch: Path, results: Tuple[List[str], List[str], List[str]] -): - start_time = time.time() - pdbs_file = path_for_batch / "pdbs_combined_gzip.h5" - all_res_pdbs = results[0] - all_contents = results[1] - if len(all_contents) == 0 or len(all_res_pdbs) == 0: - print("No files to save") - return None - if len(all_res_pdbs) != len(all_contents): - print("Wrong length of names and pdb contents") - return None - with h5py.File(pdbs_file, "w") as hf: - combined_content = "|".join(all_contents) - compressed_content = np.frombuffer( - combined_content.encode("utf-8"), dtype=np.uint8 - ) - hf.create_dataset( - ";".join(all_res_pdbs), data=compressed_content, compression="gzip" - ) - end_time = time.time() - total_time = end_time - start_time - print(f"Compress time (combined): {total_time}") - return str(pdbs_file) - - -def compress_and_save_h5_combined_lzf( - path_for_batch: Path, results: Tuple[List[str], List[str], List[str]] -): - start_time = time.time() - pdbs_file = path_for_batch / "pdbs_combined_lzf.h5" - all_res_pdbs = results[0] - all_contents = results[1] - if len(all_contents) == 0 or len(all_res_pdbs) == 0: - print("No files to save") - return None - if len(all_res_pdbs) != len(all_contents): - print("Wrong length of names and pdb contents") - return None - with h5py.File(pdbs_file, "w") as hf: - combined_content = "|".join(all_contents) - compressed_content = np.frombuffer( - combined_content.encode("utf-8"), dtype=np.uint8 - ) - hf.create_dataset( - ";".join(all_res_pdbs), data=compressed_content, compression="lzf" - ) - end_time = time.time() - total_time = end_time - start_time - print(f"Compress time (combined): {total_time}") - return str(pdbs_file) - - -def compress_and_save_h5_combined_lzf_shuffle( - path_for_batch: Path, results: Tuple[List[str], List[str], List[str]] -): - start_time = time.time() - pdbs_file = path_for_batch / "pdbs_combined_lzf_shuffle.h5" - all_res_pdbs = results[0] - all_contents = results[1] - if len(all_contents) == 0 or len(all_res_pdbs) == 0: - print("No files to save") - return None - if len(all_res_pdbs) != len(all_contents): - print("Wrong length of names and pdb contents") - return None - with h5py.File(pdbs_file, "w") as hf: - combined_content = "|".join(all_contents) - compressed_content = np.frombuffer( - combined_content.encode("utf-8"), dtype=np.uint8 - ) - hf.create_dataset( - ";".join(all_res_pdbs), - data=compressed_content, - compression="lzf", - shuffle=True, - ) - end_time = time.time() - total_time = end_time - start_time - print(f"Compress time (shuffle_combined): {total_time}") - return str(pdbs_file) - - -# Approach 3: Compressing data before storing (existing implementation) diff --git a/toolbox/models/manage_dataset/utils.py b/toolbox/models/manage_dataset/utils.py index 5419f6a..01bf60c 100644 --- a/toolbox/models/manage_dataset/utils.py +++ b/toolbox/models/manage_dataset/utils.py @@ -1,5 +1,4 @@ import asyncio -import os import shutil import tempfile import time @@ -7,7 +6,7 @@ import zlib import re -from io import BytesIO, StringIO +from io import BytesIO from itertools import islice from pathlib import Path from typing import List, Tuple, Optional, Dict, Iterable @@ -16,21 +15,13 @@ import biotite.database import biotite.database.rcsb import biotite.database.afdb -import dask import h5py import numpy as np -from dask.distributed import as_completed as dask_as_completed, worker_client +from dask.distributed import worker_client from foldcomp import foldcomp from foldcomp.setup import download -from toolbox.models.manage_dataset.compress_experiment.exp import ( - compress_and_save_h5_combined_lzf_shuffle, - compress_and_save_h5_individual, - compress_and_save_h5_individual_lzf, - compress_and_save_h5_combined, - compress_and_save_h5_combined_lzf, - compress_and_save_h5_individual_lzf_shuffle, -) + from toolbox.models.utils.cif2pdb import cif_to_pdb, binary_cif_to_pdb from toolbox.utlis.logging import logger @@ -67,15 +58,15 @@ def retrieve_cif(pdb: str) -> Tuple[Optional[str], str]: logger.debug(f"Retrying downloading {pdb} {retry_num}") try: - cif_file_io: StringIO = biotite.database.rcsb.fetch(pdb, "cif") - cif_file: str = cif_file_io.getvalue() + cif_file_io = biotite.database.rcsb.fetch(pdb, "cif") + cif_file = cif_file_io.getvalue() except Exception: cif_file = None - if not cif_file: + if cif_file is None: retry_num += 1 - if retry_num > 3: + if cif_file is None: logger.warning(f"Failed retrying {pdb}") return None, pdb @@ -121,20 +112,20 @@ def retrieve_binary_cif(pdb: str) -> tuple[BytesIO | None, str]: retry_num: int = 0 binary_file_bytes_io: Optional[BytesIO] = None - while retry_num <= 3 and binary_file_bytes_io is None: - + while retry_num <= 3: if retry_num > 0: logger.debug(f"Retrying downloading {pdb} {retry_num}") try: - binary_file_bytes_io: BytesIO = biotite.database.rcsb.fetch(pdb, "bcif") + binary_file_bytes_io = biotite.database.rcsb.fetch(pdb, "bcif") except Exception: binary_file_bytes_io = None - if not binary_file_bytes_io: - retry_num += 1 + if binary_file_bytes_io: + break + retry_num += 1 - if retry_num > 3: + if binary_file_bytes_io is None: logger.warning(f"Failed retrying {pdb}") return None, pdb @@ -214,45 +205,6 @@ def compress_and_save_h5( return str(pdbs_file) -def compress_and_save_experiment( - path_for_batch: Path, results: Tuple[List[str], List[str], List[str]] -): - fs = [ - compress_and_save_h5_individual, - compress_and_save_h5_individual_lzf, - compress_and_save_h5_individual_lzf_shuffle, - compress_and_save_h5_combined, - compress_and_save_h5_combined_lzf, - compress_and_save_h5_combined_lzf_shuffle, - compress_and_save_h5, - ] - - descriptions = [ - "individual gzip", - "individual lzf", - "individual lzf shuffle", - "combined gzip", - "combined lzf", - "combined lzf shuffle", - "combined zlib", - ] - - inputs = list(results) - - def get_file_size_mb(file_path): - try: - size_in_bytes = os.path.getsize(file_path) - size_in_mb = size_in_bytes / (1024 * 1024) # Convert bytes to megabytes - return round(size_in_mb, 2) - except Exception: - return None - - for f, desc in zip(fs, descriptions): - logger.debug(desc) - path = f(path_for_batch, inputs) - logger.debug(f"{path} {get_file_size_mb(path)} MB") - - def retrieve_pdb_chunk_to_h5( path_for_batch: Path, pdb_ids: Iterable[str],