From 189f2372fad6dfa68f8a63f1869755e162bc13d1 Mon Sep 17 00:00:00 2001 From: Davide Date: Tue, 26 May 2026 13:59:56 +0200 Subject: [PATCH 1/7] Add CLI options for output root and project name overrides --- deployment/README.md | 4 ++ deployment/srgan_hpc/cli.py | 34 ++++++++-- .../test_deployment/test_deployment_config.py | 65 +++++++++++++++++++ 3 files changed, 97 insertions(+), 6 deletions(-) diff --git a/deployment/README.md b/deployment/README.md index c162c56..f0ab2fd 100644 --- a/deployment/README.md +++ b/deployment/README.md @@ -33,6 +33,8 @@ srgan-hpc submit patch \ srgan-hpc submit grid \ --config deployment/configs/runtime.default.yaml \ + --output-root /data/srgan_new_area \ + --project-name srgan_new_area \ --lat1 52.3 --lon1 12.9 \ --lat2 52.7 --lon2 13.8 \ --start-date 2025-07-01 \ @@ -62,4 +64,6 @@ AOI submission accepts either a `.shp` file or a directory containing exactly on `staging.item_strategy: mosaic_valid` is the default for STAC staging. When a Cubo cutout intersects multiple Sentinel-2 tiles, the launcher ranks candidate items by valid-data coverage near the cutout center, then fills remaining nodata pixels from the other candidates before inference. Use `item_strategy: fixed_index` only when you explicitly want legacy `staging.image_index` behavior. +For repeated runs with the same settings, keep one standard config file and override only the destination at submit time with `--output-root`; use `--project-name` when you also want a readable run-name prefix. + `deliver-bbox` merges patch outputs per run and writes clipped GeoTIFFs for sharing in GIS tools. diff --git a/deployment/srgan_hpc/cli.py b/deployment/srgan_hpc/cli.py index 79314ba..73e4589 100644 --- a/deployment/srgan_hpc/cli.py +++ b/deployment/srgan_hpc/cli.py @@ -64,6 +64,14 @@ def build_parser() -> argparse.ArgumentParser: def _add_submit_common_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--config", required=True) + parser.add_argument( + "--output-root", + help="Override config.output_root for this run without creating a new runtime YAML.", + ) + parser.add_argument( + "--project-name", + help="Override config.project_name for this run without creating a new runtime YAML.", + ) parser.add_argument("--start-date", required=True) parser.add_argument("--end-date", required=True) parser.add_argument("--script-path") @@ -71,6 +79,21 @@ def _add_submit_common_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--verbose", action="store_true") +def _submit_config_overrides(args: argparse.Namespace) -> dict[str, str]: + overrides: dict[str, str] = {} + if args.output_root: + overrides["output_root"] = args.output_root + if args.project_name: + overrides["project_name"] = args.project_name + return overrides + + +def _load_submit_config(args: argparse.Namespace): + from deployment.srgan_hpc.config import load_runtime_config + + return load_runtime_config(args.config, overrides=_submit_config_overrides(args)) + + def _resolve_script_path(script_path: str | None) -> Path: if script_path is None: return bundled_slurm_entrypoint().resolve() @@ -98,13 +121,12 @@ def _handle_validate(args: argparse.Namespace) -> int: def _handle_submit_patch(args: argparse.Namespace) -> int: - from deployment.srgan_hpc.config import load_runtime_config from deployment.srgan_hpc.logging_utils import configure_logging from deployment.srgan_hpc.patching import Patch from deployment.srgan_hpc.submit import submit_patch_run logger = configure_logging(verbose=args.verbose) - config = load_runtime_config(args.config) + config = _load_submit_config(args) patch = Patch( patch_id="patch_000001", latitude=args.lat, @@ -134,13 +156,13 @@ def _handle_submit_patch(args: argparse.Namespace) -> int: def _handle_submit_grid(args: argparse.Namespace) -> int: - from deployment.srgan_hpc.config import load_runtime_config, patch_resolution + from deployment.srgan_hpc.config import patch_resolution from deployment.srgan_hpc.logging_utils import configure_logging from deployment.srgan_hpc.patching import build_patches from deployment.srgan_hpc.submit import submit_grid_run logger = configure_logging(verbose=args.verbose) - config = load_runtime_config(args.config) + config = _load_submit_config(args) patches = build_patches( args.lat1, args.lon1, @@ -178,12 +200,12 @@ def _handle_submit_grid(args: argparse.Namespace) -> int: def _handle_submit_aoi(args: argparse.Namespace) -> int: from deployment.srgan_hpc.aoi import select_aoi_patches - from deployment.srgan_hpc.config import load_runtime_config, patch_resolution + from deployment.srgan_hpc.config import patch_resolution from deployment.srgan_hpc.logging_utils import configure_logging from deployment.srgan_hpc.submit import submit_aoi_run logger = configure_logging(verbose=args.verbose) - config = load_runtime_config(args.config) + config = _load_submit_config(args) aoi_path = args.aoi_path or config.aoi.path if aoi_path is None: raise ValueError("AOI path must be provided via --aoi-path or config.aoi.path") diff --git a/tests/test_deployment/test_deployment_config.py b/tests/test_deployment/test_deployment_config.py index f155c3b..9fa823e 100644 --- a/tests/test_deployment/test_deployment_config.py +++ b/tests/test_deployment/test_deployment_config.py @@ -4,6 +4,7 @@ import pytest +from deployment.srgan_hpc.cli import build_parser from deployment.srgan_hpc.config import ( RuntimeConfig, enabled_product_names, @@ -64,3 +65,67 @@ def test_staging_search_filters_are_loaded(tmp_path: Path) -> None: assert config.staging.search_query == {"s2:mgrs_tile": {"eq": "43PGQ"}} assert config.staging.search_max_items == 1 assert config.staging.search_limit == 1 + + +def test_runtime_config_overrides_output_root_and_project_name(tmp_path: Path) -> None: + config_path = tmp_path / "runtime.yaml" + output_root = tmp_path / "custom-runs" + config_path.write_text( + """ +project_name: from_yaml +output_root: yaml-runs +""", + encoding="utf-8", + ) + + config = load_runtime_config( + config_path, + overrides={ + "project_name": "from_cli", + "output_root": str(output_root), + }, + ) + + assert config.project_name == "from_cli" + assert config.output_root == output_root + + +def test_submit_parser_accepts_output_overrides_for_all_submit_modes() -> None: + parser = build_parser() + common_args = [ + "--config", + "runtime.yaml", + "--output-root", + "/tmp/srgan-run", + "--project-name", + "srgan-run", + "--start-date", + "2025-01-01", + "--end-date", + "2025-01-02", + "--dry-run", + ] + + patch = parser.parse_args( + ["submit", "patch", *common_args, "--lat", "13.0", "--lon", "77.6"] + ) + grid = parser.parse_args( + [ + "submit", + "grid", + *common_args, + "--lat1", + "13.0", + "--lon1", + "77.6", + "--lat2", + "13.1", + "--lon2", + "77.7", + ] + ) + aoi = parser.parse_args(["submit", "aoi", *common_args, "--aoi-path", "area.shp"]) + + for args in (patch, grid, aoi): + assert args.output_root == "/tmp/srgan-run" + assert args.project_name == "srgan-run" From 6b7d1d68124346f881dc391a47b43be1644e0b89 Mon Sep 17 00:00:00 2001 From: Davide Date: Tue, 26 May 2026 15:16:41 +0200 Subject: [PATCH 2/7] Enhance CLI and staging functionality with summary reporting and improved logging --- deployment/srgan_hpc/cli.py | 75 ++++++++++++++++++++++++++++- deployment/srgan_hpc/staging.py | 83 ++++++++++++++++++++++++--------- deployment/srgan_hpc/submit.py | 2 + 3 files changed, 136 insertions(+), 24 deletions(-) diff --git a/deployment/srgan_hpc/cli.py b/deployment/srgan_hpc/cli.py index 73e4589..31756df 100644 --- a/deployment/srgan_hpc/cli.py +++ b/deployment/srgan_hpc/cli.py @@ -94,6 +94,32 @@ def _load_submit_config(args: argparse.Namespace): return load_runtime_config(args.config, overrides=_submit_config_overrides(args)) +def _write_and_print_summary( + *, + run_dir: Path, + config, + submission, + request: dict[str, object], + start_date: str, + end_date: str, +) -> dict[str, str]: + from deployment.srgan_hpc.submission_summary import ( + format_submission_summary, + write_submission_summary, + ) + + summary, summary_json, summary_txt = write_submission_summary( + run_dir=run_dir, + config=config, + submission=submission, + request=request, + start_date=start_date, + end_date=end_date, + ) + print(format_submission_summary(summary)) + return {"json": str(summary_json), "text": str(summary_txt)} + + def _resolve_script_path(script_path: str | None) -> Path: if script_path is None: return bundled_slurm_entrypoint().resolve() @@ -146,9 +172,27 @@ def _handle_submit_patch(args: argparse.Namespace) -> int: dry_run=args.dry_run, ) logger.info("submitted patch run_id=%s run_dir=%s", run_id, run_dir) + summary_paths = _write_and_print_summary( + run_dir=run_dir, + config=config, + submission=submission, + request={ + "type": "patch", + "lat": args.lat, + "lon": args.lon, + "planned_patch_count": 1, + }, + start_date=args.start_date, + end_date=args.end_date, + ) print( json.dumps( - {"run_id": run_id, "run_dir": str(run_dir), "submission": submission}, + { + "run_id": run_id, + "run_dir": str(run_dir), + "submission": submission, + "summary": summary_paths, + }, indent=2, ) ) @@ -184,6 +228,21 @@ def _handle_submit_grid(args: argparse.Namespace) -> int: logger.info( "submitted grid run_id=%s run_dir=%s patches=%d", run_id, run_dir, len(patches) ) + summary_paths = _write_and_print_summary( + run_dir=run_dir, + config=config, + submission=submission, + request={ + "type": "grid", + "lat1": args.lat1, + "lon1": args.lon1, + "lat2": args.lat2, + "lon2": args.lon2, + "planned_patch_count": len(patches), + }, + start_date=args.start_date, + end_date=args.end_date, + ) print( json.dumps( { @@ -191,6 +250,7 @@ def _handle_submit_grid(args: argparse.Namespace) -> int: "run_dir": str(run_dir), "patches": len(patches), "submission": submission, + "summary": summary_paths, }, indent=2, ) @@ -244,6 +304,19 @@ def _handle_submit_aoi(args: argparse.Namespace) -> int: } if selection.aoi_layer is not None: payload["aoi_layer"] = selection.aoi_layer + payload["summary"] = _write_and_print_summary( + run_dir=run_dir, + config=config, + submission=submission, + request={ + "type": "aoi", + "aoi_path": str(selection.aoi_path), + "aoi_layer": selection.aoi_layer, + "planned_patch_count": len(selection.patches), + }, + start_date=args.start_date, + end_date=args.end_date, + ) print(json.dumps(payload, indent=2)) return 0 diff --git a/deployment/srgan_hpc/staging.py b/deployment/srgan_hpc/staging.py index 3bf92ed..9d0b526 100644 --- a/deployment/srgan_hpc/staging.py +++ b/deployment/srgan_hpc/staging.py @@ -135,6 +135,19 @@ def _point_geometry(longitude: float, latitude: float) -> dict[str, object]: return {"type": "Point", "coordinates": [longitude, latitude]} +def _log_label(patch_id: str | None, product_name: str | None) -> str: + return f"{patch_id or 'patch'} {product_name or 'product'}" + + +def _format_cloud_cover(value: object) -> str: + if value is None: + return "n/a" + try: + return f"{float(value):.3f}%" + except (TypeError, ValueError): + return str(value) + + def _search_kwargs_from_config(config: StagingConfig) -> dict[str, Any]: search_kwargs: dict[str, Any] = {} if config.search_query: @@ -153,7 +166,9 @@ def _auto_select_item_ids( start_date: str, end_date: str, config: StagingConfig, -) -> list[str]: + patch_id: str | None = None, + product_name: str | None = None, +) -> tuple[list[str], list[dict[str, Any]]]: try: import pystac_client except ImportError as exc: # pragma: no cover @@ -180,13 +195,20 @@ def _auto_select_item_ids( ) selected = items[0] + selected_report = { + "id": selected.id, + "tile": selected.properties.get("s2:mgrs_tile"), + "cloud_cover": selected.properties.get("eo:cloud_cover"), + "datetime": selected.properties.get("datetime"), + } LOGGER.info( - "auto-selected STAC item id=%s tile=%s cloud_cover=%s", + "[stac] %s tile=%s cloud=%s item=%s", + _log_label(patch_id, product_name), + selected_report["tile"], + _format_cloud_cover(selected_report["cloud_cover"]), selected.id, - selected.properties.get("s2:mgrs_tile"), - selected.properties.get("eo:cloud_cover"), ) - return [selected.id] + return [selected.id], [selected_report] def create_cube_with_retry( @@ -199,6 +221,8 @@ def create_cube_with_retry( bands: list[str], edge_size: int, resolution: int, + patch_id: str | None = None, + product_name: str | None = None, ): try: import cubo @@ -209,20 +233,24 @@ def create_cube_with_retry( ) from exc search_kwargs = _search_kwargs_from_config(config) + selected_item_reports: list[dict[str, Any]] = [] if config.auto_select_item: - search_kwargs["ids"] = _auto_select_item_ids( + selected_item_ids, selected_item_reports = _auto_select_item_ids( latitude=latitude, longitude=longitude, start_date=start_date, end_date=end_date, config=config, + patch_id=patch_id, + product_name=product_name, ) + search_kwargs["ids"] = selected_item_ids search_kwargs.setdefault("max_items", 1) search_kwargs.setdefault("limit", 1) for attempt, delay in enumerate(config.rate_limit_retry_delays_seconds, start=1): try: - return cubo.create( + cube = cubo.create( lat=latitude, lon=longitude, collection=config.collection, @@ -233,6 +261,7 @@ def create_cube_with_retry( resolution=resolution, **search_kwargs, ) + return cube, selected_item_reports except Exception as exc: # pragma: no cover if not (config.retry_on_rate_limit and is_retryable_staging_error(exc)): raise @@ -247,7 +276,7 @@ def create_cube_with_retry( ) time.sleep(delay) - return cubo.create( + cube = cubo.create( lat=latitude, lon=longitude, collection=config.collection, @@ -258,6 +287,7 @@ def create_cube_with_retry( resolution=resolution, **search_kwargs, ) + return cube, selected_item_reports def _cube_item_label(cube, index: int) -> str: @@ -360,17 +390,21 @@ def stage_cutout( resolution: int, output_path: Path, metadata_path: Path | None = None, + patch_id: str | None = None, + product_name: str | None = None, ) -> Path: ensure_proj_env() LOGGER.info( - "staging cubo cutout lat=%s lon=%s start_date=%s end_date=%s output=%s", + "[stage] %s lat=%.6f lon=%.6f date=%s/%s edge=%s res=%sm", + _log_label(patch_id, product_name), latitude, longitude, start_date, end_date, - output_path, + edge_size, + resolution, ) - cube = create_cube_with_retry( + cube, auto_selected_items = create_cube_with_retry( latitude=latitude, longitude=longitude, start_date=start_date, @@ -379,23 +413,22 @@ def stage_cutout( bands=bands, edge_size=edge_size, resolution=resolution, + patch_id=patch_id, + product_name=product_name, ) cube, diagnostics = _select_or_mosaic_time_items(cube, config) + if auto_selected_items: + diagnostics["auto_selected_items"] = auto_selected_items cube = cube.transpose("band", "y", "x") stats = ensure_cube_has_valid_data(cube) diagnostics["validity_stats"] = stats LOGGER.info( - "validated staged cutout lat=%s lon=%s stats=%s staging=%s", - latitude, - longitude, - stats, - { - "item_strategy": diagnostics.get("item_strategy"), - "candidate_count": diagnostics.get("candidate_count"), - "selected_indices": diagnostics.get("selected_indices"), - "final_center_nonzero_fraction": diagnostics.get("final_center_nonzero_fraction"), - "final_full_nonzero_fraction": diagnostics.get("final_full_nonzero_fraction"), - }, + "[valid] %s full=%.3f center=%.3f nonzero=%s/%s", + _log_label(patch_id, product_name), + float(diagnostics.get("final_full_nonzero_fraction") or 0.0), + float(diagnostics.get("final_center_nonzero_fraction") or 0.0), + stats["nonzero_pixels"], + stats["total_pixels"], ) epsg_text = str(cube.attrs.get("epsg", "") or cube.coords.get("epsg", "")) @@ -417,7 +450,11 @@ def stage_cutout( ) if metadata_path is not None: write_json(metadata_path, diagnostics) + try: + output_label = str(output_path.relative_to(output_path.parent.parent)) + except ValueError: + output_label = str(output_path) LOGGER.info( - "wrote staged cutout lat=%s lon=%s output=%s", latitude, longitude, output_path + "[write] %s %s", _log_label(patch_id, product_name), output_label ) return output_path.resolve() diff --git a/deployment/srgan_hpc/submit.py b/deployment/srgan_hpc/submit.py index e5486b4..f29af4b 100644 --- a/deployment/srgan_hpc/submit.py +++ b/deployment/srgan_hpc/submit.py @@ -100,6 +100,8 @@ def _stage_patch_inputs( resolution=product.resolution, output_path=input_tif, metadata_path=patch_root / "metadata" / f"{product_name}_staging.json", + patch_id=patch.patch_id, + product_name=product_name, ) From f679e88994ed47aef5e17b113dccae42441a9eb7 Mon Sep 17 00:00:00 2001 From: Davide Date: Tue, 26 May 2026 15:16:53 +0200 Subject: [PATCH 3/7] Add submission summary generation and reporting functionality --- deployment/srgan_hpc/submission_summary.py | 261 ++++++++++++++++++ .../test_deployment_submission_summary.py | 117 ++++++++ 2 files changed, 378 insertions(+) create mode 100644 deployment/srgan_hpc/submission_summary.py create mode 100644 tests/test_deployment/test_deployment_submission_summary.py diff --git a/deployment/srgan_hpc/submission_summary.py b/deployment/srgan_hpc/submission_summary.py new file mode 100644 index 0000000..59f782c --- /dev/null +++ b/deployment/srgan_hpc/submission_summary.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +import json +from pathlib import Path +from statistics import mean +from typing import Any, Mapping + +from deployment.srgan_hpc.config import RuntimeConfig, enabled_product_names +from deployment.srgan_hpc.manifests import read_yaml, write_json + + +def _read_json(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as handle: + data = json.load(handle) + if not isinstance(data, dict): + raise ValueError(f"Expected mapping in {path}") + return data + + +def _as_float(value: object) -> float | None: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _stats(values: list[float]) -> dict[str, float] | None: + if not values: + return None + return {"min": min(values), "mean": mean(values), "max": max(values)} + + +def _command_array_range(command: str | None) -> str | None: + if not command: + return None + for token in command.split(): + if token.startswith("--array="): + return token.split("=", 1)[1] + return None + + +def _collect_staging(run_dir: Path) -> dict[str, Any]: + product_records: dict[str, list[dict[str, Any]]] = {} + for metadata_path in sorted(run_dir.glob("patches/*/metadata/*_staging.json")): + product_name = metadata_path.stem.removesuffix("_staging") + payload = _read_json(metadata_path) + payload["_patch_id"] = metadata_path.parent.parent.name + product_records.setdefault(product_name, []).append(payload) + + summary: dict[str, Any] = {} + for product_name, records in sorted(product_records.items()): + center = [ + value + for record in records + if (value := _as_float(record.get("final_center_nonzero_fraction"))) + is not None + ] + full = [ + value + for record in records + if (value := _as_float(record.get("final_full_nonzero_fraction"))) + is not None + ] + clouds: list[float] = [] + tiles: set[str] = set() + item_ids: set[str] = set() + valid_pixels = 0 + nonzero_pixels = 0 + total_pixels = 0 + for record in records: + for item in record.get("auto_selected_items", []) or []: + if item.get("id"): + item_ids.add(str(item["id"])) + if item.get("tile"): + tiles.add(str(item["tile"])) + cloud = _as_float(item.get("cloud_cover")) + if cloud is not None: + clouds.append(cloud) + validity = record.get("validity_stats", {}) or {} + valid_pixels += int(validity.get("valid_pixels", 0) or 0) + nonzero_pixels += int(validity.get("nonzero_pixels", 0) or 0) + total_pixels += int(validity.get("total_pixels", 0) or 0) + + summary[product_name] = { + "patches": len(records), + "tiles": sorted(tiles), + "item_ids": sorted(item_ids), + "cloud_cover": _stats(clouds), + "center_nonzero_fraction": _stats(center), + "full_nonzero_fraction": _stats(full), + "validity": { + "valid_pixels": valid_pixels, + "nonzero_pixels": nonzero_pixels, + "total_pixels": total_pixels, + }, + } + return summary + + +def build_submission_summary( + *, + run_dir: Path, + config: RuntimeConfig, + submission: Mapping[str, object], + request: Mapping[str, object], + start_date: str, + end_date: str, +) -> dict[str, Any]: + run_manifest = read_yaml(run_dir / "run_manifest.yaml") + skipped = list(run_manifest.get("skipped", []) or []) + patch_count = int(run_manifest.get("patch_count", 0) or 0) + planned_patch_count = int(request.get("planned_patch_count", patch_count) or 0) + staging = _collect_staging(run_dir) + warnings: list[str] = [] + if skipped: + warnings.append(f"{len(skipped)} patch(es) were skipped during staging") + for product_name, product_summary in staging.items(): + full = product_summary.get("full_nonzero_fraction") + min_full = full.get("min") if isinstance(full, dict) else None + if min_full is not None and min_full < config.staging.min_full_nonzero_fraction: + warnings.append( + f"{product_name} minimum full coverage {min_full:.3f} is below " + f"{config.staging.min_full_nonzero_fraction:.3f}" + ) + + command = str(submission.get("command", "")) if submission.get("command") else None + summary = { + "run": { + "run_id": run_manifest.get("run_id"), + "run_dir": str(run_dir), + "mode": run_manifest.get("mode"), + "product_mode": config.mode, + "start_date": start_date, + "end_date": end_date, + "config_path": str(config.config_path) if config.config_path else None, + "project_name": config.project_name, + "output_root": str(config.output_root), + }, + "request": dict(request), + "patches": { + "planned": planned_patch_count, + "submitted": patch_count, + "skipped": len(skipped), + }, + "products": { + "enabled": enabled_product_names(config), + "bands": { + "rgbnir": list(config.rgbnir.bands), + "swir": list(config.swir.bands), + }, + }, + "staging": staging, + "slurm": { + **dict(submission), + "array": _command_array_range(command), + }, + "paths": { + "logs": str(run_dir / "logs"), + "submission": str(run_dir / "submission"), + "run_manifest": str(run_dir / "run_manifest.yaml"), + "resolved_config": str(run_dir / "resolved_config.yaml"), + }, + "warnings": warnings, + } + return summary + + +def format_submission_summary(summary: Mapping[str, Any]) -> str: + run = summary["run"] + patches = summary["patches"] + slurm = summary["slurm"] + paths = summary["paths"] + lines = [ + "", + "SRGAN submission summary", + f"Run: {run['run_id']}", + f"Mode: {run['mode']} | {run['product_mode']}", + f"Dates: {run['start_date']} to {run['end_date']}", + f"Run dir: {run['run_dir']}", + ( + "Patches: " + f"{patches['submitted']} submitted, {patches['skipped']} skipped " + f"({patches['planned']} planned)" + ), + ] + if slurm.get("mode") == "dry-run": + lines.append(f"Slurm: dry-run, array {slurm.get('array') or 'none'}") + else: + lines.append( + f"Slurm: job {slurm.get('job_id', 'unknown')}, " + f"array {slurm.get('array') or 'none'}" + ) + + staging = summary.get("staging", {}) + if staging: + lines.append("Staging:") + for product_name, product_summary in staging.items(): + cloud = product_summary.get("cloud_cover") or {} + full = product_summary.get("full_nonzero_fraction") or {} + center = product_summary.get("center_nonzero_fraction") or {} + validity = product_summary.get("validity") or {} + cloud_text = "n/a" + if cloud: + cloud_text = ( + f"{cloud['min']:.3f}-{cloud['max']:.3f}% " + f"(mean {cloud['mean']:.3f}%)" + ) + tiles = ",".join(product_summary.get("tiles") or []) or "n/a" + lines.append( + f" {product_name}: patches={product_summary['patches']} " + f"full_min={full.get('min', 0.0):.3f} " + f"center_min={center.get('min', 0.0):.3f} " + f"nonzero={validity.get('nonzero_pixels', 0)}/" + f"{validity.get('total_pixels', 0)} " + f"cloud={cloud_text} tiles={tiles}" + ) + else: + lines.append("Staging: dry-run or no staging metadata available") + + warnings = list(summary.get("warnings") or []) + if warnings: + lines.append("Warnings:") + lines.extend(f" - {warning}" for warning in warnings) + + lines.extend( + [ + "Files:", + f" Summary JSON: {paths.get('summary_json')}", + f" Summary TXT: {paths.get('summary_txt')}", + f" Logs: {paths.get('logs')}", + ] + ) + return "\n".join(lines) + + +def write_submission_summary( + *, + run_dir: Path, + config: RuntimeConfig, + submission: Mapping[str, object], + request: Mapping[str, object], + start_date: str, + end_date: str, +) -> tuple[dict[str, Any], Path, Path]: + summary = build_submission_summary( + run_dir=run_dir, + config=config, + submission=submission, + request=request, + start_date=start_date, + end_date=end_date, + ) + summary_json = run_dir / "submission" / "summary.json" + summary_txt = run_dir / "submission" / "summary.txt" + summary["paths"]["summary_json"] = str(summary_json) + summary["paths"]["summary_txt"] = str(summary_txt) + write_json(summary_json, summary) + summary_txt.write_text(format_submission_summary(summary) + "\n", encoding="utf-8") + return summary, summary_json, summary_txt diff --git a/tests/test_deployment/test_deployment_submission_summary.py b/tests/test_deployment/test_deployment_submission_summary.py new file mode 100644 index 0000000..b361970 --- /dev/null +++ b/tests/test_deployment/test_deployment_submission_summary.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from pathlib import Path + +from deployment.srgan_hpc.config import RuntimeConfig +from deployment.srgan_hpc.manifests import read_yaml, write_json, write_yaml +from deployment.srgan_hpc.submission_summary import write_submission_summary + + +def test_submission_summary_writes_dry_run_report(tmp_path: Path) -> None: + run_dir = tmp_path / "runs" / "srgan_001" + write_yaml( + run_dir / "run_manifest.yaml", + { + "run_id": "srgan_001", + "mode": "grid", + "patch_count": 1, + "skipped_count": 0, + "tasks": [{"patch_id": "patch_000001", "manifest": "patches/patch_000001/manifest.yaml"}], + "skipped": [], + }, + ) + config = RuntimeConfig(output_root=tmp_path / "runs", project_name="srgan") + + summary, summary_json, summary_txt = write_submission_summary( + run_dir=run_dir, + config=config, + submission={"mode": "dry-run", "command": "sbatch --array=0-0 script manifest"}, + request={"type": "grid", "planned_patch_count": 1}, + start_date="2025-01-01", + end_date="2025-01-02", + ) + + assert summary["patches"] == {"planned": 1, "submitted": 1, "skipped": 0} + assert summary["slurm"]["array"] == "0-0" + assert summary_json.exists() + assert summary_txt.exists() + assert "SRGAN submission summary" in summary_txt.read_text(encoding="utf-8") + assert "Staging: dry-run or no staging metadata available" in summary_txt.read_text( + encoding="utf-8" + ) + + +def test_submission_summary_aggregates_staging_metadata(tmp_path: Path) -> None: + run_dir = tmp_path / "runs" / "srgan_001" + write_yaml( + run_dir / "run_manifest.yaml", + { + "run_id": "srgan_001", + "mode": "grid", + "patch_count": 1, + "skipped_count": 0, + "tasks": [{"patch_id": "patch_000001", "manifest": "patches/patch_000001/manifest.yaml"}], + "skipped": [], + }, + ) + metadata_dir = run_dir / "patches" / "patch_000001" / "metadata" + write_json( + metadata_dir / "rgbnir_staging.json", + { + "auto_selected_items": [ + { + "id": "S2_ITEM", + "tile": "33UUU", + "cloud_cover": 1.5, + "datetime": "2025-01-01T10:00:00Z", + } + ], + "final_center_nonzero_fraction": 1.0, + "final_full_nonzero_fraction": 0.997, + "validity_stats": { + "total_pixels": 100, + "valid_pixels": 100, + "nonzero_pixels": 99, + }, + }, + ) + write_json( + metadata_dir / "swir_staging.json", + { + "auto_selected_items": [ + { + "id": "S2_ITEM", + "tile": "33UUU", + "cloud_cover": 1.5, + "datetime": "2025-01-01T10:00:00Z", + } + ], + "final_center_nonzero_fraction": 1.0, + "final_full_nonzero_fraction": 1.0, + "validity_stats": { + "total_pixels": 25, + "valid_pixels": 25, + "nonzero_pixels": 25, + }, + }, + ) + config = RuntimeConfig(output_root=tmp_path / "runs", project_name="srgan") + + summary, summary_json, summary_txt = write_submission_summary( + run_dir=run_dir, + config=config, + submission={"job_id": "12345", "stdout": "Submitted batch job 12345", "stderr": ""}, + request={"type": "grid", "planned_patch_count": 1}, + start_date="2025-01-01", + end_date="2025-01-02", + ) + + assert summary["staging"]["rgbnir"]["tiles"] == ["33UUU"] + assert summary["staging"]["rgbnir"]["cloud_cover"]["mean"] == 1.5 + assert summary["staging"]["rgbnir"]["full_nonzero_fraction"]["min"] == 0.997 + assert summary["staging"]["rgbnir"]["validity"]["nonzero_pixels"] == 99 + assert read_yaml(run_dir / "run_manifest.yaml")["patch_count"] == 1 + assert summary_json.exists() + text = summary_txt.read_text(encoding="utf-8") + assert "rgbnir: patches=1" in text + assert "cloud=1.500-1.500%" in text From 6d098b7e8c09111c897d0b8a8b06209ebda0c771 Mon Sep 17 00:00:00 2001 From: Davide Date: Wed, 27 May 2026 10:49:13 +0200 Subject: [PATCH 4/7] Refactor cloud cover formatting and enhance report selection logic now it ccreates a union with the selected region tiles --- deployment/srgan_hpc/staging.py | 59 +++++++---- .../test_deployment_staging.py | 98 +++++++++++++++++++ 2 files changed, 137 insertions(+), 20 deletions(-) diff --git a/deployment/srgan_hpc/staging.py b/deployment/srgan_hpc/staging.py index 9d0b526..b3017c2 100644 --- a/deployment/srgan_hpc/staging.py +++ b/deployment/srgan_hpc/staging.py @@ -139,13 +139,22 @@ def _log_label(patch_id: str | None, product_name: str | None) -> str: return f"{patch_id or 'patch'} {product_name or 'product'}" -def _format_cloud_cover(value: object) -> str: - if value is None: +def _format_cloud_cover_range(reports: list[dict[str, Any]]) -> str: + values: list[float] = [] + for report in reports: + value = report.get("cloud_cover") + if value is None: + continue + try: + values.append(float(value)) + except (TypeError, ValueError): + continue + + if not values: return "n/a" - try: - return f"{float(value):.3f}%" - except (TypeError, ValueError): - return str(value) + if len(values) == 1: + return f"{values[0]:.3f}%" + return f"{min(values):.3f}-{max(values):.3f}%" def _search_kwargs_from_config(config: StagingConfig) -> dict[str, Any]: @@ -194,21 +203,31 @@ def _auto_select_item_ids( details={"latitude": int(latitude * 1_000_000), "longitude": int(longitude * 1_000_000)}, ) - selected = items[0] - selected_report = { - "id": selected.id, - "tile": selected.properties.get("s2:mgrs_tile"), - "cloud_cover": selected.properties.get("eo:cloud_cover"), - "datetime": selected.properties.get("datetime"), - } + selected_reports = [ + { + "id": item.id, + "tile": item.properties.get("s2:mgrs_tile"), + "cloud_cover": item.properties.get("eo:cloud_cover"), + "datetime": item.properties.get("datetime"), + } + for item in items + ] + tiles = sorted( + { + str(report["tile"]) + for report in selected_reports + if report.get("tile") is not None + } + ) LOGGER.info( - "[stac] %s tile=%s cloud=%s item=%s", + "[stac] %s candidates=%s tiles=%s cloud_range=%s first_item=%s", _log_label(patch_id, product_name), - selected_report["tile"], - _format_cloud_cover(selected_report["cloud_cover"]), - selected.id, + len(selected_reports), + ",".join(tiles) if tiles else "n/a", + _format_cloud_cover_range(selected_reports), + selected_reports[0]["id"], ) - return [selected.id], [selected_report] + return [item.id for item in items], selected_reports def create_cube_with_retry( @@ -245,8 +264,8 @@ def create_cube_with_retry( product_name=product_name, ) search_kwargs["ids"] = selected_item_ids - search_kwargs.setdefault("max_items", 1) - search_kwargs.setdefault("limit", 1) + search_kwargs.setdefault("max_items", len(selected_item_ids)) + search_kwargs.setdefault("limit", len(selected_item_ids)) for attempt, delay in enumerate(config.rate_limit_retry_delays_seconds, start=1): try: diff --git a/tests/test_deployment/test_deployment_staging.py b/tests/test_deployment/test_deployment_staging.py index d48ac6d..d3d5afc 100644 --- a/tests/test_deployment/test_deployment_staging.py +++ b/tests/test_deployment/test_deployment_staging.py @@ -1,10 +1,16 @@ from __future__ import annotations +import sys +import types + import numpy as np import pytest +from deployment.srgan_hpc.config import StagingConfig from deployment.srgan_hpc.staging import ( + _auto_select_item_ids, candidate_coverage_report, + create_cube_with_retry, is_retryable_staging_error, mosaic_candidates_by_valid_pixels, order_candidates_by_coverage, @@ -17,6 +23,29 @@ def __init__(self, status_code: int) -> None: super().__init__(f"HTTP {status_code}") +class FakeItem: + def __init__(self, item_id: str, tile: str, cloud_cover: float) -> None: + self.id = item_id + self.properties = { + "s2:mgrs_tile": tile, + "eo:cloud_cover": cloud_cover, + "datetime": "2024-08-13T10:15:59Z", + } + + +def install_fake_stac(monkeypatch: pytest.MonkeyPatch, items: list[FakeItem]) -> None: + class FakeSearch: + def items(self) -> list[FakeItem]: + return items + + class FakeCatalog: + def search(self, **_kwargs): + return FakeSearch() + + client = types.SimpleNamespace(open=lambda _url: FakeCatalog()) + monkeypatch.setitem(sys.modules, "pystac_client", types.SimpleNamespace(Client=client)) + + def test_retryable_staging_error_detects_rate_limit_status() -> None: assert is_retryable_staging_error(ResponseError(429)) @@ -55,3 +84,72 @@ def test_mosaic_valid_strategy_patches_missing_pixels_from_second_candidate() -> assert np.all(mosaic[:, :, :4] == 10) assert np.all(mosaic[:, :, 4:] == 20) assert candidate_coverage_report(mosaic)["full_nonzero_fraction"] == pytest.approx(1.0) + + +def test_auto_select_item_ids_returns_all_stac_candidates( + monkeypatch: pytest.MonkeyPatch, +) -> None: + install_fake_stac( + monkeypatch, + [ + FakeItem("item-a", "32UPA", 1.2), + FakeItem("item-b", "32UPB", 0.3), + FakeItem("item-c", "33UUU", 2.4), + ], + ) + + item_ids, reports = _auto_select_item_ids( + latitude=50.1, + longitude=15.1, + start_date="2024-08-12", + end_date="2024-08-14", + config=StagingConfig(auto_select_item=True, auto_select_item_limit=3), + patch_id="patch_000001", + product_name="rgbnir", + ) + + assert item_ids == ["item-a", "item-b", "item-c"] + assert [report["tile"] for report in reports] == ["32UPA", "32UPB", "33UUU"] + + +def test_create_cube_with_retry_passes_all_auto_selected_ids_to_cubo( + monkeypatch: pytest.MonkeyPatch, +) -> None: + install_fake_stac( + monkeypatch, + [ + FakeItem("item-a", "32UPA", 1.2), + FakeItem("item-b", "32UPB", 0.3), + ], + ) + captured_kwargs: dict[str, object] = {} + + def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return object() + + monkeypatch.setitem(sys.modules, "cubo", types.SimpleNamespace(create=fake_create)) + monkeypatch.setitem(sys.modules, "rioxarray", types.SimpleNamespace()) + + cube, reports = create_cube_with_retry( + latitude=50.1, + longitude=15.1, + start_date="2024-08-12", + end_date="2024-08-14", + config=StagingConfig( + auto_select_item=True, + auto_select_item_limit=2, + rate_limit_retry_delays_seconds=[], + ), + bands=["B04", "B03"], + edge_size=4096, + resolution=10, + patch_id="patch_000001", + product_name="rgbnir", + ) + + assert cube is not None + assert captured_kwargs["ids"] == ["item-a", "item-b"] + assert captured_kwargs["max_items"] == 2 + assert captured_kwargs["limit"] == 2 + assert [report["id"] for report in reports] == ["item-a", "item-b"] From 409d1464ae9c58efba758643ac07fa4dfc50ac71 Mon Sep 17 00:00:00 2001 From: Davide Date: Wed, 27 May 2026 10:59:51 +0200 Subject: [PATCH 5/7] Add collect job submission functionality and enhance SLURM command building --- deployment/srgan_hpc/__init__.py | 4 ++ deployment/srgan_hpc/slurm.py | 8 +++- deployment/srgan_hpc/submission_summary.py | 9 ++++ deployment/srgan_hpc/submit.py | 36 ++++++++++++++- .../test_deployment/test_deployment_slurm.py | 41 +++++++++++++++++ .../test_deployment/test_deployment_submit.py | 45 ++++++++++++++++++- 6 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 tests/test_deployment/test_deployment_slurm.py diff --git a/deployment/srgan_hpc/__init__.py b/deployment/srgan_hpc/__init__.py index ff42710..2b63b57 100644 --- a/deployment/srgan_hpc/__init__.py +++ b/deployment/srgan_hpc/__init__.py @@ -15,4 +15,8 @@ def bundled_slurm_entrypoint() -> Path: return Path(__file__).resolve().parent / "slurm" / "slurm_task_entrypoint.sh" +def bundled_slurm_collect_entrypoint() -> Path: + return Path(__file__).resolve().parent / "slurm" / "slurm_collect_entrypoint.sh" + + __version__ = get_version() diff --git a/deployment/srgan_hpc/slurm.py b/deployment/srgan_hpc/slurm.py index f7c783d..bee3262 100644 --- a/deployment/srgan_hpc/slurm.py +++ b/deployment/srgan_hpc/slurm.py @@ -19,6 +19,8 @@ class SlurmJobSpec: slurm: SlurmConfig environment: EnvironmentConfig array: str | None = None + dependency: str | None = None + request_gpus: bool = True def build_sbatch_command(spec: SlurmJobSpec) -> list[str]: @@ -40,9 +42,9 @@ def build_sbatch_command(spec: SlurmJobSpec) -> list[str]: ] if spec.slurm.partition: cmd.append(f"--partition={spec.slurm.partition}") - if spec.slurm.gres: + if spec.request_gpus and spec.slurm.gres: cmd.append(f"--gres={spec.slurm.gres}") - elif spec.slurm.gpus: + elif spec.request_gpus and spec.slurm.gpus: if spec.slurm.gpu_type: cmd.append(f"--gpus={spec.slurm.gpu_type}:{spec.slurm.gpus}") else: @@ -53,6 +55,8 @@ def build_sbatch_command(spec: SlurmJobSpec) -> list[str]: cmd.append(f"--qos={spec.slurm.qos}") if spec.array: cmd.append(f"--array={spec.array}") + if spec.dependency: + cmd.append(f"--dependency={spec.dependency}") cmd.extend(spec.slurm.extra_args) cmd.append(str(spec.script_path)) cmd.append(str(spec.manifest_path)) diff --git a/deployment/srgan_hpc/submission_summary.py b/deployment/srgan_hpc/submission_summary.py index 59f782c..2877cf4 100644 --- a/deployment/srgan_hpc/submission_summary.py +++ b/deployment/srgan_hpc/submission_summary.py @@ -192,6 +192,15 @@ def format_submission_summary(summary: Mapping[str, Any]) -> str: f"Slurm: job {slurm.get('job_id', 'unknown')}, " f"array {slurm.get('array') or 'none'}" ) + collect = slurm.get("collect") + if isinstance(collect, Mapping): + if collect.get("mode") == "dry-run": + lines.append("Collect: dry-run follow-up job prepared") + else: + lines.append( + f"Collect: job {collect.get('job_id', 'unknown')} " + "after array success" + ) staging = summary.get("staging", {}) if staging: diff --git a/deployment/srgan_hpc/submit.py b/deployment/srgan_hpc/submit.py index f29af4b..305e9f3 100644 --- a/deployment/srgan_hpc/submit.py +++ b/deployment/srgan_hpc/submit.py @@ -10,6 +10,7 @@ product_edge_size, runtime_config_to_dict, ) +from deployment.srgan_hpc import bundled_slurm_collect_entrypoint from deployment.srgan_hpc.manifests import new_run_id, write_json, write_yaml from deployment.srgan_hpc.naming import patch_dir, resolve_run_dir from deployment.srgan_hpc.patching import Patch @@ -216,6 +217,30 @@ def submit_patch_run( return run_id, run_dir, submission +def _submit_collect_job( + *, + run_id: str, + run_dir: Path, + logs_dir: Path, + config: RuntimeConfig, + dependency_job_id: str | None, + dry_run: bool, +) -> Mapping[str, object]: + dependency = f"afterok:{dependency_job_id}" if dependency_job_id else None + spec = SlurmJobSpec( + job_name=f"srgan_collect_{run_id}", + script_path=bundled_slurm_collect_entrypoint().resolve(), + manifest_path=run_dir / "run_manifest.yaml", + output_path=logs_dir / "slurm_collect_%j.out", + error_path=logs_dir / "slurm_collect_%j.err", + slurm=config.slurm, + environment=config.environment, + dependency=dependency, + request_gpus=False, + ) + return submit_job(spec, run_dir / "submission" / "collect", dry_run=dry_run) + + def _submit_patch_collection( *, mode: str, @@ -338,7 +363,16 @@ def _submit_patch_collection( environment=config.environment, array=f"0-{len(tasks) - 1}" if tasks else None, ) - submission = submit_job(spec, run_dir / "submission", dry_run=dry_run) + submission = dict(submit_job(spec, run_dir / "submission", dry_run=dry_run)) + collect_submission = _submit_collect_job( + run_id=run_id, + run_dir=run_dir, + logs_dir=logs_dir, + config=config, + dependency_job_id=str(submission["job_id"]) if submission.get("job_id") else None, + dry_run=dry_run, + ) + submission["collect"] = dict(collect_submission) return run_id, run_dir, submission diff --git a/tests/test_deployment/test_deployment_slurm.py b/tests/test_deployment/test_deployment_slurm.py new file mode 100644 index 0000000..dbd1b0d --- /dev/null +++ b/tests/test_deployment/test_deployment_slurm.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from pathlib import Path + +from deployment.srgan_hpc.config import EnvironmentConfig, SlurmConfig +from deployment.srgan_hpc.slurm import SlurmJobSpec, build_sbatch_command + + +def test_build_sbatch_command_includes_dependency() -> None: + spec = SlurmJobSpec( + job_name="collect", + script_path=Path("/tmp/collect.sh"), + manifest_path=Path("/tmp/run_manifest.yaml"), + output_path=Path("/tmp/out.log"), + error_path=Path("/tmp/err.log"), + slurm=SlurmConfig(gpus=0), + environment=EnvironmentConfig(python_executable="python"), + dependency="afterok:12345", + ) + + command = build_sbatch_command(spec) + + assert "--dependency=afterok:12345" in command + + +def test_build_sbatch_command_can_skip_gpu_request() -> None: + spec = SlurmJobSpec( + job_name="collect", + script_path=Path("/tmp/collect.sh"), + manifest_path=Path("/tmp/run_manifest.yaml"), + output_path=Path("/tmp/out.log"), + error_path=Path("/tmp/err.log"), + slurm=SlurmConfig(gpu_type="A100", gpus=1), + environment=EnvironmentConfig(python_executable="python"), + request_gpus=False, + ) + + command = build_sbatch_command(spec) + + assert not any(part.startswith("--gpus=") for part in command) + assert not any(part.startswith("--gres=") for part in command) diff --git a/tests/test_deployment/test_deployment_submit.py b/tests/test_deployment/test_deployment_submit.py index 60ad6ec..51550dc 100644 --- a/tests/test_deployment/test_deployment_submit.py +++ b/tests/test_deployment/test_deployment_submit.py @@ -5,7 +5,7 @@ from deployment.srgan_hpc.config import RuntimeConfig from deployment.srgan_hpc.manifests import read_yaml from deployment.srgan_hpc.patching import Patch -from deployment.srgan_hpc.submit import submit_patch_run +from deployment.srgan_hpc.submit import submit_grid_run, submit_patch_run def test_submit_patch_dry_run_writes_product_inputs_and_manifest( @@ -41,3 +41,46 @@ def test_submit_patch_dry_run_writes_product_inputs_and_manifest( } assert (run_dir / "patches" / "patch_000001" / "inputs").is_dir() assert submission["mode"] == "dry-run" + + +def test_submit_grid_dry_run_prepares_collect_followup_job(tmp_path: Path) -> None: + config = RuntimeConfig(output_root=tmp_path / "runs", mode="fused") + patches = [ + Patch( + patch_id="patch_000001", + latitude=45.0, + longitude=9.0, + edge_size=512, + row_index=0, + row_count=1, + column_index=0, + column_count=1, + ), + Patch( + patch_id="patch_000002", + latitude=45.1, + longitude=9.1, + edge_size=512, + row_index=0, + row_count=1, + column_index=1, + column_count=2, + ), + ] + + _, run_dir, submission = submit_grid_run( + config=config, + patches=patches, + start_date="2024-01-01", + end_date="2024-01-31", + script_path=Path("/tmp/slurm_task.sh"), + dry_run=True, + ) + + assert submission["mode"] == "dry-run" + assert "--array=0-1" in str(submission["command"]) + collect = submission["collect"] + assert collect["mode"] == "dry-run" + assert "slurm_collect_entrypoint.sh" in str(collect["command"]) + assert "--array=" not in str(collect["command"]) + assert (run_dir / "submission" / "collect" / "sbatch_command.txt").is_file() From e2f923f0f266b741c71888252a40617486b9f695 Mon Sep 17 00:00:00 2001 From: Davide Date: Wed, 27 May 2026 11:05:28 +0200 Subject: [PATCH 6/7] Refactor collect_outputs to move files instead of copying and update tests accordingly --- deployment/srgan_hpc/collect.py | 8 ++++---- tests/test_deployment/test_deployment_collect.py | 6 ++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/deployment/srgan_hpc/collect.py b/deployment/srgan_hpc/collect.py index adecb9d..5e8b4f3 100644 --- a/deployment/srgan_hpc/collect.py +++ b/deployment/srgan_hpc/collect.py @@ -7,11 +7,11 @@ def collect_outputs(run_dir: Path, destination: Path | None = None) -> tuple[Path, int]: destination = destination or run_dir / "collected" destination.mkdir(parents=True, exist_ok=True) - copied = 0 + moved = 0 for tif_path in sorted(run_dir.glob("patches/*/outputs/*.tif")): patch_id = tif_path.parent.parent.name patch_destination = destination / patch_id patch_destination.mkdir(parents=True, exist_ok=True) - shutil.copy2(tif_path, patch_destination / tif_path.name) - copied += 1 - return destination, copied + shutil.move(str(tif_path), patch_destination / tif_path.name) + moved += 1 + return destination, moved diff --git a/tests/test_deployment/test_deployment_collect.py b/tests/test_deployment/test_deployment_collect.py index 32a5b89..590a69e 100644 --- a/tests/test_deployment/test_deployment_collect.py +++ b/tests/test_deployment/test_deployment_collect.py @@ -14,9 +14,11 @@ def test_collect_outputs_preserves_patch_identity_for_duplicate_product_names( output_dir.mkdir(parents=True) (output_dir / "fused_sr.tif").write_bytes(marker) - destination, copied = collect_outputs(run_dir) + destination, moved = collect_outputs(run_dir) assert destination == run_dir / "collected" - assert copied == 2 + assert moved == 2 assert (destination / "patch_000001" / "fused_sr.tif").read_bytes() == b"first" assert (destination / "patch_000002" / "fused_sr.tif").read_bytes() == b"second" + assert not (run_dir / "patches" / "patch_000001" / "outputs" / "fused_sr.tif").exists() + assert not (run_dir / "patches" / "patch_000002" / "outputs" / "fused_sr.tif").exists() From 63a9a950ce0b403eb2fabc30db96925cd2841f07 Mon Sep 17 00:00:00 2001 From: Davide Date: Wed, 27 May 2026 11:14:44 +0200 Subject: [PATCH 7/7] Add SLURM entrypoint script for HPC job collection --- .../slurm/slurm_collect_entrypoint.sh | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100755 deployment/srgan_hpc/slurm/slurm_collect_entrypoint.sh diff --git a/deployment/srgan_hpc/slurm/slurm_collect_entrypoint.sh b/deployment/srgan_hpc/slurm/slurm_collect_entrypoint.sh new file mode 100755 index 0000000..767c22b --- /dev/null +++ b/deployment/srgan_hpc/slurm/slurm_collect_entrypoint.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +set -euo pipefail + +MANIFEST_PATH="${1:?manifest path required}" +PYTHON_BIN="${SRGAN_HPC_PYTHON:-python}" +RUN_DIR="$(dirname "${MANIFEST_PATH}")" + +if [[ -n "${SRGAN_HPC_MODULES:-}" ]] && command -v module >/dev/null 2>&1; then + IFS=',' read -r -a MODULE_LIST <<< "${SRGAN_HPC_MODULES}" + for module_name in "${MODULE_LIST[@]}"; do + module load "${module_name}" + done +fi + +if [[ -n "${SRGAN_HPC_CONDA_ENV:-}" ]]; then + if command -v conda >/dev/null 2>&1; then + eval "$(conda shell.bash hook)" + conda activate "${SRGAN_HPC_CONDA_ENV}" + else + source activate "${SRGAN_HPC_CONDA_ENV}" + fi +fi + +exec "${PYTHON_BIN}" -m deployment.srgan_hpc.cli collect --run-dir "${RUN_DIR}"