diff --git a/README.md b/README.md index 9ccba4d..f0b69e8 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,8 @@ PYTHONPATH='.' python3 -u ${FRIDATA_PATH}/fridata.py \ -e ${EMBEDDER_TYPE} ``` +For subset runs with `--input-path`, new datasets store canonical keys as `{line_from_ids_file}_{chain}` (for example `A0A2K6V5L6_A`), not the full AlphaFold CIF filename stem. The dataset’s `input_structures.idx` maps each canonical key to the source structure filename. Older datasets created before this convention may still use long AF-style keys. + ## Running as a CLI tool Assuming all `Instalation and activation` steps succeeded. @@ -106,7 +108,7 @@ python3 -m pip install -e . ``` $ fridata <...> -``` +```3dc54 (Use ids_file tokens (e.g. plain UniProt) plus chain as the canonical dataset index keys) ## Running on HPC diff --git a/tests/test_input_path_resolution.py b/tests/test_input_path_resolution.py new file mode 100644 index 0000000..6748377 --- /dev/null +++ b/tests/test_input_path_resolution.py @@ -0,0 +1,138 @@ +"""Tests for input_path ID resolution (exact stem + AF model version / isoform).""" + +from pathlib import Path + +from toolbox.models.manage_dataset.extract_archive import ( + build_stem_to_paths, + pick_single_path_for_canonical_id, + resolve_id, + retrieve_single_file, +) + + +def _touch(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("", encoding="utf-8") + + +def test_build_stem_to_paths_empty(tmp_path: Path) -> None: + assert build_stem_to_paths(tmp_path) == {} + + +def test_resolve_exact_cif_only(tmp_path: Path) -> None: + f = tmp_path / "Q5VSL9-2.cif" + _touch(f) + inv = build_stem_to_paths(tmp_path) + out = resolve_id("Q5VSL9-2", inv) + assert len(out) == 1 + assert out[0].resolve() == f.resolve() + + +def test_resolve_prefers_cif_over_pdb_same_stem(tmp_path: Path) -> None: + cif = tmp_path / "Q5VSL9-2.cif" + pdb = tmp_path / "Q5VSL9-2.pdb" + _touch(pdb) + _touch(cif) + inv = build_stem_to_paths(tmp_path) + out = resolve_id("Q5VSL9-2", inv) + assert len(out) == 1 + assert out[0].suffix.lower() == ".cif" + + +def test_resolve_af_latest_version(tmp_path: Path) -> None: + _touch(tmp_path / "AF-Q5VSL9-2-F1-model_v4.cif") + v6 = tmp_path / "AF-Q5VSL9-2-F1-model_v6.cif" + _touch(v6) + inv = build_stem_to_paths(tmp_path) + out = resolve_id("Q5VSL9-2", inv) + assert len(out) == 1 + assert out[0].resolve() == v6.resolve() + + +def test_resolve_af_multi_fragment_same_version(tmp_path: Path) -> None: + f1 = tmp_path / "AF-Q5VSL9-F1-model_v4.cif" + f2 = tmp_path / "AF-Q5VSL9-F2-model_v4.cif" + _touch(f1) + _touch(f2) + inv = build_stem_to_paths(tmp_path) + out = resolve_id("Q5VSL9", inv) + assert len(out) == 2 + assert {p.resolve() for p in out} == {f1.resolve(), f2.resolve()} + + +def test_pick_single_path_prefers_lowest_af_fragment(tmp_path: Path) -> None: + f1 = tmp_path / "AF-Q5VSL9-F1-model_v4.cif" + f2 = tmp_path / "AF-Q5VSL9-F2-model_v4.cif" + _touch(f1) + _touch(f2) + chosen = pick_single_path_for_canonical_id([f2, f1]) + assert chosen.resolve() == f1.resolve() + + +def test_pick_single_path_non_af_tiebreak_by_path(tmp_path: Path) -> None: + a = tmp_path / "a" / "x.cif" + b = tmp_path / "b" / "y.cif" + _touch(a) + _touch(b) + chosen = pick_single_path_for_canonical_id([b, a]) + assert chosen.resolve() == a.resolve() + + +def test_retrieve_single_file_canonical_pdb_code(tmp_path: Path) -> None: + f = tmp_path / "AF-A0A2K6V5L6-F1-model_v6.cif" + f.write_text("dummy", encoding="utf-8") + data, pdb_code, ext = retrieve_single_file((str(f), "A0A2K6V5L6")) + assert data == "dummy" + assert pdb_code == "A0A2K6V5L6" + assert ext == ".cif" + + +def test_retrieve_single_file_plain_path_uses_stem(tmp_path: Path) -> None: + f = tmp_path / "AF-A0A2K6V5L6-F1-model_v6.cif" + f.write_text("z", encoding="utf-8") + _, pdb_code, _ = retrieve_single_file(str(f)) + assert pdb_code == "AF-A0A2K6V5L6-F1-model_v6" + + +def test_resolve_exact_wins_over_af_pattern(tmp_path: Path) -> None: + exact = tmp_path / "Q5VSL9-2.cif" + _touch(exact) + _touch(tmp_path / "AF-Q5VSL9-2-F1-model_v6.cif") + inv = build_stem_to_paths(tmp_path) + out = resolve_id("Q5VSL9-2", inv) + assert len(out) == 1 + assert out[0].resolve() == exact.resolve() + + +def test_resolve_loose_isoform_fallback(tmp_path: Path) -> None: + iso = tmp_path / "AF-Q5VSL9-2-F1-model_v6.cif" + _touch(iso) + inv = build_stem_to_paths(tmp_path) + out = resolve_id("Q5VSL9", inv) + assert len(out) == 1 + assert out[0].resolve() == iso.resolve() + + +def test_resolve_af_exact_takes_precedence_over_loose_higher_version(tmp_path: Path) -> None: + base_v4 = tmp_path / "AF-Q5VSL9-F1-model_v4.cif" + iso_v6 = tmp_path / "AF-Q5VSL9-2-F1-model_v6.cif" + _touch(base_v4) + _touch(iso_v6) + inv = build_stem_to_paths(tmp_path) + out = resolve_id("Q5VSL9", inv) + assert len(out) == 1 + assert out[0].resolve() == base_v4.resolve() + + +def test_resolve_missing_returns_empty(tmp_path: Path) -> None: + _touch(tmp_path / "OTHER.cif") + inv = build_stem_to_paths(tmp_path) + assert resolve_id("Q5VSL9", inv) == [] + + +def test_resolve_strips_id_not_applied_use_raw_match(tmp_path: Path) -> None: + """Whitespace is stripped by caller in save_extracted_files; resolver is strict.""" + _touch(tmp_path / "Q5VSL9.cif") + inv = build_stem_to_paths(tmp_path) + assert resolve_id(" Q5VSL9", inv) == [] + assert resolve_id("Q5VSL9", inv) != [] diff --git a/toolbox/models/manage_dataset/extract_archive.py b/toolbox/models/manage_dataset/extract_archive.py index 192a3e8..68383fd 100644 --- a/toolbox/models/manage_dataset/extract_archive.py +++ b/toolbox/models/manage_dataset/extract_archive.py @@ -1,15 +1,10 @@ import os -import shutil import time -from typing import Iterable, List, Optional, Tuple, Dict +from typing import Iterable, List, Optional, Tuple, Dict, Pattern, Union import zipfile import tarfile from pathlib import Path -from glob import iglob - -from tqdm import tqdm - from dask.distributed import worker_client from toolbox.models.manage_dataset.index.handle_index import add_new_files_to_index, create_index from toolbox.models.utils.create_client import total_workers @@ -65,6 +60,97 @@ def is_archive(path): return zipfile.is_zipfile(path) or tarfile.is_tarfile(path) +def build_stem_to_paths(extracted_path: Path) -> Dict[str, List[Path]]: + """ + Collect all .pdb / .cif under extracted_path, keyed by filename stem. + Paths are resolved to absolute for reliable opening (archive vs directory). + """ + stem_to_paths: Dict[str, List[Path]] = {} + base = extracted_path.resolve() + if not base.exists(): + return stem_to_paths + for pattern in ("*.pdb", "*.cif"): + for p in base.rglob(pattern): + if not p.is_file(): + continue + resolved = p.resolve() + stem_to_paths.setdefault(resolved.stem, []).append(resolved) + return stem_to_paths + + +def _pick_path_prefer_cif(paths: List[Path]) -> Path: + """If both .cif and .pdb exist for the same stem, prefer .cif.""" + cifs = [p for p in paths if p.suffix.lower() == ".cif"] + if cifs: + return sorted(cifs, key=lambda x: str(x))[0] + return sorted(paths, key=lambda x: str(x))[0] + + +def _paths_at_max_af_version( + stem_to_paths: Dict[str, List[Path]], pattern: Pattern[str] +) -> List[Path]: + """ + Match stems with pattern groups: (fragment F, version V). + Pick the globally highest V, then return one path per matching stem at that V + (prefer .cif when a stem has multiple extensions). + """ + rows: List[Tuple[int, int, str]] = [] + for stem in stem_to_paths: + m = pattern.match(stem) + if m: + fragment = int(m.group(1)) + version = int(m.group(2)) + rows.append((version, fragment, stem)) + if not rows: + return [] + max_v = max(r[0] for r in rows) + stems_at_max = sorted({r[2] for r in rows if r[0] == max_v}) + return [_pick_path_prefer_cif(stem_to_paths[s]) for s in stems_at_max] + + +def resolve_id(requested_id: str, stem_to_paths: Dict[str, List[Path]]) -> List[Path]: + """ + Resolve a requested protein id to file path(s) under input_path / extracted tree. + + 1) Exact stem match (prefer .cif over .pdb). + 2) AF exact: AF-{id}-F{N}-model_v{V} — highest V, all fragments at that V. + 3) AF loose (isoform): AF-{id}-{digits}-F{N}-model_v{V} — same version rule. + """ + if requested_id in stem_to_paths: + return [_pick_path_prefer_cif(stem_to_paths[requested_id])] + + af_exact = re.compile(rf"^AF-{re.escape(requested_id)}-F(\d+)-model_v(\d+)$") + found = _paths_at_max_af_version(stem_to_paths, af_exact) + if found: + return found + + af_loose = re.compile(rf"^AF-{re.escape(requested_id)}-\d+-F(\d+)-model_v(\d+)$") + return _paths_at_max_af_version(stem_to_paths, af_loose) + + +_AF_FRAGMENT_NUM_RE = re.compile(r"-F(\d+)-model_v\d+$") + + +def pick_single_path_for_canonical_id(paths: List[Path]) -> Path: + """ + When multiple structures match one ids_file id (e.g. F1 and F2 at same model version), + keep a single file: lowest AF fragment number F{N}; ties by resolved path string. + Non-AF stems sort after AF (fragment key 10**9). + """ + if not paths: + raise ValueError("paths must be non-empty") + if len(paths) == 1: + return paths[0] + scored: List[Tuple[int, str, Path]] = [] + for p in paths: + stem = p.stem + m = _AF_FRAGMENT_NUM_RE.search(stem) + frag = int(m.group(1)) if m else 10**9 + scored.append((frag, str(p.resolve()), p)) + scored.sort(key=lambda t: (t[0], t[1])) + return scored[0][2] + + def save_extracted_files( structures_dataset: "StructuresDataset", extracted_path: Path, @@ -79,54 +165,48 @@ def save_extracted_files( Path(structures_dataset.structures_path()).mkdir(exist_ok=True, parents=True) pdb_repo_path = structures_dataset.structures_path() - pdb_iterator = iglob(str(extracted_path) + "/**/*.pdb") - direct_pdb_iterator = iglob(str(extracted_path) + "/*.pdb") - - pdb_files_name_to_dir = { - Path(file).name.replace(".pdb", "").replace(".cif", ""): file for file in pdb_iterator - } - - direct_pdb_files_name_to_dir = { - Path(file).name.replace(".pdb", "").replace(".cif", ""): file for file in direct_pdb_iterator - } - - cif_iterator = iglob(str(extracted_path) + "/**/*.cif") - direct_cif_iterator = iglob(str(extracted_path) + "/*.cif") - - cif_files_name_to_dir = { - Path(file).name.replace(".pdb", "").replace(".cif", ""): file for file in cif_iterator - } - direct_cif_files_name_to_dir = { - Path(file).name.replace(".pdb", "").replace(".cif", ""): file for file in direct_cif_iterator + extracted_path = Path(extracted_path) + stem_to_paths = build_stem_to_paths(extracted_path) + picked_by_stem = { + stem: _pick_path_prefer_cif(paths) for stem, paths in stem_to_paths.items() } - files_name_to_dir = {**pdb_files_name_to_dir, **direct_pdb_files_name_to_dir, **cif_files_name_to_dir, **direct_cif_files_name_to_dir} - - logger.debug(f"extracted files: {len(files_name_to_dir)}") - - for name, path in files_name_to_dir.items(): - files_name_to_dir[name] = path.removeprefix(str(structures_dataset.input_path) + '/') - - present_files_set = set(files_name_to_dir.keys()) + logger.debug(f"extracted files: {len(stem_to_paths)} unique stems") + canonical_base_to_source_name: Dict[str, str] = {} + use_canonical_ids = ids is not None if ids is None: - files = list(files_name_to_dir.values()) + files = [str(picked_by_stem[s]) for s in sorted(picked_by_stem)] chunks = list(structures_dataset.chunk(files)) missing_files = None else: - logger.info(f"Searching for {len(ids)} files in {len(present_files_set)} already extracted files") - - ids_set = set(ids) - - wanted_files = present_files_set & ids_set - wanted_files = [f"{structures_dataset.input_path}/{files_name_to_dir[file]}" for file in wanted_files] - missing_files = list(ids_set - present_files_set) - - logger.info(f"Found {len(wanted_files)}, missing {len(missing_files)} out of {len(ids)} requested files") - ids = wanted_files + logger.info( + f"Searching for {len(ids)} ids among {len(stem_to_paths)} stems under {extracted_path}" + ) - chunks = list(structures_dataset.chunk(ids)) + wanted_items: List[Union[str, Tuple[str, str]]] = [] + seen_resolved: set[str] = set() + missing_files = [] + + for raw_id in ids: + rid = raw_id.strip() + resolved_paths = resolve_id(rid, stem_to_paths) + if not resolved_paths: + missing_files.append(raw_id) + continue + chosen = pick_single_path_for_canonical_id(resolved_paths) + key = str(chosen.resolve()) + if key not in seen_resolved: + seen_resolved.add(key) + wanted_items.append((key, rid)) + canonical_base_to_source_name[rid] = chosen.name + + logger.info( + f"Resolved {len(wanted_items)} file path(s), {len(missing_files)} missing " + f"out of {len(ids)} requested ids" + ) + chunks = list(structures_dataset.chunk(wanted_items)) mkdir_for_batches(pdb_repo_path, len(chunks)) @@ -164,24 +244,19 @@ def collect(result): input_structures_index = {} - for file_path_str in files_name_to_dir.values(): - file_path = Path(file_path_str) - file_name_without_extension = file_path.stem - - id_with_chain = no_chain_to_chain_dict.get(file_name_without_extension, None) - - if id_with_chain: + if use_canonical_ids: + for id_with_chain in sorted(new_files_index.keys()): + base = id_with_chain.rsplit("_", 1)[0] + src = canonical_base_to_source_name.get(base) + if src is not None: + input_structures_index[id_with_chain] = src + else: + for stem, picked in sorted(picked_by_stem.items(), key=lambda x: x[0]): + file_path = picked + id_with_chain = no_chain_to_chain_dict.get(stem, None) + if id_with_chain is None: + continue input_structures_index[id_with_chain] = file_path.name - # else this protein is missing from the archive - - - # for cif_file_name in cif_files_name_to_dir.keys(): - # match = re.match(r'^AF-(.+?)-F1-model_v\d+$', cif_file_name) - # if match: - # pdb_id = match.group(1) - # chain_id = list(file_to_pdb(retrieve_single_file(cif_files_name_to_dir[cif_file_name])).keys())[0].split('_')[-1] - # files_name_to_dir[pdb_id + "_" + chain_id] = files_name_to_dir[cif_file_name] - # del files_name_to_dir[cif_file_name] try: add_new_files_to_index(structures_dataset.dataset_index_file_path(), new_files_index, structures_dataset.config.data_path) @@ -193,7 +268,9 @@ def collect(result): def retrieve_protein_file_to_h5( - path_for_batch: Path, pdb_ids: Iterable[str], workers: List[str] = None + path_for_batch: Path, + pdb_ids: Iterable[Union[str, Tuple[str, str]]], + workers: List[str] = None, ) -> Tuple[List[str], str]: with worker_client() as client: start_time = time.time() @@ -230,12 +307,26 @@ def retrieve_protein_file_to_h5( return pdb_ids, h5_file_path -def retrieve_single_file(file_path): - file_path = Path(file_path) - file_name = file_path.stem +def retrieve_single_file( + item: Union[str, Tuple[str, str], List[str]], +): + """ + Load structure file for conversion. + + ``item`` is either a path string, or ``(path_str, canonical_pdb_code)`` where + ``canonical_pdb_code`` is the ids_file token used as ``cif_to_pdb`` / PDB key base + (e.g. UniProt accession), not the AF CIF filename stem. + """ + canonical: Optional[str] = None + if isinstance(item, (tuple, list)) and len(item) == 2: + file_path = Path(item[0]) + canonical = str(item[1]) + else: + file_path = Path(item) + pdb_code = canonical if canonical is not None else file_path.stem file_extension = file_path.suffix with open(file_path, "r") as file: - return file.read(), file_name, file_extension + return file.read(), pdb_code, file_extension def file_to_pdb(input_data):