From 986470ab1228d702124b16184443eacbb2d11a02 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 16 Mar 2026 15:39:22 -0700 Subject: [PATCH 01/42] Initial commit for adding NERSC IRI-API support alongside SFAPI for job submission --- orchestration/flows/bl832/nersc.py | 128 +++++++++++++++- orchestration/globus/token.py | 235 +++++++++++++++++++++++++++++ 2 files changed, 362 insertions(+), 1 deletion(-) create mode 100644 orchestration/globus/token.py diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index f4ffc9fb..a87c6c53 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field import datetime from dotenv import load_dotenv +import enum import json import logging import os @@ -16,6 +17,7 @@ from typing import Any, Optional from orchestration.flows.bl832.config import Config832 + from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController from orchestration.mlflow import get_checkpoint_info from orchestration.prune_controller import get_prune_controller, PruneMethod @@ -23,6 +25,7 @@ from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) +from orchestration.globus.token import get_access_token_confidential, DEFAULT_TOKEN_FILE from orchestration.prefect import schedule_prefect_flow logger = logging.getLogger(__name__) @@ -142,6 +145,37 @@ def _load_job_options( return {**opts, **overrides} +class NERSCLoginMethod(enum.Enum): + """Selects which NERSC API login method to use when creating a NERSC client. + + Each method corresponds to a different set of credentials and API base URL. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" + + +# Applies only to NERSCLoginMethod.IRIAPI +_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" +_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client +_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" +_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" +_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +}) + +_API_BASE_URLS: dict[NERSCLoginMethod, str] = { + NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", + NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", +} + + class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): """ Implementation for a NERSC-based tomography HPC controller. @@ -158,7 +192,99 @@ def __init__( self.client = client @staticmethod - def create_sfapi_client() -> Client: + def create_nersc_client( + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, + ) -> Client: + """Create and return a NERSC client for the requested login method. + + Two fundamentally different auth strategies are supported: + + - :attr:`NERSCLoginMethod.SFAPI`: uses an Iris-registered OAuth2 + client ID + private key (NERSC OIDC flow). Set ``PATH_NERSC_CLIENT_ID`` + and ``PATH_NERSC_PRI_KEY`` to the paths of those files. + + - :attr:`NERSCLoginMethod.IRIAPI`: uses a Globus bearer token written + by ``globus_token.py``. Set ``PATH_GLOBUS_TOKEN_FILE`` to the token + file path, or rely on the default (``~/.globus/auth_tokens.json``). + + Args: + login_method: Which NERSC API to authenticate against. + Defaults to :attr:`NERSCLoginMethod.SFAPI`. + + Returns: + An authenticated :class:`sfapi_client.Client` instance. + + Raises: + ValueError: If SFAPI credential environment variables are unset. + FileNotFoundError: If credential or token files are absent. + RuntimeError: If the Globus token is expired. + Exception: If the underlying client construction fails. + """ + logger.info(f"Creating NERSC client using login method: {login_method.value}") + api_url = _API_BASE_URLS[login_method] + logger.info(f"Targeting API base URL: {api_url}") + + if login_method is NERSCLoginMethod.SFAPI: + client = NERSCTomographyHPCController._create_sfapi_client() + + elif login_method is NERSCLoginMethod.IRIAPI: + client = NERSCTomographyHPCController._create_iriapi_client() + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") + + logger.info( + f"NERSC client created successfully " + f"(method={login_method.value}, api_url={api_url})." + ) + return client + + @staticmethod + def _create_iriapi_client() -> Client: + """Create a NERSC client for the IRI API using a Globus bearer token. + + Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the + environment. Reuses a cached token if valid; otherwise mints a new one + via the client credentials grant. No browser or user interaction. + + Returns: + An authenticated :class:`sfapi_client.Client` targeting the IRI API. + + Raises: + ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. + RuntimeError: If the acquired token is missing required scopes. + """ + client_id = os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) + client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) + + if not client_id: + raise ValueError( + f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." + ) + if not client_secret: + raise ValueError( + f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " + "A Globus Confidential App client is required for automated IRI API auth." + ) + + token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) + token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE + + access_token = get_access_token_confidential( + client_id=client_id, + client_secret=client_secret, + required_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, + resource_server=_IRIAPI_GLOBUS_RESOURCE_SERVER, + token_file=token_file, + ) + + return Client( + token=access_token, + api_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + ) + + @staticmethod + def _create_sfapi_client() -> Client: """Create and return an NERSC client instance""" # When generating the SFAPI Key in Iris, make sure to select "asldev" as the user! diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py new file mode 100644 index 00000000..81b5438f --- /dev/null +++ b/orchestration/globus/token.py @@ -0,0 +1,235 @@ +import json +import logging +import os +from pathlib import Path +import stat +import time + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +logger = logging.getLogger(__name__) + +# Default token file location, matching the Globus SDK convention. +DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" +GLOBUS_OIDC_TOKEN_URL: str = "https://auth.globus.org/v2/oauth2/token" + + +def get_access_token_confidential( + client_id: str, + client_secret: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, +) -> str: + """Get a valid Globus access token using a Confidential Client (machine-to-machine). + + No browser or user interaction required. If a valid unexpired token exists + on disk it is reused; otherwise a new one is minted via the client + credentials grant and saved. + + Args: + client_id: Globus Confidential App client ID. + client_secret: Globus Confidential App client secret. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token cache file. Defaults to + ``~/.globus/auth_tokens.json``. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + + # 1. Do we already have a valid token? + stored = load_token_file(resolved_token_file) + if stored: + expires_at = stored.get("expires_at_seconds") + if expires_at and time.time() < expires_at: + logger.info("Using cached Globus token (still valid).") + return stored["access_token"] + logger.info("Cached Globus token is expired; minting a new one.") + else: + logger.info("No cached Globus token found; minting a new one.") + + # 2. Mint a new token — same call whether first time or expired. + globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) + token_response = globus_client.oauth2_client_credentials_tokens( + requested_scopes=" ".join(sorted(required_scopes)) + ) + auth_data = token_response.by_resource_server[resource_server] + + granted = set(auth_data.get("scope", "").split()) + missing = required_scopes - granted + if missing: + raise RuntimeError( + f"Globus token is missing required scopes: {sorted(missing)}" + ) + + save_token_file(resolved_token_file, auth_data) + logger.info(f"New Globus token saved to {resolved_token_file}.") + + return auth_data["access_token"] + + +def load_token_file(token_file: Path) -> dict | None: + """Load saved Globus token data from disk. + + Args: + token_file: Path to the JSON token file. + + Returns: + Parsed token dict, or None if the file does not exist. + """ + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_token_file(token_file: Path, tokens: dict) -> None: + """Atomically save Globus token data to disk with owner-only permissions. + + Writes to a temporary file then renames to avoid partial writes. + + Args: + token_file: Destination path for the JSON token file. + tokens: Token dict to serialise. + """ + _ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + required_scopes: frozenset[str], + resource_server: str, +) -> dict: + """Run an interactive browser-based Globus login flow. + + Prints an authorization URL, waits for the user to paste an auth code, + and exchanges it for tokens. + + Args: + client: Globus NativeAppAuthClient to drive the flow. + required_scopes: Set of OAuth2 scopes to request. + resource_server: Resource server key to extract from the token response + (e.g. ``"auth.globus.org"``). + + Returns: + Token dict for the given resource server. + """ + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(required_scopes)), + refresh_tokens=True, + ) + logger.info("Open this URL in your browser to authenticate with Globus:") + logger.info(client.oauth2_get_authorize_url()) + code = input("\nEnter authorization code: ").strip() + token_response = client.oauth2_exchange_code_for_tokens(code) + return token_response.by_resource_server[resource_server] + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, + refresh_token: str, + resource_server: str, +) -> dict | None: + """Attempt a silent Globus token refresh. + + Args: + client: Globus NativeAppAuthClient to drive the refresh. + refresh_token: The stored refresh token. + resource_server: Resource server key to extract from the token response. + + Returns: + Fresh token dict for the given resource server, or None if refresh failed. + """ + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.by_resource_server[resource_server] + except GlobusAPIError as e: + logger.warning( + f"Globus token refresh failed ({e.http_status}); " + "falling back to interactive login." + ) + return None + + +def get_access_token( + client_id: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, + force_login: bool = False, +) -> str: + """Get a valid Globus access token, refreshing or logging in as needed. + + Attempts a silent refresh from the saved token file first. Falls back to + interactive browser login if no saved tokens exist, the refresh token is + absent, or the refresh fails. Saves the resulting tokens back to disk. + + Args: + client_id: Globus NativeApp client ID. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token file. Defaults to + ``~/.globus/auth_tokens.json``. + force_login: If True, skip refresh and force interactive login. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + globus_client = globus_sdk.NativeAppAuthClient(client_id) + + auth_data: dict | None = None + + if not force_login: + stored = load_token_file(resolved_token_file) + if stored and stored.get("refresh_token"): + auth_data = refresh_tokens( + globus_client, stored["refresh_token"], resource_server + ) + + if auth_data is None: + logger.info("Initiating interactive Globus login.") + auth_data = interactive_login(globus_client, required_scopes, resource_server) + + granted = set(auth_data.get("scope", "").split()) + missing = required_scopes - granted + if missing: + raise RuntimeError( + f"Globus token is missing required scopes: {sorted(missing)}" + ) + + save_token_file(resolved_token_file, auth_data) + logger.info(f"Globus token saved to {resolved_token_file}.") + + return auth_data["access_token"] + + +def _ensure_private_parent_dir(path: Path) -> None: + """Create parent directories for path with owner-only permissions. + + Args: + path: The file path whose parent directory should be created. + """ + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) From 0512e587894efb5b9c1b2b05af65adb1ef442e06 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 16 Mar 2026 15:55:45 -0700 Subject: [PATCH 02/42] Adding an abstraction for _submit_job() and _wait_for_job() that use the correct mechanism based on IRI/SF-API --- orchestration/flows/bl832/nersc.py | 68 +++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index a87c6c53..52ffcc25 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -186,10 +186,12 @@ class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin) def __init__( self, client: Client, - config: Config832 + config: Config832, + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: TomographyHPCController.__init__(self, config) self.client = client + self.login_method = login_method @staticmethod def create_nersc_client( @@ -353,6 +355,70 @@ def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelS ) return registry[key] + def _submit_job(self, job_script: str) -> str: + """Submit a Slurm job script and return the job ID. + + Dispatches to the appropriate submission mechanism based on + ``self.login_method``. + + Args: + job_script: The full Slurm batch script to submit. + + Returns: + The submitted job ID as a string. + + Raises: + RuntimeError: If job submission fails. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + return str(job.jobid) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.post( + "/api/v1/compute/job/perlmutter", + json={"script": job_script}, + ) + response.raise_for_status() + return str(response.json()["job_id"]) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _wait_for_job(self, job_id: str) -> bool: + """Block until a submitted job completes. + + Dispatches to the appropriate polling mechanism based on + ``self.login_method``. + + Args: + job_id: The job ID returned by :meth:`_submit_job`. + + Returns: + True if the job completed successfully, False otherwise. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.job(jobid=job_id) + job.complete() + return True + + elif self.login_method is NERSCLoginMethod.IRIAPI: + while True: + response = self.client.get( + f"/api/v1/compute/status/perlmutter/{job_id}" + ) + response.raise_for_status() + state = response.json().get("state") + logger.info(f"Job {job_id} state: {state}") + if state in ("COMPLETED", "FAILED", "CANCELLED", "TIMEOUT"): + return state == "COMPLETED" + time.sleep(60) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + def reconstruct( self, file_path: str = "", From fe275199eb25288f2e4eae0aeb0f1929a3d99750 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:22:49 -0700 Subject: [PATCH 03/42] moving NERSCLoginMethod(Enum) to the job_controller.py module --- orchestration/flows/bl832/job_controller.py | 26 +++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/orchestration/flows/bl832/job_controller.py b/orchestration/flows/bl832/job_controller.py index b2ff064b..1a23d02a 100644 --- a/orchestration/flows/bl832/job_controller.py +++ b/orchestration/flows/bl832/job_controller.py @@ -10,6 +10,19 @@ load_dotenv() +class NERSCLoginMethod(Enum): + """Selects which NERSC API login method to use when creating a NERSC client. + + Each method corresponds to a different set of credentials and API base URL. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" + + class TomographyHPCController(ABC): """ Abstract class for tomography HPC controllers. @@ -65,7 +78,8 @@ class HPC(Enum): def get_controller( hpc_type: HPC, - config: Config832 + config: Config832, + login_method: "NERSCLoginMethod | None" = None, ) -> TomographyHPCController: """ Factory function that returns an HPC controller instance for the given HPC environment. @@ -86,10 +100,14 @@ def get_controller( config=config ) elif hpc_type == HPC.NERSC: - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + resolved_login_method = login_method if isinstance(login_method, NERSCLoginMethod) else NERSCLoginMethod.SFAPI return NERSCTomographyHPCController( - client=NERSCTomographyHPCController.create_sfapi_client(), - config=config + client=NERSCTomographyHPCController.create_nersc_client( + login_method=resolved_login_method + ), + config=config, + login_method=resolved_login_method, ) elif hpc_type == HPC.OLCF: # TODO: Implement OLCF controller From eaf02fe8c83a87a22970a7bbe6d6e91f5f8885e3 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:25:00 -0700 Subject: [PATCH 04/42] Removed NERSCLoginMethod(Enum) from nersc.py. Created a temporary test flow for reconstruction to test job submission. In reconstruct(), replaced the SFAPI-specific job submission/polling code with the general _submit_job() and _wait_for_job() methods. --- orchestration/flows/bl832/nersc.py | 165 ++++++++++++++--------------- 1 file changed, 82 insertions(+), 83 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 52ffcc25..4e5e1c0e 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2,6 +2,7 @@ import datetime from dotenv import load_dotenv import enum +import httpx import json import logging import os @@ -17,21 +18,37 @@ from typing import Any, Optional from orchestration.flows.bl832.config import Config832 - -from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController -from orchestration.mlflow import get_checkpoint_info -from orchestration.prune_controller import get_prune_controller, PruneMethod -from orchestration.transfer_controller import globus_transfer_task +from orchestration.flows.bl832.job_controller import get_controller, HPC, NERSCLoginMethod, TomographyHPCController from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) from orchestration.globus.token import get_access_token_confidential, DEFAULT_TOKEN_FILE +from orchestration.mlflow import get_checkpoint_info from orchestration.prefect import schedule_prefect_flow +from orchestration.prune_controller import get_prune_controller, PruneMethod +from orchestration.transfer_controller import globus_transfer_task logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) load_dotenv() +# Applies only to NERSCLoginMethod.IRIAPI +_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" +_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client +_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" +_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" +_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +}) + +_API_BASE_URLS: dict[NERSCLoginMethod, str] = { + NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", + NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", +} + @dataclass class SegmentationModelSpec: @@ -158,24 +175,6 @@ class NERSCLoginMethod(enum.Enum): """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" -# Applies only to NERSCLoginMethod.IRIAPI -_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" -_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client -_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" -_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" -_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ - "openid", - "profile", - "email", - "urn:globus:auth:scope:auth.globus.org:view_identities", -}) - -_API_BASE_URLS: dict[NERSCLoginMethod, str] = { - NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", - NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", -} - - class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): """ Implementation for a NERSC-based tomography HPC controller. @@ -185,8 +184,8 @@ class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin) def __init__( self, - client: Client, config: Config832, + client: Client | httpx.Client | None = None, login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: TomographyHPCController.__init__(self, config) @@ -280,9 +279,9 @@ def _create_iriapi_client() -> Client: token_file=token_file, ) - return Client( - token=access_token, - api_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + return httpx.Client( + base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + headers={"Authorization": f"Bearer {access_token}"}, ) @staticmethod @@ -355,6 +354,28 @@ def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelS ) return registry[key] + def _get_nersc_username(self) -> str: + """Get the NERSC username for constructing pscratch paths. + + Uses the sfapi_client user endpoint for SFAPI, or reads + ``NERSC_USERNAME`` from the environment for IRIAPI. + + Returns: + NERSC username string. + + Raises: + ValueError: If IRIAPI is selected and NERSC_USERNAME is unset. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + return self.client.user().name + else: + username = os.getenv("NERSC_USERNAME") + if not username: + raise ValueError( + "NERSC_USERNAME must be set in the environment when using IRIAPI." + ) + return username + def _submit_job(self, job_script: str) -> str: """Submit a Slurm job script and return the job ID. @@ -393,7 +414,7 @@ def _wait_for_job(self, job_id: str) -> bool: ``self.login_method``. Args: - job_id: The job ID returned by :meth:`_submit_job`. + job_id: The job ID returned by `_submit_job`. Returns: True if the job completed successfully, False otherwise. @@ -433,7 +454,8 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - user = self.client.user() + # user = self.client.user() + username = self._get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path logger.info(f"{raw_path=}") @@ -447,7 +469,7 @@ def reconstruct( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -596,55 +618,23 @@ def reconstruct( echo "JOB_STATUS=SUCCESS" >> $TIMING_FILE echo "JOB_END=$(date +%s)" >> $TIMING_FILE """ + job_id = None try: - logger.info("Submitting reconstruction job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + logger.info("Submitting reconstruction job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - # Fetch timing data - timing = self._fetch_timing_data(perlmutter, pscratch_path, job.jobid) - - return { - "success": True, - "job_id": job.jobid, - "timing": timing - } - + success = self._wait_for_job(job_id) + timing = self._fetch_timing_data(pscratch_path, job_id) if success else None + return {"success": success, "job_id": job_id, "timing": timing} except Exception as e: - logger.info(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) + logger.error(f"Error during reconstruction job submission or completion: {e}") + return {"success": False, "job_id": job_id, "timing": None} - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False - - def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dict: + def _fetch_timing_data(self, pscratch_path: str, job_id: str) -> dict: """ Fetch and parse timing data from the SLURM job. - :param perlmutter: SFAPI compute object for Perlmutter :param pscratch_path: Path to the user's pscratch directory :param job_id: SLURM job ID :return: Dictionary with timing breakdown @@ -653,17 +643,26 @@ def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dic try: # Use SFAPI to read the timing file - result = perlmutter.run(f"cat {timing_file}") - - # result might be a string directly, or an object with .output - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - elif hasattr(result, 'stdout'): - output = result.stdout - else: - output = str(result) + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"cat {timing_file}") + + # result might be a string directly, or an object with .output + if isinstance(result, str): + output = result + elif hasattr(result, 'output'): + output = result.output + elif hasattr(result, 'stdout'): + output = result.stdout + else: + output = str(result) + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.get( + "/api/v1/filesystem/file/perlmutter", + params={"path": timing_file}, + ) + response.raise_for_status() + output = response.text logger.info(f"Timing file contents:\n{output}") From be2c5716a28c8de382f57318c324cfac58ee9195 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:50:35 -0700 Subject: [PATCH 05/42] Updating pytests --- orchestration/_tests/test_bl832/test_nersc.py | 309 +++++++++++++--- orchestration/_tests/test_sfapi_flow.py | 332 ++---------------- 2 files changed, 292 insertions(+), 349 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 8d7056a8..7a8ca07a 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -1,5 +1,4 @@ -# orchestration/_tests/bl832/test_nersc.py - +# orchestration/_tests/test_bl832/test_nersc.py import pytest from uuid import uuid4 @@ -20,18 +19,28 @@ def prefect_test_fixture(): yield -# ────────────────────────────────────────────────────────────────────────────── +# --------------------------------------------------------------------------- # Shared fixtures -# ────────────────────────────────────────────────────────────────────────────── +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_config(mocker): + config = mocker.MagicMock() + config.ghcr_images832 = { + "recon_image": "mock_recon_image", + "multires_image": "mock_multires_image", + } + return config + @pytest.fixture def mock_sfapi_client(mocker): - """Mock sfapi_client.Client with a completed job on Perlmutter.""" - mock_client = mocker.MagicMock() + """sfapi_client.Client mock with user, compute, submit_job, and job chained.""" + client = mocker.MagicMock() mock_user = mocker.MagicMock() mock_user.name = "testuser" - mock_client.user.return_value = mock_user + client.user.return_value = mock_user mock_job = mocker.MagicMock() mock_job.jobid = "12345" @@ -39,10 +48,9 @@ def mock_sfapi_client(mocker): mock_compute = mocker.MagicMock() mock_compute.submit_job.return_value = mock_job - mock_client.compute.return_value = mock_compute - - mocker.patch("orchestration.flows.bl832.nersc.Client", return_value=mock_client) - return mock_client + client.compute.return_value = mock_compute + mocker.patch("orchestration.flows.bl832.nersc.Client", return_value=client) + return client @pytest.fixture @@ -167,11 +175,28 @@ def _make_future(mocker, value): return f -# ────────────────────────────────────────────────────────────────────────────── -# create_sfapi_client -# ────────────────────────────────────────────────────────────────────────────── +@pytest.fixture +def mock_iriapi_client(mocker): + """httpx.Client mock for IRI API responses.""" + client = mocker.MagicMock() + + submit_response = mocker.MagicMock() + submit_response.json.return_value = {"job_id": "99999"} + client.post.return_value = submit_response + + status_response = mocker.MagicMock() + status_response.json.return_value = {"state": "COMPLETED"} + client.get.return_value = status_response + + return client + + +# --------------------------------------------------------------------------- +# _create_sfapi_client +# --------------------------------------------------------------------------- def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { @@ -179,29 +204,34 @@ def test_create_sfapi_client_success(mocker): "PATH_NERSC_PRI_KEY": "/path/to/client_secret", }.get(x)) mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) - mocker.patch("builtins.open", side_effect=[ - mocker.mock_open(read_data="client_id_value")(), - mocker.mock_open(read_data='{"key": "value"}')(), - ]) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") - client = NERSCTomographyHPCController.create_sfapi_client() + client = NERSCTomographyHPCController._create_sfapi_client() - mock_client_cls.assert_called_once_with("client_id_value", "mock_secret") - assert client == mock_client_cls.return_value + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController.create_sfapi_client() + NERSCTomographyHPCController._create_sfapi_client() def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { @@ -211,40 +241,7 @@ def test_create_sfapi_client_missing_files(mocker): mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController.create_sfapi_client() - - -# ────────────────────────────────────────────────────────────────────────────── -# reconstruct -# ────────────────────────────────────────────────────────────────────────────── - -def test_reconstruct_success(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - - result = controller.reconstruct(file_path="folder/file.h5") - - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - assert isinstance(result, dict) - assert result["success"] is True - assert result["job_id"] == "12345" - - -def test_reconstruct_submission_failure(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - - result = controller.reconstruct(file_path="folder/file.h5") - - assert result is False + NERSCTomographyHPCController._create_sfapi_client() # ────────────────────────────────────────────────────────────────────────────── @@ -386,6 +383,162 @@ def test_segmentation_dinov3_submission_failure(mocker, mock_sfapi_client, mock_ controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) result = controller.segmentation_dinov3(recon_folder_path="folder/recfile") + assert result is False + +# --------------------------------------------------------------------------- +# reconstruct — SFAPI +# --------------------------------------------------------------------------- + + +def test_reconstruct_sfapi_success(mocker, mock_sfapi_client, mock_config832): + """SFAPI reconstruct submits a job and waits for completion.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is True + assert result["job_id"] == "12345" + assert mock_sfapi_client.compute.call_count == 3 # 1 _submit_job() + 1 _wait_for_job() + 1 _fetch_timing_data() + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.job.assert_called_once_with(jobid="12345") + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() + + +def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_config832): + """SFAPI reconstruct returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("SFAPI error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is False + + +# --------------------------------------------------------------------------- +# reconstruct — IRIAPI +# --------------------------------------------------------------------------- + +def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is True + assert result["job_id"] == "99999" + mock_iriapi_client.post.assert_called_once() + assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/perlmutter" + assert "script" in mock_iriapi_client.post.call_args.kwargs["json"] + assert mock_iriapi_client.get.call_count == 2 + mock_iriapi_client.get.assert_any_call( + "/api/v1/compute/status/perlmutter/99999" + ) + mock_iriapi_client.get.assert_any_call( + "/api/v1/filesystem/file/perlmutter", + params={"path": mocker.ANY}, + ) + + +def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI reconstruct returns False when job state is FAILED.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is False + + +def test_reconstruct_iriapi_missing_username(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI reconstruct raises ValueError when NERSC_USERNAME is unset.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.delenv("NERSC_USERNAME", raising=False) + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + with pytest.raises(ValueError, match="NERSC_USERNAME"): + controller.reconstruct(file_path="folder/scan.h5") + + +# --------------------------------------------------------------------------- +# build_multi_resolution — SFAPI +# --------------------------------------------------------------------------- + +def test_build_multi_resolution_sfapi_success(mocker, mock_sfapi_client, mock_config832): + """SFAPI build_multi_resolution submits and waits successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + + +def test_build_multi_resolution_sfapi_failure(mocker, mock_sfapi_client, mock_config832): + """SFAPI build_multi_resolution returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") assert result is False @@ -794,3 +947,47 @@ def test_moon_segment_flow_no_sam3_no_combine(mocker, mock_config832, mock_recon mock_sam3_task.submit.assert_not_called() mock_combine_task.submit.assert_not_called() +# --------------------------------------------------------------------------- +# build_multi_resolution — IRIAPI +# --------------------------------------------------------------------------- + + +def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI build_multi_resolution POSTs and polls successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + mock_iriapi_client.post.assert_called_once() + mock_iriapi_client.get.assert_called_once_with( + "/api/v1/compute/status/perlmutter/99999" + ) + + +def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI build_multi_resolution returns False when job state is FAILED.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is False diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py index 6e9bf225..e0d4a854 100644 --- a/orchestration/_tests/test_sfapi_flow.py +++ b/orchestration/_tests/test_sfapi_flow.py @@ -1,8 +1,5 @@ # orchestration/_tests/test_sfapi_flow.py - -from pathlib import Path import pytest -from unittest.mock import MagicMock, patch, mock_open from uuid import uuid4 from prefect.blocks.system import Secret @@ -11,307 +8,56 @@ @pytest.fixture(autouse=True, scope="session") def prefect_test_fixture(): - """ - A pytest fixture that automatically sets up and tears down the Prefect test harness - for the entire test session. It creates and saves test secrets and configurations - required for Globus integration. - - Yields: - None - """ with prefect_test_harness(): - globus_client_id = Secret(value=str(uuid4())) - globus_client_id.save(name="globus-client-id", overwrite=True) - globus_client_secret = Secret(value=str(uuid4())) - globus_client_secret.save(name="globus-client-secret", overwrite=True) - + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) yield -# ---------------------------- -# Tests for create_sfapi_client -# ---------------------------- - - -def test_create_sfapi_client_success(): - """ - Test successful creation of the SFAPI client. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Mock data for client_id and client_secret files - mock_client_id = 'value' - mock_client_secret = '{"key": "value"}' - - # Create separate mock_open instances for each file - mock_open_client_id = mock_open(read_data=mock_client_id) - mock_open_client_secret = mock_open(read_data=mock_client_secret) - - with patch("orchestration.flows.bl832.nersc.os.getenv") as mock_getenv, \ - patch("orchestration.flows.bl832.nersc.os.path.isfile") as mock_isfile, \ - patch("builtins.open", side_effect=[ - mock_open_client_id.return_value, - mock_open_client_secret.return_value - ]), \ - patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key") as mock_import_key, \ - patch("orchestration.flows.bl832.nersc.Client") as MockClient: - - # Mock environment variables - mock_getenv.side_effect = lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - - # Mock file existence - mock_isfile.return_value = True - - # Mock JsonWebKey.import_key to return a mock secret - mock_import_key.return_value = "mock_secret" - - # Create the client - client = NERSCTomographyHPCController.create_sfapi_client() - - # Assert that Client was instantiated with 'value' and 'mock_secret' - MockClient.assert_called_once_with("value", "mock_secret") - - # Assert that the returned client is the mocked client - assert client == MockClient.return_value, "Client should be the mocked sfapi_client.Client instance" - - -def test_create_sfapi_client_missing_paths(): - """ - Test creation of the SFAPI client with missing credential paths. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - with patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None): - with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController.create_sfapi_client() - - -def test_create_sfapi_client_missing_files(): - """ - Test creation of the SFAPI client with missing credential files. - """ - with ( - # Mock environment variables - patch( - "orchestration.flows.bl832.nersc.os.getenv", - side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - ), - - # Mock file existence to simulate missing files - patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - ): - # Import the module after applying patches to ensure mocks are in place - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Expect a FileNotFoundError due to missing credential files - with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController.create_sfapi_client() - -# ---------------------------- -# Fixture for Mocking SFAPI Client -# ---------------------------- - - -@pytest.fixture -def mock_sfapi_client(): - """ - Mock the sfapi_client.Client class with necessary methods. - """ - with patch("orchestration.flows.bl832.nersc.Client") as MockClient: - mock_client_instance = MockClient.return_value - - # Mock the user method - mock_user = MagicMock() - mock_user.name = "testuser" - mock_client_instance.user.return_value = mock_user - - # Mock the compute method to return a mocked compute object - mock_compute = MagicMock() - mock_job = MagicMock() - mock_job.jobid = "12345" - mock_job.state = "COMPLETED" - mock_compute.submit_job.return_value = mock_job - mock_client_instance.compute.return_value = mock_compute - - yield mock_client_instance - - -# ---------------------------- -# Fixture for Mocking Config832 -# ---------------------------- - -@pytest.fixture -def mock_config832(): - """ - Mock the Config832 class to provide necessary configurations. - - All settings dicts must be fully populated to match the config YAML schema, - because _load_job_options() passes config_settings directly as the defaults - dict and then accesses keys by name. - """ - with patch("orchestration.flows.bl832.nersc.Config832") as MockConfig: - mock_config = MockConfig.return_value - mock_config.ghcr_images832 = { - "recon_image": "mock_recon_image", - "multires_image": "mock_multires_image", - } - mock_config.nersc_recon_settings = { - "qos": "realtime", - "account": "mock_account", - "reservation": "", - "num_nodes": 4, - "cpus-per-task": 128, - "walltime": "0:30:00", - } - mock_config.nersc_multiresolution_settings = { - "qos": "realtime", - "account": "mock_account", - "reservation": "", - "cpus-per-task": 128, - "walltime": "0:15:00", - } - mock_config.apps = {"als_transfer": "some_config"} - yield mock_config - - -# ---------------------------- -# Tests for NERSCTomographyHPCController -# ---------------------------- - -def test_reconstruct_success(mock_sfapi_client, mock_config832): - """ - Test successful reconstruction job submission. - """ +def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - - # Assert that the method returns True - assert isinstance(result, dict) - assert result["success"] is True - assert result["job_id"] == "12345" - -def test_reconstruct_submission_failure(mock_sfapi_client, mock_config832): - """ - Test reconstruction job submission failure. - """ + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) + mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") + mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") + + client = NERSCTomographyHPCController._create_sfapi_client() + + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value + + +def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) + with pytest.raises(ValueError, match="Missing NERSC credentials paths."): + NERSCTomographyHPCController._create_sfapi_client() - # Assert that the method returns False - assert result is False, "reconstruct should return False on submission failure." - -def test_build_multi_resolution_success(mock_sfapi_client, mock_config832): - """ - Test successful multi-resolution job submission. - """ +def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - - # Assert that the method returns True - assert result is True, "build_multi_resolution should return True on successful job completion." - - -def test_build_multi_resolution_submission_failure(mock_sfapi_client, mock_config832): - """ - Test multi-resolution job submission failure. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) - - # Assert that the method returns False - assert result is False, "build_multi_resolution should return False on submission failure." - - -def test_job_submission(mock_sfapi_client): - """ - Test job submission and status updates. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - mock_config = MagicMock() - mock_config.nersc_recon_settings = { - "qos": "realtime", - "account": "mock_account", - "reservation": "", - "num_nodes": 4, - "cpus-per-task": 128, - "walltime": "0:30:00", - } - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config) - file_path = "path/to/file.h5" - - # Mock Path to extract file and folder names - with patch.object(Path, 'parent', new_callable=MagicMock) as mock_parent, \ - patch.object(Path, 'stem', new_callable=MagicMock) as mock_stem: - mock_parent.name = "to" - mock_stem.return_value = "file" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - # Verify the returned job has the expected attributes - submitted_job = mock_sfapi_client.compute.return_value.submit_job.return_value - assert submitted_job.jobid == "12345", "Job ID should match the mock job ID." - assert submitted_job.state == "COMPLETED", "Job state should be COMPLETED." + with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): + NERSCTomographyHPCController._create_sfapi_client() From cf15c2041d7985f9fc2fcdfa748ea75530b8f49d Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:51:14 -0700 Subject: [PATCH 06/42] Updating multires() method to use the generic _submit_job() and _wait_for_job() helpers --- orchestration/flows/bl832/nersc.py | 85 +++++++++++++++++------------- 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 4e5e1c0e..79a4bdb0 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -7,7 +7,6 @@ import logging import os from pathlib import Path -import re import time from authlib.jose import JsonWebKey @@ -718,7 +717,8 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - user = self.client.user() + # user = self.client.user() + username = self._get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] logger.info(f"{multires_image=}") @@ -729,7 +729,7 @@ def build_multi_resolution( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -784,42 +784,53 @@ def build_multi_resolution( date """ try: - logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + logger.info("Submitting Tiff to Zarr job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - - return True - + success = self._wait_for_job(job_id) + logger.info(f"Multiresolution job {'completed' if success else 'failed'}.") + return success except Exception as e: - logger.warning(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + logger.error(f"Error during multiresolution job submission or completion: {e}") + return False + # try: + # logger.info("Submitting Tiff to Zarr job script to Perlmutter.") + # perlmutter = self.client.compute(Machine.perlmutter) + # job = perlmutter.submit_job(job_script) + # logger.info(f"Submitted job ID: {job.jobid}") + + # try: + # job.update() + # except Exception as update_err: + # logger.warning(f"Initial job update failed, continuing: {update_err}") + + # time.sleep(60) + # logger.info(f"Job {job.jobid} current state: {job.state}") + + # job.complete() # Wait until the job completes + # logger.info("Reconstruction job completed successfully.") + + # return True + + # except Exception as e: + # logger.warning(f"Error during job submission or completion: {e}") + # match = re.search(r"Job not found:\s*(\d+)", str(e)) + + # if match: + # jobid = match.group(1) + # logger.info(f"Attempting to recover job {jobid}.") + # try: + # job = self.client.perlmutter.job(jobid=jobid) + # time.sleep(30) + # job.complete() + # logger.info("Reconstruction job completed successfully after recovery.") + # return True + # except Exception as recovery_err: + # logger.error(f"Failed to recover job {jobid}: {recovery_err}") + # return False + # else: + # return False def segmentation_sam3( self, From d0e80683737337f3dee3b74b5c44e5b6ba29b405 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 30 Mar 2026 14:53:53 -0700 Subject: [PATCH 07/42] successfully ran reconstruction using the IRI-API --- orchestration/flows/bl832/nersc.py | 78 ++++-- orchestration/globus/token.py | 390 +++++++++++++++++++++-------- scripts/get_globus_token.py | 337 +++++++++++++++++++++++++ 3 files changed, 681 insertions(+), 124 deletions(-) create mode 100644 scripts/get_globus_token.py diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 79a4bdb0..2aad35de 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -7,6 +7,7 @@ import logging import os from pathlib import Path +import re import time from authlib.jose import JsonWebKey @@ -21,8 +22,12 @@ from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) -from orchestration.globus.token import get_access_token_confidential, DEFAULT_TOKEN_FILE from orchestration.mlflow import get_checkpoint_info +from orchestration.globus.token import ( + get_access_token, + DEFAULT_TOKEN_FILE, + IRI_SCOPE, +) from orchestration.prefect import schedule_prefect_flow from orchestration.prune_controller import get_prune_controller, PruneMethod from orchestration.transfer_controller import globus_transfer_task @@ -33,7 +38,9 @@ # Applies only to NERSCLoginMethod.IRIAPI _IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" -_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client +_IRI_COMPUTE_RESOURCE: str = "compute" +_IRI_SCRATCH_RESOURCE: str = "scratch" +# _IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client _IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" _IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" _IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ @@ -41,6 +48,7 @@ "profile", "email", "urn:globus:auth:scope:auth.globus.org:view_identities", + IRI_SCOPE, }) _API_BASE_URLS: dict[NERSCLoginMethod, str] = { @@ -254,33 +262,33 @@ def _create_iriapi_client() -> Client: ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. RuntimeError: If the acquired token is missing required scopes. """ - client_id = os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) - client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) + client_id = "fae5c579-490a-4d76-b6eb-d78f65caeb63" # os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) + # client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) if not client_id: raise ValueError( f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." ) - if not client_secret: - raise ValueError( - f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " - "A Globus Confidential App client is required for automated IRI API auth." - ) + # if not client_secret: + # raise ValueError( + # f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " + # "A Globus Confidential App client is required for automated IRI API auth." + # ) token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE - access_token = get_access_token_confidential( + access_token = get_access_token( client_id=client_id, - client_secret=client_secret, - required_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, - resource_server=_IRIAPI_GLOBUS_RESOURCE_SERVER, + requested_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, token_file=token_file, + force_login=False, ) return httpx.Client( base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], headers={"Authorization": f"Bearer {access_token}"}, + timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), ) @staticmethod @@ -396,12 +404,39 @@ def _submit_job(self, job_script: str) -> str: return str(job.jobid) elif self.login_method is NERSCLoginMethod.IRIAPI: + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" + + script_body = "\n".join( + line for line in job_script.splitlines() + if not line.startswith("#SBATCH") and not line.startswith("#!/") + ).strip() + + job_spec = { + "executable": "/bin/bash", + "arguments": ["-c", script_body], + "stdout_path": f"{pscratch_path}/tomo_recon_logs/iri_job.out", + "stderr_path": f"{pscratch_path}/tomo_recon_logs/iri_job.err", + "resources": { + "node_count": 1, + "processes_per_node": 1, + "cpu_cores_per_process": 64, + "exclusive_node_use": True, + }, + "attributes": { + "duration": 1800, + "queue_name": "realtime", + "account": "als", + "custom_attributes": {"constraint": "cpu"}, + }, + } + response = self.client.post( - "/api/v1/compute/job/perlmutter", - json={"script": job_script}, + f"/api/v1/compute/job/{_IRI_COMPUTE_RESOURCE}", + json=job_spec, ) response.raise_for_status() - return str(response.json()["job_id"]) + return str(response.json()["id"]) else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -427,13 +462,16 @@ def _wait_for_job(self, job_id: str) -> bool: elif self.login_method is NERSCLoginMethod.IRIAPI: while True: response = self.client.get( - f"/api/v1/compute/status/perlmutter/{job_id}" + f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/{job_id}" # ← was "perlmutter" ) response.raise_for_status() - state = response.json().get("state") + state = response.json().get("status", {}).get("state") logger.info(f"Job {job_id} state: {state}") - if state in ("COMPLETED", "FAILED", "CANCELLED", "TIMEOUT"): - return state == "COMPLETED" + if state == "completed": + return True + if state in ("failed", "canceled", "timeout"): + logger.error(f"Job {job_id} ended with state: {state}") + return False time.sleep(60) else: diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py index 81b5438f..4970eaa7 100644 --- a/orchestration/globus/token.py +++ b/orchestration/globus/token.py @@ -1,3 +1,4 @@ +# orchestration/globus/token.py import json import logging import os @@ -12,69 +13,20 @@ # Default token file location, matching the Globus SDK convention. DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" -GLOBUS_OIDC_TOKEN_URL: str = "https://auth.globus.org/v2/oauth2/token" +# IRI API Globus scope and resource server. +# The IRI access token lives in other_tokens under this scope, not at the +# top level of the auth.globus.org response. +IRI_SCOPE: str = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +IRI_RESOURCE_SERVER: str = "ed3e577d-f7f3-4639-b96e-ff5a8445d699" -def get_access_token_confidential( - client_id: str, - client_secret: str, - required_scopes: frozenset[str], - resource_server: str, - token_file: Path | None = None, -) -> str: - """Get a valid Globus access token using a Confidential Client (machine-to-machine). - - No browser or user interaction required. If a valid unexpired token exists - on disk it is reused; otherwise a new one is minted via the client - credentials grant and saved. - - Args: - client_id: Globus Confidential App client ID. - client_secret: Globus Confidential App client secret. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. - token_file: Path to the JSON token cache file. Defaults to - ``~/.globus/auth_tokens.json``. - - Returns: - A valid Globus access token string. - - Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - - # 1. Do we already have a valid token? - stored = load_token_file(resolved_token_file) - if stored: - expires_at = stored.get("expires_at_seconds") - if expires_at and time.time() < expires_at: - logger.info("Using cached Globus token (still valid).") - return stored["access_token"] - logger.info("Cached Globus token is expired; minting a new one.") - else: - logger.info("No cached Globus token found; minting a new one.") - - # 2. Mint a new token — same call whether first time or expired. - globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) - token_response = globus_client.oauth2_client_credentials_tokens( - requested_scopes=" ".join(sorted(required_scopes)) - ) - auth_data = token_response.by_resource_server[resource_server] - - granted = set(auth_data.get("scope", "").split()) - missing = required_scopes - granted - if missing: - raise RuntimeError( - f"Globus token is missing required scopes: {sorted(missing)}" - ) - - save_token_file(resolved_token_file, auth_data) - logger.info(f"New Globus token saved to {resolved_token_file}.") - - return auth_data["access_token"] +# --------------------------------------------------------------------------- +# File I/O +# --------------------------------------------------------------------------- def load_token_file(token_file: Path) -> dict | None: """Load saved Globus token data from disk. @@ -112,105 +64,345 @@ def save_token_file(token_file: Path, tokens: dict) -> None: os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) +def _ensure_private_parent_dir(path: Path) -> None: + """Create parent directories for path with owner-only permissions. + + Args: + path: The file path whose parent directory should be created. + """ + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +# --------------------------------------------------------------------------- +# IRI token helpers +# --------------------------------------------------------------------------- + +def _parse_scope_string(scope_string: str) -> set[str]: + """Split a space-separated scope string into a set. + + Args: + scope_string: Space-separated OAuth2 scope string. + + Returns: + Set of individual scope strings. + """ + return set(scope_string.split()) if scope_string else set() + + +def extract_iri_token(token_response_data: dict) -> dict: + """Extract the IRI access token entry from a Globus token response. + + The IRI token is not returned at the top level — it lives inside + ``other_tokens``, identified by :data:`IRI_SCOPE`. + + Args: + token_response_data: Full token response dict as returned by the + Globus SDK (i.e. ``token_response.data``). + + Returns: + Token dict for the IRI resource server. + + Raises: + RuntimeError: If no token matching the IRI scope is found. + """ + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError( + f"Missing token for required IRI scope: {IRI_SCOPE}. " + "Re-run with --force-login and ensure consent is granted for the IRI scope." + ) + + +def _replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + """Return a copy of token_response_data with the IRI entry replaced. + + Args: + token_response_data: Full stored token response dict. + iri_token_data: Updated IRI token dict to splice in. + + Returns: + Updated token response dict. + """ + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for i, token_data in enumerate(other_tokens): + if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): + other_tokens[i] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def _get_iri_refresh_token(stored_tokens: dict) -> str | None: + """Extract the IRI refresh token from stored token data, if present. + + Args: + stored_tokens: Full stored token response dict. + + Returns: + The IRI refresh token string, or None if absent. + """ + try: + return extract_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def _get_auth_refresh_token(stored_tokens: dict) -> str | None: + """Extract the top-level Globus Auth refresh token from stored data. + + Args: + stored_tokens: Full stored token response dict. + + Returns: + The auth refresh token string, or None if absent. + """ + if "refresh_token" in stored_tokens: + return stored_tokens["refresh_token"] + auth_tokens = stored_tokens.get("auth.globus.org") + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + return None + + +# --------------------------------------------------------------------------- +# NativeApp flow (interactive) +# --------------------------------------------------------------------------- + def interactive_login( client: globus_sdk.NativeAppAuthClient, - required_scopes: frozenset[str], - resource_server: str, + requested_scopes: frozenset[str], + prompt_login: bool = False, ) -> dict: """Run an interactive browser-based Globus login flow. Prints an authorization URL, waits for the user to paste an auth code, - and exchanges it for tokens. + and returns the full token response data including ``other_tokens``. Args: client: Globus NativeAppAuthClient to drive the flow. - required_scopes: Set of OAuth2 scopes to request. - resource_server: Resource server key to extract from the token response - (e.g. ``"auth.globus.org"``). + requested_scopes: Set of OAuth2 scopes to request. Should include + :data:`IRI_SCOPE` to obtain an IRI API token. + prompt_login: If True, add ``prompt=login`` to the authorize URL to + force a fresh identity-provider login. Returns: - Token dict for the given resource server. + Full token response dict (``token_response.data``), including + ``other_tokens``. + + Raises: + RuntimeError: If no authorization code is entered, or if the code + exchange fails. """ client.oauth2_start_flow( - requested_scopes=" ".join(sorted(required_scopes)), + requested_scopes=" ".join(sorted(requested_scopes)), refresh_tokens=True, ) logger.info("Open this URL in your browser to authenticate with Globus:") - logger.info(client.oauth2_get_authorize_url()) + prompt = "login" if prompt_login else globus_sdk.MISSING + logger.info(client.oauth2_get_authorize_url(prompt=prompt)) code = input("\nEnter authorization code: ").strip() - token_response = client.oauth2_exchange_code_for_tokens(code) - return token_response.by_resource_server[resource_server] + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the " + "code shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as e: + if e.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed — the code was empty, " + "invalid, expired, or already used. Re-run and try again." + ) from e + raise RuntimeError( + f"Authorization code exchange failed with HTTP {e.http_status}." + ) from e + return token_response.data -def refresh_tokens( +def _refresh_single_token( client: globus_sdk.NativeAppAuthClient, refresh_token: str, - resource_server: str, ) -> dict | None: - """Attempt a silent Globus token refresh. + """Attempt a single Globus token refresh, returning raw response data. Args: - client: Globus NativeAppAuthClient to drive the refresh. + client: NativeAppAuthClient to drive the refresh. refresh_token: The stored refresh token. - resource_server: Resource server key to extract from the token response. Returns: - Fresh token dict for the given resource server, or None if refresh failed. + Raw token response data dict, or None if the refresh failed. """ try: token_response = client.oauth2_refresh_token(refresh_token) - return token_response.by_resource_server[resource_server] + return token_response.data except GlobusAPIError as e: logger.warning( f"Globus token refresh failed ({e.http_status}); " - "falling back to interactive login." + "will fall back to interactive login." ) return None +def _refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, + stored_tokens: dict, +) -> tuple[dict | None, bool]: + """Try to refresh stored tokens, preferring the IRI refresh token. + + Attempts the IRI-specific refresh token first, then falls back to the + top-level Globus Auth refresh token. + + Args: + client: NativeAppAuthClient to drive the refresh. + stored_tokens: Full stored token response dict. + + Returns: + Tuple of ``(updated_token_data, success)``. On failure both values + are ``(None, False)``. + """ + iri_refresh = _get_iri_refresh_token(stored_tokens) + if iri_refresh: + iri_token_data = _refresh_single_token(client, iri_refresh) + if iri_token_data is not None: + return _replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh = _get_auth_refresh_token(stored_tokens) + if auth_refresh: + auth_data = _refresh_single_token(client, auth_refresh) + if auth_data is not None: + return auth_data, True + + return None, False + + def get_access_token( client_id: str, - required_scopes: frozenset[str], - resource_server: str, + requested_scopes: frozenset[str], token_file: Path | None = None, force_login: bool = False, + prompt_login: bool = False, ) -> str: - """Get a valid Globus access token, refreshing or logging in as needed. + """Get a valid IRI API access token via the NativeApp interactive flow. Attempts a silent refresh from the saved token file first. Falls back to interactive browser login if no saved tokens exist, the refresh token is absent, or the refresh fails. Saves the resulting tokens back to disk. + The IRI token is extracted from ``other_tokens`` in the response — it is + not the top-level Globus Auth token. + Args: client_id: Globus NativeApp client ID. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. + requested_scopes: Set of OAuth2 scopes to request. Must include + :data:`IRI_SCOPE` to obtain a usable IRI API token. token_file: Path to the JSON token file. Defaults to ``~/.globus/auth_tokens.json``. force_login: If True, skip refresh and force interactive login. + prompt_login: If True, add ``prompt=login`` to the authorize URL. Returns: - A valid Globus access token string. + A valid IRI API access token string. Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. + RuntimeError: If the IRI scope token is missing from the response. """ resolved_token_file = token_file or DEFAULT_TOKEN_FILE globus_client = globus_sdk.NativeAppAuthClient(client_id) - auth_data: dict | None = None + token_response_data: dict | None = None + used_refresh = False if not force_login: stored = load_token_file(resolved_token_file) - if stored and stored.get("refresh_token"): - auth_data = refresh_tokens( - globus_client, stored["refresh_token"], resource_server + if stored: + token_response_data, used_refresh = _refresh_stored_tokens( + globus_client, stored ) - if auth_data is None: + if token_response_data is None: logger.info("Initiating interactive Globus login.") - auth_data = interactive_login(globus_client, required_scopes, resource_server) + token_response_data = interactive_login( + globus_client, requested_scopes, prompt_login=prompt_login + ) + + # Extract IRI token — if a refresh ran but didn't return the IRI token, + # fall back to interactive login before raising. + try: + iri_token = extract_iri_token(token_response_data) + except RuntimeError: + if used_refresh: + logger.warning( + "Refreshed tokens did not include the IRI token; " + "falling back to interactive login." + ) + token_response_data = interactive_login( + globus_client, requested_scopes, prompt_login=prompt_login + ) + iri_token = extract_iri_token(token_response_data) + else: + raise + + save_token_file(resolved_token_file, token_response_data) + logger.info(f"Globus token saved to {resolved_token_file}.") + + return iri_token["access_token"] + + +# --------------------------------------------------------------------------- +# Confidential Client flow (machine-to-machine) +# --------------------------------------------------------------------------- + +def get_access_token_confidential( + client_id: str, + client_secret: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, +) -> str: + """Get a valid Globus access token using a Confidential Client. + + No browser or user interaction required. If a valid unexpired token exists + on disk it is reused; otherwise a new one is minted via the client + credentials grant and saved. + + Args: + client_id: Globus Confidential App client ID. + client_secret: Globus Confidential App client secret. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token cache file. Defaults to + ``~/.globus/auth_tokens.json``. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + + stored = load_token_file(resolved_token_file) + if stored: + expires_at = stored.get("expires_at_seconds") + if expires_at and time.time() < expires_at: + logger.info("Using cached Globus token (still valid).") + return stored["access_token"] + logger.info("Cached Globus token is expired; minting a new one.") + else: + logger.info("No cached Globus token found; minting a new one.") + + globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) + token_response = globus_client.oauth2_client_credentials_tokens( + requested_scopes=" ".join(sorted(required_scopes)) + ) + auth_data = token_response.by_resource_server[resource_server] granted = set(auth_data.get("scope", "").split()) missing = required_scopes - granted @@ -220,16 +412,6 @@ def get_access_token( ) save_token_file(resolved_token_file, auth_data) - logger.info(f"Globus token saved to {resolved_token_file}.") + logger.info(f"New Globus token saved to {resolved_token_file}.") return auth_data["access_token"] - - -def _ensure_private_parent_dir(path: Path) -> None: - """Create parent directories for path with owner-only permissions. - - Args: - path: The file path whose parent directory should be created. - """ - path.parent.mkdir(parents=True, exist_ok=True) - os.chmod(path.parent, 0o700) diff --git a/scripts/get_globus_token.py b/scripts/get_globus_token.py new file mode 100644 index 00000000..6b615378 --- /dev/null +++ b/scripts/get_globus_token.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import stat +import time +import urllib.error +import urllib.request +from pathlib import Path + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" +RESOURCE_SERVER = "auth.globus.org" +IRI_SCOPE = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +REQUIRED_SCOPES = { + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +} +REQUESTED_SCOPES = REQUIRED_SCOPES | {IRI_SCOPE} +DEFAULT_IRI_VALIDATE_URL = "https://api.iri.nersc.gov/api/v1/account/projects" + + +def parse_args() -> argparse.Namespace: + default_token_file = Path.home() / ".globus" / "auth_tokens.json" + parser = argparse.ArgumentParser( + description=( + "Get Globus Auth tokens with required scopes. " + "Tokens are saved to a secure local file by default." + ) + ) + parser.add_argument( + "--token-file", + type=Path, + default=default_token_file, + help=f"Path for saved token JSON (default: {default_token_file})", + ) + parser.add_argument( + "--print-token", + action="store_true", + help="Print the access token to stdout (off by default).", + ) + parser.add_argument( + "--force-login", + action="store_true", + help="Skip refresh and force interactive browser login.", + ) + parser.add_argument( + "--refresh-only", + action="store_true", + help="Refresh saved tokens only; do not fall back to interactive login.", + ) + parser.add_argument( + "--prompt-login", + action="store_true", + help="Add prompt=login to the Globus authorize URL to force re-authentication.", + ) + parser.add_argument( + "--validate-iri", + action="store_true", + help="Validate the IRI token by calling the IRI account/projects endpoint.", + ) + parser.add_argument( + "--iri-validate-url", + default=DEFAULT_IRI_VALIDATE_URL, + help=( + "IRI endpoint used by --validate-iri " + f"(default: {DEFAULT_IRI_VALIDATE_URL})" + ), + ) + return parser.parse_args() + + +def parse_scope_string(scope_string: str) -> set[str]: + return set(scope_string.split()) if scope_string else set() + + +def ensure_private_parent_dir(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +def load_tokens(token_file: Path) -> dict | None: + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_tokens(token_file: Path, tokens: dict) -> None: + ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def get_refresh_token(stored_tokens: dict) -> str | None: + if "refresh_token" in stored_tokens: + return stored_tokens.get("refresh_token") + + auth_tokens = stored_tokens.get(RESOURCE_SERVER) + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + + return None + + +def get_iri_token(token_response_data: dict) -> dict: + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError(f"Missing token for required IRI scope: {IRI_SCOPE}") + + +def get_iri_refresh_token(stored_tokens: dict) -> str | None: + try: + return get_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for index, token_data in enumerate(other_tokens): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + other_tokens[index] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def validate_auth_data(auth_data: dict) -> dict: + if auth_data.get("resource_server") != RESOURCE_SERVER: + raise RuntimeError( + f"Missing token for required resource server: {RESOURCE_SERVER}" + ) + + granted = parse_scope_string(auth_data.get("scope", "")) + missing = REQUIRED_SCOPES - granted + if missing: + raise RuntimeError(f"Missing required scopes: {sorted(missing)}") + + return get_iri_token(auth_data) + + +def validate_iri_token(iri_token_data: dict, validate_url: str) -> dict | list: + request = urllib.request.Request( + validate_url, + headers={ + "accept": "application/json", + "Authorization": f"Bearer {iri_token_data['access_token']}", + }, + method="GET", + ) + try: + with urllib.request.urlopen(request) as response: + body = response.read().decode("utf-8") + data = json.loads(body) if body else {} + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8") + details = body.strip() or exc.reason + raise RuntimeError( + f"IRI validation failed with HTTP {exc.code} from {validate_url}: {details}" + ) from exc + except urllib.error.URLError as exc: + raise RuntimeError( + f"IRI validation request failed for {validate_url}: {exc.reason}" + ) from exc + except json.JSONDecodeError as exc: + raise RuntimeError( + f"IRI validation returned non-JSON data from {validate_url}" + ) from exc + + if isinstance(data, dict): + session_info = data.get("session_info") + if isinstance(session_info, dict): + authentications = session_info.get("authentications") + if isinstance(authentications, dict) and not authentications: + raise RuntimeError( + "IRI validation succeeded but session_info.authentications is empty. " + "Re-run with --force-login --prompt-login and use a Chrome incognito window." + ) + + return data + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + *, + prompt_login: bool = False, +) -> dict: + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(REQUESTED_SCOPES)), + refresh_tokens=True, + ) + print("Open this URL, login, and consent:") + prompt = "login" if prompt_login else globus_sdk.MISSING + print(client.oauth2_get_authorize_url(prompt=prompt)) + code = input("\nEnter authorization code: ").strip() + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the code " + "shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as exc: + if exc.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed. The code was empty, invalid, " + "expired, or already used. Re-run the script and complete the " + "Globus login flow again." + ) from exc + raise RuntimeError( + f"Authorization code exchange failed with HTTP {exc.http_status}. " + "Re-run the script and try again." + ) from exc + return token_response.data + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, refresh_token: str +) -> dict | None: + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.data + except GlobusAPIError as exc: + print( + f"Refresh failed ({exc.http_status}); switching to interactive login." + ) + return None + + +def refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, stored_tokens: dict +) -> tuple[dict | None, bool]: + iri_refresh_token = get_iri_refresh_token(stored_tokens) + if iri_refresh_token: + iri_token_data = refresh_tokens(client, iri_refresh_token) + if iri_token_data is not None: + return replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh_token = get_refresh_token(stored_tokens) + if auth_refresh_token: + auth_data = refresh_tokens(client, auth_refresh_token) + if auth_data is not None: + return auth_data, True + + return None, False + + +def main() -> None: + args = parse_args() + if args.force_login and args.refresh_only: + raise RuntimeError("Choose only one of --force-login or --refresh-only") + + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + + auth_data = None + used_refresh = False + if not args.force_login: + stored = load_tokens(args.token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + + if auth_data is None: + if args.refresh_only: + raise RuntimeError( + "Refresh-only mode failed. No usable saved refresh token was found " + "or token refresh did not return the required IRI token." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + print( + "Refreshed tokens did not include the IRI token; " + "switching to interactive login." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + + save_tokens(args.token_file, auth_data) + + if args.validate_iri: + validation_data = validate_iri_token(iri_token_data, args.iri_validate_url) + print(f"IRI validation succeeded against {args.iri_validate_url}") + if isinstance(validation_data, dict): + session_info = validation_data.get("session_info") + if isinstance(session_info, dict): + session_id = session_info.get("session_id") + if session_id: + print(f"IRI session_id: {session_id}") + elif isinstance(validation_data, list): + print(f"IRI validation response items: {len(validation_data)}") + + expires_at = iri_token_data.get("expires_at_seconds") + if expires_at: + ttl = int(expires_at - time.time()) + print(f"\nIRI access token valid for ~{max(ttl, 0)} seconds.") + + print(f"Saved token data to {args.token_file}") + print(f"Granted Globus Auth scopes: {auth_data.get('scope', '')}") + print(f"IRI token resource server: {iri_token_data.get('resource_server')}") + print(f"IRI token scopes: {iri_token_data.get('scope', '')}") + + if args.print_token: + print("\nIRI access token:") + print(iri_token_data["access_token"]) + else: + print( + "IRI access token not printed " + "(use --print-token to display it for the NERSC IRI API)." + ) + + +if __name__ == "__main__": + main() From 6b8c843424bdfd4995416f5ef8cc32bb69a3a467 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:21:14 -0700 Subject: [PATCH 08/42] removing token.py and moving the logic to get_globus_token.py --- orchestration/globus/token.py | 417 ---------------------------------- 1 file changed, 417 deletions(-) delete mode 100644 orchestration/globus/token.py diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py deleted file mode 100644 index 4970eaa7..00000000 --- a/orchestration/globus/token.py +++ /dev/null @@ -1,417 +0,0 @@ -# orchestration/globus/token.py -import json -import logging -import os -from pathlib import Path -import stat -import time - -import globus_sdk -from globus_sdk.exc import GlobusAPIError - -logger = logging.getLogger(__name__) - -# Default token file location, matching the Globus SDK convention. -DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" - -# IRI API Globus scope and resource server. -# The IRI access token lives in other_tokens under this scope, not at the -# top level of the auth.globus.org response. -IRI_SCOPE: str = ( - "https://auth.globus.org/scopes/" - "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" -) -IRI_RESOURCE_SERVER: str = "ed3e577d-f7f3-4639-b96e-ff5a8445d699" - - -# --------------------------------------------------------------------------- -# File I/O -# --------------------------------------------------------------------------- - -def load_token_file(token_file: Path) -> dict | None: - """Load saved Globus token data from disk. - - Args: - token_file: Path to the JSON token file. - - Returns: - Parsed token dict, or None if the file does not exist. - """ - if not token_file.exists(): - return None - with token_file.open("r", encoding="utf-8") as f: - return json.load(f) - - -def save_token_file(token_file: Path, tokens: dict) -> None: - """Atomically save Globus token data to disk with owner-only permissions. - - Writes to a temporary file then renames to avoid partial writes. - - Args: - token_file: Destination path for the JSON token file. - tokens: Token dict to serialise. - """ - _ensure_private_parent_dir(token_file) - tmp = token_file.with_suffix(".tmp") - with os.fdopen( - os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), - "w", - encoding="utf-8", - ) as f: - json.dump(tokens, f, indent=2) - os.replace(tmp, token_file) - os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) - - -def _ensure_private_parent_dir(path: Path) -> None: - """Create parent directories for path with owner-only permissions. - - Args: - path: The file path whose parent directory should be created. - """ - path.parent.mkdir(parents=True, exist_ok=True) - os.chmod(path.parent, 0o700) - - -# --------------------------------------------------------------------------- -# IRI token helpers -# --------------------------------------------------------------------------- - -def _parse_scope_string(scope_string: str) -> set[str]: - """Split a space-separated scope string into a set. - - Args: - scope_string: Space-separated OAuth2 scope string. - - Returns: - Set of individual scope strings. - """ - return set(scope_string.split()) if scope_string else set() - - -def extract_iri_token(token_response_data: dict) -> dict: - """Extract the IRI access token entry from a Globus token response. - - The IRI token is not returned at the top level — it lives inside - ``other_tokens``, identified by :data:`IRI_SCOPE`. - - Args: - token_response_data: Full token response dict as returned by the - Globus SDK (i.e. ``token_response.data``). - - Returns: - Token dict for the IRI resource server. - - Raises: - RuntimeError: If no token matching the IRI scope is found. - """ - for token_data in token_response_data.get("other_tokens", []): - if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): - return token_data - raise RuntimeError( - f"Missing token for required IRI scope: {IRI_SCOPE}. " - "Re-run with --force-login and ensure consent is granted for the IRI scope." - ) - - -def _replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: - """Return a copy of token_response_data with the IRI entry replaced. - - Args: - token_response_data: Full stored token response dict. - iri_token_data: Updated IRI token dict to splice in. - - Returns: - Updated token response dict. - """ - merged = dict(token_response_data) - other_tokens = list(merged.get("other_tokens", [])) - for i, token_data in enumerate(other_tokens): - if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): - other_tokens[i] = iri_token_data - break - else: - other_tokens.append(iri_token_data) - merged["other_tokens"] = other_tokens - return merged - - -def _get_iri_refresh_token(stored_tokens: dict) -> str | None: - """Extract the IRI refresh token from stored token data, if present. - - Args: - stored_tokens: Full stored token response dict. - - Returns: - The IRI refresh token string, or None if absent. - """ - try: - return extract_iri_token(stored_tokens).get("refresh_token") - except RuntimeError: - return None - - -def _get_auth_refresh_token(stored_tokens: dict) -> str | None: - """Extract the top-level Globus Auth refresh token from stored data. - - Args: - stored_tokens: Full stored token response dict. - - Returns: - The auth refresh token string, or None if absent. - """ - if "refresh_token" in stored_tokens: - return stored_tokens["refresh_token"] - auth_tokens = stored_tokens.get("auth.globus.org") - if isinstance(auth_tokens, dict): - return auth_tokens.get("refresh_token") - return None - - -# --------------------------------------------------------------------------- -# NativeApp flow (interactive) -# --------------------------------------------------------------------------- - -def interactive_login( - client: globus_sdk.NativeAppAuthClient, - requested_scopes: frozenset[str], - prompt_login: bool = False, -) -> dict: - """Run an interactive browser-based Globus login flow. - - Prints an authorization URL, waits for the user to paste an auth code, - and returns the full token response data including ``other_tokens``. - - Args: - client: Globus NativeAppAuthClient to drive the flow. - requested_scopes: Set of OAuth2 scopes to request. Should include - :data:`IRI_SCOPE` to obtain an IRI API token. - prompt_login: If True, add ``prompt=login`` to the authorize URL to - force a fresh identity-provider login. - - Returns: - Full token response dict (``token_response.data``), including - ``other_tokens``. - - Raises: - RuntimeError: If no authorization code is entered, or if the code - exchange fails. - """ - client.oauth2_start_flow( - requested_scopes=" ".join(sorted(requested_scopes)), - refresh_tokens=True, - ) - logger.info("Open this URL in your browser to authenticate with Globus:") - prompt = "login" if prompt_login else globus_sdk.MISSING - logger.info(client.oauth2_get_authorize_url(prompt=prompt)) - code = input("\nEnter authorization code: ").strip() - if not code: - raise RuntimeError( - "No authorization code entered. Re-run the script and paste the " - "code shown by Globus after login." - ) - try: - token_response = client.oauth2_exchange_code_for_tokens(code) - except GlobusAPIError as e: - if e.http_status == 400: - raise RuntimeError( - "Authorization code exchange failed — the code was empty, " - "invalid, expired, or already used. Re-run and try again." - ) from e - raise RuntimeError( - f"Authorization code exchange failed with HTTP {e.http_status}." - ) from e - return token_response.data - - -def _refresh_single_token( - client: globus_sdk.NativeAppAuthClient, - refresh_token: str, -) -> dict | None: - """Attempt a single Globus token refresh, returning raw response data. - - Args: - client: NativeAppAuthClient to drive the refresh. - refresh_token: The stored refresh token. - - Returns: - Raw token response data dict, or None if the refresh failed. - """ - try: - token_response = client.oauth2_refresh_token(refresh_token) - return token_response.data - except GlobusAPIError as e: - logger.warning( - f"Globus token refresh failed ({e.http_status}); " - "will fall back to interactive login." - ) - return None - - -def _refresh_stored_tokens( - client: globus_sdk.NativeAppAuthClient, - stored_tokens: dict, -) -> tuple[dict | None, bool]: - """Try to refresh stored tokens, preferring the IRI refresh token. - - Attempts the IRI-specific refresh token first, then falls back to the - top-level Globus Auth refresh token. - - Args: - client: NativeAppAuthClient to drive the refresh. - stored_tokens: Full stored token response dict. - - Returns: - Tuple of ``(updated_token_data, success)``. On failure both values - are ``(None, False)``. - """ - iri_refresh = _get_iri_refresh_token(stored_tokens) - if iri_refresh: - iri_token_data = _refresh_single_token(client, iri_refresh) - if iri_token_data is not None: - return _replace_iri_token(stored_tokens, iri_token_data), True - - auth_refresh = _get_auth_refresh_token(stored_tokens) - if auth_refresh: - auth_data = _refresh_single_token(client, auth_refresh) - if auth_data is not None: - return auth_data, True - - return None, False - - -def get_access_token( - client_id: str, - requested_scopes: frozenset[str], - token_file: Path | None = None, - force_login: bool = False, - prompt_login: bool = False, -) -> str: - """Get a valid IRI API access token via the NativeApp interactive flow. - - Attempts a silent refresh from the saved token file first. Falls back to - interactive browser login if no saved tokens exist, the refresh token is - absent, or the refresh fails. Saves the resulting tokens back to disk. - - The IRI token is extracted from ``other_tokens`` in the response — it is - not the top-level Globus Auth token. - - Args: - client_id: Globus NativeApp client ID. - requested_scopes: Set of OAuth2 scopes to request. Must include - :data:`IRI_SCOPE` to obtain a usable IRI API token. - token_file: Path to the JSON token file. Defaults to - ``~/.globus/auth_tokens.json``. - force_login: If True, skip refresh and force interactive login. - prompt_login: If True, add ``prompt=login`` to the authorize URL. - - Returns: - A valid IRI API access token string. - - Raises: - RuntimeError: If the IRI scope token is missing from the response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - globus_client = globus_sdk.NativeAppAuthClient(client_id) - - token_response_data: dict | None = None - used_refresh = False - - if not force_login: - stored = load_token_file(resolved_token_file) - if stored: - token_response_data, used_refresh = _refresh_stored_tokens( - globus_client, stored - ) - - if token_response_data is None: - logger.info("Initiating interactive Globus login.") - token_response_data = interactive_login( - globus_client, requested_scopes, prompt_login=prompt_login - ) - - # Extract IRI token — if a refresh ran but didn't return the IRI token, - # fall back to interactive login before raising. - try: - iri_token = extract_iri_token(token_response_data) - except RuntimeError: - if used_refresh: - logger.warning( - "Refreshed tokens did not include the IRI token; " - "falling back to interactive login." - ) - token_response_data = interactive_login( - globus_client, requested_scopes, prompt_login=prompt_login - ) - iri_token = extract_iri_token(token_response_data) - else: - raise - - save_token_file(resolved_token_file, token_response_data) - logger.info(f"Globus token saved to {resolved_token_file}.") - - return iri_token["access_token"] - - -# --------------------------------------------------------------------------- -# Confidential Client flow (machine-to-machine) -# --------------------------------------------------------------------------- - -def get_access_token_confidential( - client_id: str, - client_secret: str, - required_scopes: frozenset[str], - resource_server: str, - token_file: Path | None = None, -) -> str: - """Get a valid Globus access token using a Confidential Client. - - No browser or user interaction required. If a valid unexpired token exists - on disk it is reused; otherwise a new one is minted via the client - credentials grant and saved. - - Args: - client_id: Globus Confidential App client ID. - client_secret: Globus Confidential App client secret. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. - token_file: Path to the JSON token cache file. Defaults to - ``~/.globus/auth_tokens.json``. - - Returns: - A valid Globus access token string. - - Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - - stored = load_token_file(resolved_token_file) - if stored: - expires_at = stored.get("expires_at_seconds") - if expires_at and time.time() < expires_at: - logger.info("Using cached Globus token (still valid).") - return stored["access_token"] - logger.info("Cached Globus token is expired; minting a new one.") - else: - logger.info("No cached Globus token found; minting a new one.") - - globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) - token_response = globus_client.oauth2_client_credentials_tokens( - requested_scopes=" ".join(sorted(required_scopes)) - ) - auth_data = token_response.by_resource_server[resource_server] - - granted = set(auth_data.get("scope", "").split()) - missing = required_scopes - granted - if missing: - raise RuntimeError( - f"Globus token is missing required scopes: {sorted(missing)}" - ) - - save_token_file(resolved_token_file, auth_data) - logger.info(f"New Globus token saved to {resolved_token_file}.") - - return auth_data["access_token"] From 27ea5b2f0a5cc97e9d901507cbc8e35f1283f2f1 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:21:55 -0700 Subject: [PATCH 09/42] moving get_globus_token.py to orchestration/globus/ to be used as a module --- .../globus}/get_globus_token.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) rename {scripts => orchestration/globus}/get_globus_token.py (84%) diff --git a/scripts/get_globus_token.py b/orchestration/globus/get_globus_token.py similarity index 84% rename from scripts/get_globus_token.py rename to orchestration/globus/get_globus_token.py index 6b615378..c47057e8 100644 --- a/scripts/get_globus_token.py +++ b/orchestration/globus/get_globus_token.py @@ -11,6 +11,7 @@ import globus_sdk from globus_sdk.exc import GlobusAPIError +DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" RESOURCE_SERVER = "auth.globus.org" IRI_SCOPE = ( @@ -264,6 +265,52 @@ def refresh_stored_tokens( return None, False +def get_iri_access_token( + token_file: Path = DEFAULT_TOKEN_FILE, + force_login: bool = False, + prompt_login: bool = False, +) -> str: + """ + Get a valid IRI access token, refreshing or prompting for login as needed. + Tokens are saved to the specified token_file path (default: ~/.globus/auth_tokens.json). + By default, the function will attempt to refresh saved tokens before falling back + to interactive login. Use force_login=True to skip refresh and require interactive login. + Use prompt_login=True to add prompt=login to the authorization URL, which forces + re-authentication even if the user has an active Globus session in their browser. + + Args: + token_file: Path to save and load token data (default: ~/.globus/auth_tokens.json) + force_login: If True, skip token refresh and require interactive login + prompt_login: If True, add prompt=login to the authorization URL to force re-authentication + + Returns: + A valid IRI access token string with the required scopes. + + Raises: + RuntimeError: If token refresh fails and interactive login is not allowed or fails, + or if the resulting tokens do not include a valid IRI access token. + """ + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + auth_data = None + used_refresh = False + if not force_login: + stored = load_tokens(token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + if auth_data is None: + auth_data = interactive_login(client, prompt_login=prompt_login) + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + auth_data = interactive_login(client, prompt_login=prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + save_tokens(token_file, auth_data) + return iri_token_data["access_token"] + + def main() -> None: args = parse_args() if args.force_login and args.refresh_only: From bad1db503eed89a669ffc075d9e89a3bf0cbbf9c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:22:56 -0700 Subject: [PATCH 10/42] Cleaning up nersc.py --- orchestration/flows/bl832/nersc.py | 115 ++++++++++++++--------------- 1 file changed, 55 insertions(+), 60 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 2aad35de..629fac1b 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -23,10 +23,9 @@ NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) from orchestration.mlflow import get_checkpoint_info -from orchestration.globus.token import ( - get_access_token, +from orchestration.globus.get_globus_token import ( + get_iri_access_token, DEFAULT_TOKEN_FILE, - IRI_SCOPE, ) from orchestration.prefect import schedule_prefect_flow from orchestration.prune_controller import get_prune_controller, PruneMethod @@ -39,17 +38,7 @@ # Applies only to NERSCLoginMethod.IRIAPI _IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" _IRI_COMPUTE_RESOURCE: str = "compute" -_IRI_SCRATCH_RESOURCE: str = "scratch" -# _IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client _IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" -_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" -_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ - "openid", - "profile", - "email", - "urn:globus:auth:scope:auth.globus.org:view_identities", - IRI_SCOPE, -}) _API_BASE_URLS: dict[NERSCLoginMethod, str] = { NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", @@ -263,26 +252,19 @@ def _create_iriapi_client() -> Client: RuntimeError: If the acquired token is missing required scopes. """ client_id = "fae5c579-490a-4d76-b6eb-d78f65caeb63" # os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) - # client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) if not client_id: raise ValueError( f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." ) - # if not client_secret: - # raise ValueError( - # f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " - # "A Globus Confidential App client is required for automated IRI API auth." - # ) token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE - access_token = get_access_token( - client_id=client_id, - requested_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, + access_token = get_iri_access_token( token_file=token_file, force_login=False, + prompt_login=False ) return httpx.Client( @@ -832,43 +814,6 @@ def build_multi_resolution( except Exception as e: logger.error(f"Error during multiresolution job submission or completion: {e}") return False - # try: - # logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - # perlmutter = self.client.compute(Machine.perlmutter) - # job = perlmutter.submit_job(job_script) - # logger.info(f"Submitted job ID: {job.jobid}") - - # try: - # job.update() - # except Exception as update_err: - # logger.warning(f"Initial job update failed, continuing: {update_err}") - - # time.sleep(60) - # logger.info(f"Job {job.jobid} current state: {job.state}") - - # job.complete() # Wait until the job completes - # logger.info("Reconstruction job completed successfully.") - - # return True - - # except Exception as e: - # logger.warning(f"Error during job submission or completion: {e}") - # match = re.search(r"Job not found:\s*(\d+)", str(e)) - - # if match: - # jobid = match.group(1) - # logger.info(f"Attempting to recover job {jobid}.") - # try: - # job = self.client.perlmutter.job(jobid=jobid) - # time.sleep(30) - # job.complete() - # logger.info("Reconstruction job completed successfully after recovery.") - # return True - # except Exception as recovery_err: - # logger.error(f"Failed to recover job {jobid}: {recovery_err}") - # return False - # else: - # return False def segmentation_sam3( self, @@ -1847,7 +1792,8 @@ def nersc_recon_flow( logger.info(f"Starting NERSC reconstruction flow for {file_path=}") controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=NERSCLoginMethod.SFAPI ) logger.info("NERSC reconstruction controller initialized") @@ -2433,6 +2379,55 @@ def nersc_moon_segment_flow( return False +@flow(name="nersc_recon_test_iriapi_flow", flow_run_name="nersc_recon-{file_path}") +def nersc_recon_test_iriapi_flow( + file_path: str, + config: Optional[Config832] = None, +) -> bool: + """ + Perform tomography reconstruction on NERSC. + + :param file_path: Path to the file to reconstruct. + :param config: Configuration object (if None, a default Config832 will be created). + :return: True if successful, False otherwise. + """ + logger.info(f"Starting NERSC reconstruction flow for {file_path=}") + controller = get_controller( + hpc_type=HPC.NERSC, + config=config, + login_method=NERSCLoginMethod.IRIAPI + ) + logger.info("NERSC reconstruction controller initialized") + + nersc_reconstruction_success = controller.reconstruct( + file_path=file_path, + ) + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + + nersc_multi_res_success = controller.build_multi_resolution( + file_path=file_path, + ) + logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + tiff_file_path = f"{folder_name}/rec{file_name}" + zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + + logger.info(f"{tiff_file_path=}") + logger.info(f"{zarr_file_path=}") + + # Transfers and pruning omitted from test flow. + + # TODO: Ingest into SciCat + if nersc_reconstruction_success: + return True + else: + return False + + @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( walltime: datetime.timedelta = datetime.timedelta(minutes=5), From da163416f2ed1236be56393210ed5f88d8e6dbb5 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:28:37 -0700 Subject: [PATCH 11/42] cleaning up old commented code --- orchestration/flows/bl832/nersc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 629fac1b..e1d49fa9 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -473,7 +473,6 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - # user = self.client.user() username = self._get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path @@ -737,7 +736,6 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - # user = self.client.user() username = self._get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] From d1d65ad162516a39544bf47cc5bef1a0af0d638d Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:34:53 -0700 Subject: [PATCH 12/42] Updating unit tests --- orchestration/_tests/test_bl832/test_nersc.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 7a8ca07a..5b0f1a3d 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -181,20 +181,20 @@ def mock_iriapi_client(mocker): client = mocker.MagicMock() submit_response = mocker.MagicMock() - submit_response.json.return_value = {"job_id": "99999"} + submit_response.json.return_value = {"id": "99999"} client.post.return_value = submit_response status_response = mocker.MagicMock() - status_response.json.return_value = {"state": "COMPLETED"} + status_response.json.return_value = {"status": {"state": "completed"}} client.get.return_value = status_response return client - # --------------------------------------------------------------------------- # _create_sfapi_client # --------------------------------------------------------------------------- + def test_create_sfapi_client_success(mocker): """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController @@ -472,7 +472,7 @@ def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config83 monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} # was {"state": "FAILED"} controller = NERSCTomographyHPCController( client=mock_iriapi_client, @@ -970,17 +970,17 @@ def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_ assert result is True mock_iriapi_client.post.assert_called_once() mock_iriapi_client.get.assert_called_once_with( - "/api/v1/compute/status/perlmutter/99999" + "/api/v1/compute/status/compute/99999" ) def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config, monkeypatch): - """IRIAPI build_multi_resolution returns False when job state is FAILED.""" + """IRIAPI build_multi_resolution returns False when job state is failed.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} controller = NERSCTomographyHPCController( client=mock_iriapi_client, From dda78c534c35636eb40c8db8f12434ff1fa13e19 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 7 Apr 2026 11:20:19 -0700 Subject: [PATCH 13/42] updating login script --- scripts/login_to_globus_and_prefect.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/login_to_globus_and_prefect.sh b/scripts/login_to_globus_and_prefect.sh index a38b629f..b8f60bde 100755 --- a/scripts/login_to_globus_and_prefect.sh +++ b/scripts/login_to_globus_and_prefect.sh @@ -18,4 +18,7 @@ export GLOBUS_CLI_CLIENT_SECRET="$GLOBUS_CLIENT_SECRET" export GLOBUS_COMPUTE_CLIENT_ID="$GLOBUS_CLIENT_ID" export GLOBUS_COMPUTE_CLIENT_SECRET="$GLOBUS_CLIENT_SECRET" export PREFECT_API_KEY="$PREFECT_API_KEY" -export PREFECT_API_URL="$PREFECT_API_URL" \ No newline at end of file +export PREFECT_API_URL="$PREFECT_API_URL" +export NERSC_USERNAME="$NERSC_USERNAME" +export PATH_NERSC_CLIENT_ID="$PATH_NERSC_CLIENT_ID" +export PATH_NERSC_PRI_KEY="$PATH_NERSC_PRI_KEY" \ No newline at end of file From 596106aad41d18ebffdef03cf5db6acc80801a42 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 7 Apr 2026 14:29:40 -0700 Subject: [PATCH 14/42] Rebasing and including segmentation flows as part of iri/sfapi abstraction --- orchestration/_tests/test_bl832/test_nersc.py | 39 ++- orchestration/flows/bl832/nersc.py | 286 +++++++----------- 2 files changed, 130 insertions(+), 195 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 5b0f1a3d..3ba9742f 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -23,16 +23,6 @@ def prefect_test_fixture(): # Shared fixtures # --------------------------------------------------------------------------- -@pytest.fixture -def mock_config(mocker): - config = mocker.MagicMock() - config.ghcr_images832 = { - "recon_image": "mock_recon_image", - "multires_image": "mock_multires_image", - } - return config - - @pytest.fixture def mock_sfapi_client(mocker): """sfapi_client.Client mock with user, compute, submit_job, and job chained.""" @@ -257,9 +247,10 @@ def test_build_multi_resolution_success(mocker, mock_sfapi_client, mock_config83 result = controller.build_multi_resolution(file_path="folder/file.h5") - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert result is True @@ -292,7 +283,7 @@ def test_segmentation_sam3_success(mocker, mock_sfapi_client, mock_config832): mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert isinstance(result, dict) assert result["success"] is True assert result["job_id"] == "12345" @@ -370,7 +361,7 @@ def test_segmentation_dinov3_success(mocker, mock_sfapi_client, mock_config832): mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert result is True @@ -454,11 +445,15 @@ def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, assert result["success"] is True assert result["job_id"] == "99999" mock_iriapi_client.post.assert_called_once() - assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/perlmutter" - assert "script" in mock_iriapi_client.post.call_args.kwargs["json"] + assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/compute" + posted_json = mock_iriapi_client.post.call_args.kwargs["json"] + assert posted_json["executable"] == "/bin/bash" + assert posted_json["arguments"][0] == "-c" + assert isinstance(posted_json["arguments"][1], str) # the script body + assert "tomo_recon" in posted_json["arguments"][1] # sanity check it's the right script assert mock_iriapi_client.get.call_count == 2 mock_iriapi_client.get.assert_any_call( - "/api/v1/compute/status/perlmutter/99999" + "/api/v1/compute/status/compute/99999" ) mock_iriapi_client.get.assert_any_call( "/api/v1/filesystem/file/perlmutter", @@ -589,7 +584,7 @@ def test_combine_segmentations_success(mocker, mock_sfapi_client, mock_config832 mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert result is True @@ -952,7 +947,7 @@ def test_moon_segment_flow_no_sam3_no_combine(mocker, mock_config832, mock_recon # --------------------------------------------------------------------------- -def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_config, monkeypatch): +def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): """IRIAPI build_multi_resolution POSTs and polls successfully.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod @@ -961,7 +956,7 @@ def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_ controller = NERSCTomographyHPCController( client=mock_iriapi_client, - config=mock_config, + config=mock_config832, login_method=NERSCLoginMethod.IRIAPI, ) @@ -974,7 +969,7 @@ def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_ ) -def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config, monkeypatch): +def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config832, monkeypatch): """IRIAPI build_multi_resolution returns False when job state is failed.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod @@ -984,7 +979,7 @@ def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_ controller = NERSCTomographyHPCController( client=mock_iriapi_client, - config=mock_config, + config=mock_config832, login_method=NERSCLoginMethod.IRIAPI, ) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index e1d49fa9..c2d5f031 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -459,6 +459,55 @@ def _wait_for_job(self, job_id: str) -> bool: else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + def _mkdir_remote(self, path: str) -> None: + """Create a directory on Perlmutter remotely. + + Args: + path: Absolute path to create. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + perlmutter.run(f"mkdir -p {path}") + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.post( + "/api/v1/filesystem/mkdir/perlmutter", + json={"path": path, "parents": True}, + ) + response.raise_for_status() + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _read_remote_file(self, path: str) -> str: + """Read a remote file on Perlmutter and return its contents. + + Args: + path: Absolute path to the file on Perlmutter. + + Returns: + File contents as a string. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"cat {path}") + if isinstance(result, str): + return result + elif hasattr(result, 'output'): + return result.output + elif hasattr(result, 'stdout'): + return result.stdout + return str(result) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.get( + "/api/v1/filesystem/file/perlmutter", + params={"path": path}, + ) + response.raise_for_status() + return response.text + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + def reconstruct( self, file_path: str = "", @@ -660,27 +709,7 @@ def _fetch_timing_data(self, pscratch_path: str, job_id: str) -> dict: timing_file = f"{pscratch_path}/tomo_recon_logs/timing_{job_id}.txt" try: - # Use SFAPI to read the timing file - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - result = perlmutter.run(f"cat {timing_file}") - - # result might be a string directly, or an object with .output - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - elif hasattr(result, 'stdout'): - output = result.stdout - else: - output = str(result) - elif self.login_method is NERSCLoginMethod.IRIAPI: - response = self.client.get( - "/api/v1/filesystem/file/perlmutter", - params={"path": timing_file}, - ) - response.raise_for_status() - output = response.text + output = self._read_remote_file(timing_file) logger.info(f"Timing file contents:\n{output}") @@ -823,8 +852,8 @@ def segmentation_sam3( """ logger.info("Starting NERSC segmentation process (inference_v6).") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" opts = _load_job_options( variable_name="nersc-segmentation-options", @@ -1014,34 +1043,21 @@ def segmentation_sam3( """ try: - logger.info("Submitting segmentation job to Perlmutter (v6).") - perlmutter = self.client.compute(Machine.perlmutter) + logger.info("Submitting segmentation job to Perlmutter.") # Ensure directories exist logger.info("Creating necessary directories...") - perlmutter.run(f"mkdir -p {pscratch_path}/tomo_seg_logs") - perlmutter.run(f"mkdir -p {output_dir}") + self._mkdir_remote(f"{pscratch_path}/tomo_seg_logs") + self._mkdir_remote(output_dir) # Submit job - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - # Initial update - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - - # Wait briefly before polling + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - # Wait for completion - job.complete() + success = self._wait_for_job(job_id) logger.info("Segmentation job completed successfully.") - # Fetch timing data from output file - timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, job.jobid, job_name) + timing = self._fetch_seg_timing_from_output(pscratch_path, job_id, job_name) if timing: logger.info("=" * 60) @@ -1055,43 +1071,21 @@ def segmentation_sam3( logger.info("=" * 60) return { - "success": True, - "job_id": job.jobid, + "success": success, + "job_id": job_id, "timing": timing, - "output_dir": output_dir + "output_dir": output_dir, } except Exception as e: logger.error(f"Error during segmentation job: {e}") import traceback logger.error(traceback.format_exc()) - - # Attempt recovery - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Segmentation job completed after recovery.") - - timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, jobid, job_name) - return { - "success": True, - "job_id": jobid, - "timing": timing, - "output_dir": output_dir - } - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return { "success": False, "job_id": None, "timing": None, - "output_dir": None + "output_dir": None, } def segmentation_dinov3( @@ -1109,8 +1103,8 @@ def segmentation_dinov3( """ logger.info("Starting NERSC DINOv3 segmentation process.") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" # Load from config spec = self._get_segmentation_spec("dinov3", project) @@ -1245,39 +1239,15 @@ def segmentation_dinov3( """ try: logger.info("Submitting DINOv3 segmentation job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("DINOv3 segmentation job completed successfully.") - return True - + success = self._wait_for_job(job_id) + logger.info(f"DINOv3 segmentation job {'completed successfully' if success else 'failed'}.") + return success except Exception as e: logger.error(f"Error during DINOv3 segmentation job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("DINOv3 segmentation job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + return False def combine_segmentations( self, @@ -1285,7 +1255,7 @@ def combine_segmentations( ) -> bool: """ Run CPU-based combination of SAM3+DINOv3 segmentation results - at NERSC Perlmutter via SFAPI Slurm job. + at NERSC Perlmutter via Slurm job. :param recon_folder_path: Relative path to the reconstructed data folder, e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' @@ -1293,8 +1263,8 @@ def combine_segmentations( """ logger.info("Starting NERSC segmentation combination process.") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" opts = _load_job_options( "nersc-combine-seg-options", self.config.nersc_combine_segmentation_settings @@ -1393,45 +1363,20 @@ def combine_segmentations( """ try: logger.info("Submitting segmentation combination job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("Segmentation combination job completed successfully.") - return True - + success = self._wait_for_job(job_id) + logger.info(f"Segmentation combination job {'completed successfully' if success else 'failed'}.") + return success except Exception as e: logger.error(f"Error during segmentation combination job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Segmentation combination job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + return False - def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: str, job_name: str) -> dict: + def _fetch_seg_timing_from_output(self, pscratch_path: str, job_id: str, job_name: str) -> dict: """ Fetch and parse timing data from the SLURM output file. - :param perlmutter: SFAPI compute object for Perlmutter :param pscratch_path: Path to the user's pscratch directory :param job_id: SLURM job ID :param job_name: Job name for finding output file @@ -1440,18 +1385,7 @@ def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: output_file = f"{pscratch_path}/tomo_seg_logs/{job_name}_{job_id}.out" try: - # Use SFAPI to read the output file - result = perlmutter.run(f"cat {output_file}") - - # Handle different result types - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - elif hasattr(result, 'stdout'): - output = result.stdout - else: - output = str(result) + output = self._read_remote_file(output_file) logger.info("Job output file contents (last 50 lines):") lines = output.strip().split('\n') @@ -1528,8 +1462,8 @@ def pull_shifter_image( """ logger.info("Starting Shifter image pull.") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" if image is None: image = self.config.ghcr_images832["recon_image"] @@ -1576,24 +1510,16 @@ def pull_shifter_image( try: logger.info("Submitting Shifter image pull job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") if wait: - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - time.sleep(30) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("Shifter image pull completed successfully.") - return True + success = self._wait_for_job(job_id) + logger.info(f"Shifter image pull {'completed successfully' if success else 'failed'}.") + return success else: - logger.info(f"Job submitted. Check status with job ID: {job.jobid}") + logger.info(f"Job submitted. Check status with job ID: {job_id}") return True except Exception as e: @@ -1616,17 +1542,31 @@ def check_shifter_image( image = self.config.ghcr_images832["recon_image"] try: - perlmutter = self.client.compute(Machine.perlmutter) - # Run shifterimg images command - result = perlmutter.run(f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"") + if self.login_method is NERSCLoginMethod.SFAPI: + # synchronous via utilities/command + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"") + output = result if isinstance(result, str) else getattr(result, 'output', str(result)) - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - else: - output = str(result) + elif self.login_method is NERSCLoginMethod.IRIAPI: + # async: submit job → wait → read stdout file + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" + output_file = f"{pscratch_path}/tomo_recon_logs/shifter_check.txt" + check_script = f"""#!/bin/bash + #SBATCH -q debug + #SBATCH -A als + #SBATCH -C cpu + #SBATCH -N 1 + #SBATCH --ntasks=1 + #SBATCH --cpus-per-task=1 + #SBATCH --time=0:05:00 + shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true + """ + job_id = self._submit_job(check_script) + self._wait_for_job(job_id) + output = self._read_remote_file(output_file) if output.strip(): logger.info(f"Image found in Shifter cache: {output.strip()}") From 9da5e6e30a2d9fe5fc3a816be4d4e92b0e95480f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 13:59:39 -0700 Subject: [PATCH 15/42] commenting out petiole segmentation prune block for now, while testing --- orchestration/flows/bl832/nersc.py | 122 ++++++++++++++--------------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index c2d5f031..13345343 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2047,67 +2047,67 @@ def nersc_petiole_segment_flow( ) # ── STEP 6: Pruning ─────────────────────────────────────────────────────── - logger.info("Scheduling file pruning tasks.") - prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) - - try: - prune_controller.prune( - file_path=f"{folder_name}/{path.name}", - source_endpoint=config.nersc832_alsdev_pscratch_raw, - check_endpoint=None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule raw data pruning: {e}") - - if nersc_reconstruction_success: - try: - prune_controller.prune( - file_path=scratch_path_tiff, - source_endpoint=config.nersc832_alsdev_pscratch_scratch, - check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule reconstruction data pruning: {e}") - - if any_seg_success: - try: - prune_controller.prune( - file_path=scratch_path_segment, - source_endpoint=config.nersc832_alsdev_pscratch_scratch, - check_endpoint=config.data832_scratch if any([ - data832_sam3_transfer_success, - data832_dinov3_transfer_success, - ]) else None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule segmentation data pruning: {e}") - - if data832_tiff_transfer_success: - try: - prune_controller.prune( - file_path=scratch_path_tiff, - source_endpoint=config.data832_scratch, - check_endpoint=None, - days_from_now=30.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule data832 tiff pruning: {e}") - - if any([data832_sam3_transfer_success, - data832_dinov3_transfer_success, - data832_combined_transfer_success]): - try: - prune_controller.prune( - file_path=scratch_path_segment, - source_endpoint=config.data832_scratch, - check_endpoint=None, - days_from_now=30.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule data832 segment pruning: {e}") + # logger.info("Scheduling file pruning tasks.") + # prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) + + # try: + # prune_controller.prune( + # file_path=f"{folder_name}/{path.name}", + # source_endpoint=config.nersc832_alsdev_pscratch_raw, + # check_endpoint=None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule raw data pruning: {e}") + + # if nersc_reconstruction_success: + # try: + # prune_controller.prune( + # file_path=scratch_path_tiff, + # source_endpoint=config.nersc832_alsdev_pscratch_scratch, + # check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule reconstruction data pruning: {e}") + + # if any_seg_success: + # try: + # prune_controller.prune( + # file_path=scratch_path_segment, + # source_endpoint=config.nersc832_alsdev_pscratch_scratch, + # check_endpoint=config.data832_scratch if any([ + # data832_sam3_transfer_success, + # data832_dinov3_transfer_success, + # ]) else None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule segmentation data pruning: {e}") + + # if data832_tiff_transfer_success: + # try: + # prune_controller.prune( + # file_path=scratch_path_tiff, + # source_endpoint=config.data832_scratch, + # check_endpoint=None, + # days_from_now=30.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule data832 tiff pruning: {e}") + + # if any([data832_sam3_transfer_success, + # data832_dinov3_transfer_success, + # data832_combined_transfer_success]): + # try: + # prune_controller.prune( + # file_path=scratch_path_segment, + # source_endpoint=config.data832_scratch, + # check_endpoint=None, + # days_from_now=30.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule data832 segment pruning: {e}") if nersc_reconstruction_success and any_seg_success: logger.info("NERSC reconstruction + multi-segmentation flow completed successfully.") From ef227af6598ae1042fd8119e2f462f1e6519ecfc Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 14:43:44 -0700 Subject: [PATCH 16/42] Making reconstruction run as a task --- orchestration/flows/bl832/nersc.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 13345343..85e85bed 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1873,7 +1873,6 @@ def nersc_petiole_segment_flow( logger.info(f"Reconstructed TIFFs will be at: {scratch_path_tiff}") logger.info(f"Segmented output will be at: {scratch_path_segment}") - controller = get_controller(hpc_type=HPC.NERSC, config=config) logger.info("NERSC controller initialized") if num_nodes is None: @@ -1894,9 +1893,10 @@ def nersc_petiole_segment_flow( # ── STEP 1: Multinode Reconstruction ───────────────────────────────────── logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") - recon_result = controller.reconstruct( + recon_result = nersc_reconstruction_task( file_path=file_path, - num_nodes=num_nodes + num_nodes=num_nodes, + config=config, ) if isinstance(recon_result, dict): @@ -2427,6 +2427,30 @@ def pull_shifter_image_flow( return success +@task(name="nersc_reconstruction_task") +def nersc_reconstruction_task( + file_path: str, + num_nodes: int = 4, + config: Optional[Config832] = None, +) -> dict: + """ + Run tomography reconstruction at NERSC Perlmutter. + + :param file_path: Path to the raw HDF5 file to reconstruct. + :param num_nodes: Number of nodes to use for reconstruction. + :param config: Configuration object for the flow. + :return: Dict with keys 'success', 'job_id', 'timing'. + """ + logger = get_run_logger() + if config is None: + config = Config832() + + logger.info("Initializing NERSC Tomography HPC Controller.") + controller = get_controller(hpc_type=HPC.NERSC, config=config) + logger.info(f"Starting NERSC reconstruction task for {file_path=}") + return controller.reconstruct(file_path=file_path, num_nodes=num_nodes) + + @task(name="nersc_multiresolution_task") def nersc_multiresolution_task( file_path: str, From b4558bef7389ca6eaab99524961932e6ac2e882c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 14:50:03 -0700 Subject: [PATCH 17/42] Making IRIAPI the default login method for now --- orchestration/flows/bl832/nersc.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 85e85bed..29062308 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1847,6 +1847,7 @@ def nersc_petiole_segment_flow( file_path: str, config: Optional[Config832] = None, num_nodes: Optional[int] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Transfer raw data to NERSC, run reconstruction, then run SAM3 and DINOv3 @@ -1897,6 +1898,7 @@ def nersc_petiole_segment_flow( file_path=file_path, num_nodes=num_nodes, config=config, + login_method=login_method ) if isinstance(recon_result, dict): @@ -1950,10 +1952,10 @@ def nersc_petiole_segment_flow( logger.info("Submitting SAM3 and DINOv3 segmentation tasks concurrently.") sam3_future = nersc_segmentation_sam3_task.submit( - recon_folder_path=scratch_path_tiff, config=config + recon_folder_path=scratch_path_tiff, config=config, login_method=login_method ) dinov3_future = nersc_segmentation_dinov3_task.submit( - recon_folder_path=scratch_path_tiff, config=config, project="petiole" + recon_folder_path=scratch_path_tiff, config=config, project="petiole", login_method=login_method ) # ── STEP 4: Transfer each model's output as it completes ───────────────── @@ -1999,7 +2001,7 @@ def nersc_petiole_segment_flow( logger.info("Running segmentation combination.") combine_future = nersc_combine_segmentations_task.submit( - recon_folder_path=scratch_path_tiff, config=config + recon_folder_path=scratch_path_tiff, config=config, login_method=login_method ) combine_success = combine_future.result() @@ -2432,6 +2434,7 @@ def nersc_reconstruction_task( file_path: str, num_nodes: int = 4, config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> dict: """ Run tomography reconstruction at NERSC Perlmutter. @@ -2446,7 +2449,7 @@ def nersc_reconstruction_task( config = Config832() logger.info("Initializing NERSC Tomography HPC Controller.") - controller = get_controller(hpc_type=HPC.NERSC, config=config) + controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) logger.info(f"Starting NERSC reconstruction task for {file_path=}") return controller.reconstruct(file_path=file_path, num_nodes=num_nodes) @@ -2455,6 +2458,7 @@ def nersc_reconstruction_task( def nersc_multiresolution_task( file_path: str, config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Run multiresolution task at NERSC. @@ -2472,7 +2476,8 @@ def nersc_multiresolution_task( logger.info("Initializing NERSC Tomography HPC Controller.") tomography_controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=login_method ) logger.info(f"Starting NERSC multiresolution task for {file_path=}") nersc_multiresolution_success = tomography_controller.build_multi_resolution( @@ -2507,6 +2512,7 @@ def nersc_multiresolution_integration_test() -> bool: def nersc_segmentation_sam3_task( recon_folder_path: str, config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Run segmentation task at NERSC. @@ -2524,7 +2530,8 @@ def nersc_segmentation_sam3_task( logger.info("Initializing NERSC Tomography HPC Controller.") tomography_controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=login_method ) logger.info(f"Starting NERSC segmentation task for {recon_folder_path=}") nersc_segmentation_success = tomography_controller.segmentation_sam3( @@ -2545,12 +2552,13 @@ def nersc_segmentation_dinov3_task( recon_folder_path: str, config: Optional[Config832] = None, project: Optional[str] = "petiole", + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: logger = get_run_logger() if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) logger.info(f"Starting NERSC DINOv3 segmentation task for {recon_folder_path=}, {project=}") success = tomography_controller.segmentation_dinov3(recon_folder_path=recon_folder_path, project=project) if not success: @@ -2564,12 +2572,13 @@ def nersc_segmentation_dinov3_task( def nersc_combine_segmentations_task( recon_folder_path: str, config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: logger = get_run_logger() if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) logger.info(f"Starting NERSC combine segmentations task for {recon_folder_path=}") success = tomography_controller.combine_segmentations(recon_folder_path=recon_folder_path) if not success: @@ -2591,7 +2600,8 @@ def nersc_segmentation_sam3_integration_test() -> bool: recon_folder_path = 'synaps-i/rec20211222_125057_petiole4' # 'test' # flow_success = nersc_segmentation_sam3_task( recon_folder_path=recon_folder_path, - config=Config832() + config=Config832(), + login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Flow success: {flow_success}") return flow_success From 241c889f91ffdb0743f0fecedfc8ac7ca413423d Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 19:09:44 -0700 Subject: [PATCH 18/42] adjusting queue name and account --- orchestration/flows/bl832/nersc.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 29062308..32e370ca 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -365,7 +365,7 @@ def _get_nersc_username(self) -> str: ) return username - def _submit_job(self, job_script: str) -> str: + def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: """Submit a Slurm job script and return the job ID. Dispatches to the appropriate submission mechanism based on @@ -373,6 +373,7 @@ def _submit_job(self, job_script: str) -> str: Args: job_script: The full Slurm batch script to submit. + num_nodes: The number of nodes to request for the job. Returns: The submitted job ID as a string. @@ -400,15 +401,15 @@ def _submit_job(self, job_script: str) -> str: "stdout_path": f"{pscratch_path}/tomo_recon_logs/iri_job.out", "stderr_path": f"{pscratch_path}/tomo_recon_logs/iri_job.err", "resources": { - "node_count": 1, + "node_count": num_nodes, "processes_per_node": 1, "cpu_cores_per_process": 64, "exclusive_node_use": True, }, "attributes": { "duration": 1800, - "queue_name": "realtime", - "account": "als", + "queue_name": "regular", # change to dynamic + "account": "dabramov", # change to dynamic "custom_attributes": {"constraint": "cpu"}, }, } From c9e7b14c330b2506caebd42faff03b00c70f053e Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 19:21:45 -0700 Subject: [PATCH 19/42] Making the IRI job submission read sbatch settings --- orchestration/flows/bl832/nersc.py | 34 +++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 32e370ca..01f66dd4 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -387,8 +387,23 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: return str(job.jobid) elif self.login_method is NERSCLoginMethod.IRIAPI: - username = self._get_nersc_username() - pscratch_path = f"/pscratch/sd/{username[0]}/{username}" + # Parse SBATCH directives before stripping them + sbatch_values = {} + for line in job_script.splitlines(): + if line.startswith("#SBATCH"): + if "-q " in line: + sbatch_values["queue_name"] = line.split("-q ")[-1].strip() + elif "-A " in line: + sbatch_values["account"] = line.split("-A ")[-1].strip() + elif "--time=" in line: + t = line.split("--time=")[-1].strip() + # convert HH:MM:SS to seconds + parts = t.split(":") + sbatch_values["duration"] = int(parts[0])*3600 + int(parts[1])*60 + int(parts[2]) + elif "-N " in line: + sbatch_values["node_count"] = int(line.split("-N ")[-1].strip()) + elif "-C " in line: + sbatch_values["constraint"] = line.split("-C ")[-1].strip() script_body = "\n".join( line for line in job_script.splitlines() @@ -398,22 +413,21 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: job_spec = { "executable": "/bin/bash", "arguments": ["-c", script_body], - "stdout_path": f"{pscratch_path}/tomo_recon_logs/iri_job.out", - "stderr_path": f"{pscratch_path}/tomo_recon_logs/iri_job.err", "resources": { - "node_count": num_nodes, + "node_count": sbatch_values.get("node_count", 1), "processes_per_node": 1, "cpu_cores_per_process": 64, "exclusive_node_use": True, }, "attributes": { - "duration": 1800, - "queue_name": "regular", # change to dynamic - "account": "dabramov", # change to dynamic - "custom_attributes": {"constraint": "cpu"}, + "duration": sbatch_values.get("duration", 1800), + "queue_name": sbatch_values.get("queue_name", "realtime"), + "account": sbatch_values.get("account", "als"), + "custom_attributes": { + "constraint": sbatch_values.get("constraint", "cpu") + }, }, } - response = self.client.post( f"/api/v1/compute/job/{_IRI_COMPUTE_RESOURCE}", json=job_spec, From 698d243cd74bd7754e6e2e23d993ce48c490281a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 14 Apr 2026 14:39:58 -0700 Subject: [PATCH 20/42] Switching to debug queue/2 nodes for the IRI demo --- config.yml | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/config.yml b/config.yml index ea576b54..2d6d2ab3 100644 --- a/config.yml +++ b/config.yml @@ -178,15 +178,15 @@ hpc_submission_settings832: # ── RECON + MULTIRES SETTINGS ─────────────────────────────────────────────── nersc_reconstruction: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: realtime + qos: debug account: als - reservation: "_CAP_TOMO_MOON_CPU" - num_nodes: 16 + reservation: "" + num_nodes: 2 cpus-per-task: 128 walltime: "0:30:00" nersc_multiresolution: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: realtime + qos: debug account: als reservation: "" cpus-per-task: 128 @@ -195,15 +195,15 @@ hpc_submission_settings832: # ── PETIOLE SEGMENTATION SETTINGS ─────────────────────────────────────────── nersc_segmentation_sam3: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: regular + qos: debug account: als constraint: gpu reservation: "" - num_nodes: 4 + num_nodes: 2 ntasks-per-node: 1 gpus-per-node: 4 cpus-per-task: 128 - walltime: "00:59:00" + walltime: "00:30:00" # ── Inference parameters ────────────────────────────────────────────────── script_name: "src/inference_v6.py" batch_size: 1 @@ -226,16 +226,16 @@ hpc_submission_settings832: finetuned_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/checkpoint_v6.pt nersc_segmentation_dinov3: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: regular + qos: debug account: als constraint: gpu reservation: "" - num_nodes: 4 + num_nodes: 2 ntasks-per-node: 1 nproc_per_node: 4 gpus-per-node: 4 cpus-per-task: 128 - walltime: "00:59:00" + walltime: "00:30:00" # ── Inference parameters ────────────────────────────────────────────────── script_name: "src.inference_dino_v1" batch_size: 4 @@ -246,14 +246,14 @@ hpc_submission_settings832: dino_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt nersc_combine_segmentations: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: regular + qos: debug account: als constraint: cpu reservation: "" - num_nodes: 4 + num_nodes: 2 ntasks: 128 cpus-per-task: 1 - walltime: "01:00:00" + walltime: "00:30:00" # ── Combination parameters ──────────────────────────────────────────────── script_name: "src.combine_sam_dino_v3" dilate_px: 5 From 6e01f8fcb71b3aab2de26019d390ae221d688feb Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 14 Apr 2026 14:40:59 -0700 Subject: [PATCH 21/42] check globus token expiration before minting a new one. avoids race condition when submitting concurrent jobs --- orchestration/globus/get_globus_token.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/orchestration/globus/get_globus_token.py b/orchestration/globus/get_globus_token.py index c47057e8..f740a034 100644 --- a/orchestration/globus/get_globus_token.py +++ b/orchestration/globus/get_globus_token.py @@ -291,6 +291,19 @@ def get_iri_access_token( or if the resulting tokens do not include a valid IRI access token. """ client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + + # Fast path: if token exists and is not expired, return it directly without refreshing or saving + if not force_login: + stored = load_tokens(token_file) + if stored: + try: + iri_token = get_iri_token(stored) + expires_at = iri_token.get("expires_at_seconds", 0) + if expires_at and time.time() < expires_at - 60: # 60s buffer + return iri_token["access_token"] + except RuntimeError: + pass # fall through to refresh + auth_data = None used_refresh = False if not force_login: From f4388e8581aeb93dd2f8d7170672e0db060f42b5 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 14 Apr 2026 14:41:33 -0700 Subject: [PATCH 22/42] Fixing IRIAPI bugs, also commenting out Globus transfers for now --- orchestration/flows/bl832/nersc.py | 219 ++++++++++++++++++++--------- 1 file changed, 152 insertions(+), 67 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 01f66dd4..5ea74f9b 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -45,6 +45,24 @@ NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", } +# NERSC resource IDs (from status/resources endpoint) +RESOURCE_IDS = { + # Perlmutter compute + "perlmutter_compute": "94351904-6dba-4c16-b5cd-fbd280d8615b", + "perlmutter_login": "e525a224-61c1-419f-9642-91168c792e39", + "perlmutter_realtime": "3776417d-747c-4753-895a-6323c17b9c98", + "perlmutter_job_submit": "3cf3c048-855e-4dd8-a189-065a483954bb", + # Storage + "scratch": "43d8f6c0-f900-48ce-b267-73714103f4ac", + "homes": "65b28619-c3b6-4942-8da1-044a3b3a2a9e", + "common": "7e07a611-f927-4a39-a44d-b1d6e307accd", + "cfs": "59e80c79-4dfd-4c53-9c07-7405685fcd37", + "archive": "f4916c65-9001-49c2-b0bf-6fe4276b564c", + # Services + "globus": "0a207df3-4bec-45b8-9060-13505d269da9", + "dtns": "a762cbdc-af7a-4b2b-9463-67f0189dd2ae", +} + @dataclass class SegmentationModelSpec: @@ -270,7 +288,7 @@ def _create_iriapi_client() -> Client: return httpx.Client( base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], headers={"Authorization": f"Bearer {access_token}"}, - timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), + timeout=httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0), ) @staticmethod @@ -387,7 +405,6 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: return str(job.jobid) elif self.login_method is NERSCLoginMethod.IRIAPI: - # Parse SBATCH directives before stripping them sbatch_values = {} for line in job_script.splitlines(): if line.startswith("#SBATCH"): @@ -397,44 +414,79 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: sbatch_values["account"] = line.split("-A ")[-1].strip() elif "--time=" in line: t = line.split("--time=")[-1].strip() - # convert HH:MM:SS to seconds parts = t.split(":") sbatch_values["duration"] = int(parts[0])*3600 + int(parts[1])*60 + int(parts[2]) elif "-N " in line: sbatch_values["node_count"] = int(line.split("-N ")[-1].strip()) elif "-C " in line: sbatch_values["constraint"] = line.split("-C ")[-1].strip() + elif "--output=" in line: + sbatch_values["stdout_path"] = line.split("--output=")[-1].strip() + elif "--error=" in line: + sbatch_values["stderr_path"] = line.split("--error=")[-1].strip() + # Strip shebang and SBATCH headers, keep the script body script_body = "\n".join( line for line in job_script.splitlines() if not line.startswith("#SBATCH") and not line.startswith("#!/") ).strip() + constraint = sbatch_values.get("constraint", "cpu") + is_gpu = "gpu" in constraint.lower() + + resources = { + "node_count": sbatch_values.get("node_count", 1), + "processes_per_node": 1, + "exclusive_node_use": True, + } + if is_gpu: + resources["gpu_cores_per_process"] = 4 + else: + resources["cpu_cores_per_process"] = 128 + job_spec = { "executable": "/bin/bash", - "arguments": ["-c", script_body], - "resources": { - "node_count": sbatch_values.get("node_count", 1), - "processes_per_node": 1, - "cpu_cores_per_process": 64, - "exclusive_node_use": True, - }, + "arguments": ["-s"], # read script from stdin isn't supported, so... + "pre_launch": script_body, # run the body here before the executable + "resources": resources, + # { + # "node_count": sbatch_values.get("node_count", 1), + # "processes_per_node": 1, + # "cpu_cores_per_process": 64, + # "exclusive_node_use": True, + # }, "attributes": { "duration": sbatch_values.get("duration", 1800), - "queue_name": sbatch_values.get("queue_name", "realtime"), + "queue_name": sbatch_values.get("queue_name", "regular"), "account": sbatch_values.get("account", "als"), "custom_attributes": { - "constraint": sbatch_values.get("constraint", "cpu") + "constraint": constraint # sbatch_values.get("constraint", "cpu") }, }, } + + if "stdout_path" in sbatch_values: + job_spec["stdout_path"] = sbatch_values["stdout_path"] + if "stderr_path" in sbatch_values: + job_spec["stderr_path"] = sbatch_values["stderr_path"] + response = self.client.post( - f"/api/v1/compute/job/{_IRI_COMPUTE_RESOURCE}", + "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", json=job_spec, ) + if not response.is_success: + logger.error(f"Job submission failed: {response.status_code} {response.text}") + logger.error(f"Job spec was: {json.dumps(job_spec, indent=2)}") response.raise_for_status() return str(response.json()["id"]) + # response = self.client.post( + # "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", + # json=job_spec, + # ) + # response.raise_for_status() + # return str(response.json()["id"]) + else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -485,7 +537,7 @@ def _mkdir_remote(self, path: str) -> None: perlmutter.run(f"mkdir -p {path}") elif self.login_method is NERSCLoginMethod.IRIAPI: response = self.client.post( - "/api/v1/filesystem/mkdir/perlmutter", + f"/api/v1/filesystem/mkdir/{RESOURCE_IDS["perlmutter_login"]}", json={"path": path, "parents": True}, ) response.raise_for_status() @@ -514,11 +566,32 @@ def _read_remote_file(self, path: str) -> str: elif self.login_method is NERSCLoginMethod.IRIAPI: response = self.client.get( - "/api/v1/filesystem/file/perlmutter", + f"/api/v1/filesystem/view/{RESOURCE_IDS['perlmutter_login']}", params={"path": path}, ) response.raise_for_status() - return response.text + task_id = response.json().get("task_id") + if not task_id: + return response.text + + for _ in range(40): + task_response = self.client.get(f"/api/v1/task/{task_id}") + task_response.raise_for_status() + task = task_response.json() + status = task.get("status") + if status == "completed": + result = task.get("result", "") + if isinstance(result, dict): + output = result.get("output", result) + if isinstance(output, dict): + return output.get("content", str(output)) + return str(output) + return str(result) + elif status == "failed": + raise RuntimeError(f"File read task {task_id} failed: {task.get('result')}") + time.sleep(3) + + raise TimeoutError(f"File read task {task_id} did not complete") else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -567,6 +640,8 @@ def reconstruct( opts = _load_job_options("nersc-reconstruction-options", self.config.nersc_recon_settings) + logger.info(f"Resolved options: {opts}") + num_nodes = opts.get("num_nodes", num_nodes) cpus_per_task = opts["cpus-per-task"] qos = opts["qos"] @@ -631,6 +706,7 @@ def reconstruct( echo "METADATA_START=$(date +%s)" >> $TIMING_FILE NUM_SLICES=$(shifter \ + --image={recon_image} \ --volume={pscratch_path}/8.3.2:/alsdata \ python -c " import h5py @@ -675,6 +751,7 @@ def reconstruct( fi srun --nodes=1 --ntasks=1 --exclusive shifter \ + --image={recon_image} \ --env=NUMEXPR_MAX_THREADS=128 \ --env=NUMEXPR_NUM_THREADS=128 \ --env=OMP_NUM_THREADS=128 \ @@ -933,7 +1010,7 @@ def segmentation_sam3( #SBATCH -A {account} {reservation_line} #SBATCH -N {num_nodes} -#SBATCH -C {constraint} # gpu +#SBATCH -C {constraint} #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node={ntasks_per_node} @@ -1913,7 +1990,6 @@ def nersc_petiole_segment_flow( file_path=file_path, num_nodes=num_nodes, config=config, - login_method=login_method ) if isinstance(recon_result, dict): @@ -1950,24 +2026,24 @@ def nersc_petiole_segment_flow( logger.info("Reconstruction Successful.") # ── STEP 2: Transfer TIFFs to data832 ──────────────────────────────────── - logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") - try: - data832_tiff_future = globus_transfer_task.submit( - file_path=scratch_path_tiff, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("TIFF transfer to data832 submitted.") - except Exception as e: - logger.error(f"Failed to transfer TIFFs to data832: {e}") - data832_tiff_transfer_success = False + # logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") + # try: + # data832_tiff_future = globus_transfer_task.submit( + # file_path=scratch_path_tiff, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("TIFF transfer to data832 submitted.") + # except Exception as e: + # logger.error(f"Failed to transfer TIFFs to data832: {e}") + # data832_tiff_transfer_success = False # ── STEP 3: SAM3 / DINOv3 ────────────────────────── logger.info("Submitting SAM3 and DINOv3 segmentation tasks concurrently.") sam3_future = nersc_segmentation_sam3_task.submit( - recon_folder_path=scratch_path_tiff, config=config, login_method=login_method + recon_folder_path=scratch_path_tiff, config=config ) dinov3_future = nersc_segmentation_dinov3_task.submit( recon_folder_path=scratch_path_tiff, config=config, project="petiole", login_method=login_method @@ -1979,15 +2055,17 @@ def nersc_petiole_segment_flow( logger.info(f"SAM3 segmentation result: {sam3_success}") if sam3_success: logger.info("Transferring SAM3 segmentation outputs to data832") - sam3_segment_path = f"{folder_name}/seg{file_name}/sam3" + # sam3_segment_path = f"{folder_name}/seg{file_name}/sam3" try: - data832_sam3_future = globus_transfer_task.submit( - file_path=sam3_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("SAM3 transfer to data832 submitted") + # data832_sam3_future = globus_transfer_task.submit( + # file_path=sam3_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("SAM3 transfer to data832 submitted") + data832_sam3_transfer_success = True + logger.info(f"SAM3 transfer to data832 success: {data832_sam3_transfer_success}") except Exception as e: logger.error(f"Failed to transfer SAM3 outputs to data832: {e}") @@ -1995,15 +2073,17 @@ def nersc_petiole_segment_flow( logger.info(f"DINOv3 segmentation result: {dinov3_success}") if dinov3_success: logger.info("Transferring DINOv3 segmentation outputs to data832") - dinov3_segment_path = f"{folder_name}/seg{file_name}/dino" + # dinov3_segment_path = f"{folder_name}/seg{file_name}/dino" try: - data832_dinov3_future = globus_transfer_task.submit( - file_path=dinov3_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("DINOv3 transfer to data832 submitted") + # data832_dinov3_future = globus_transfer_task.submit( + # file_path=dinov3_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("DINOv3 transfer to data832 submitted") + data832_dinov3_transfer_success = True + logger.info(f"DINOv3 transfer to data832 success: {data832_dinov3_transfer_success}") except Exception as e: logger.error(f"Failed to transfer DINOv3 outputs to data832: {e}") @@ -2016,22 +2096,24 @@ def nersc_petiole_segment_flow( logger.info("Running segmentation combination.") combine_future = nersc_combine_segmentations_task.submit( - recon_folder_path=scratch_path_tiff, config=config, login_method=login_method + recon_folder_path=scratch_path_tiff, config=config ) combine_success = combine_future.result() logger.info(f"Combination result: {combine_success}") if combine_success: logger.info("Transferring combined segmentation outputs to data832") - combined_segment_path = f"{folder_name}/seg{file_name}/combined/sam_dino" + # combined_segment_path = f"{folder_name}/seg{file_name}/combined/sam_dino" try: - data832_combined_future = globus_transfer_task.submit( - file_path=combined_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("Combined transfer to data832 submitted") + # data832_combined_future = globus_transfer_task.submit( + # file_path=combined_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("Combined transfer to data832 submitted") + data832_combined_transfer_success = True + logger.info(f"Combined transfer to data832 success: {data832_combined_transfer_success}") except Exception as e: logger.error(f"Failed to transfer combined outputs to data832: {e}") @@ -2527,7 +2609,6 @@ def nersc_multiresolution_integration_test() -> bool: def nersc_segmentation_sam3_task( recon_folder_path: str, config: Optional[Config832] = None, - login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Run segmentation task at NERSC. @@ -2546,7 +2627,7 @@ def nersc_segmentation_sam3_task( tomography_controller = get_controller( hpc_type=HPC.NERSC, config=config, - login_method=login_method + login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Starting NERSC segmentation task for {recon_folder_path=}") nersc_segmentation_success = tomography_controller.segmentation_sam3( @@ -2573,7 +2654,7 @@ def nersc_segmentation_dinov3_task( if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=NERSCLoginMethod.IRIAPI) logger.info(f"Starting NERSC DINOv3 segmentation task for {recon_folder_path=}, {project=}") success = tomography_controller.segmentation_dinov3(recon_folder_path=recon_folder_path, project=project) if not success: @@ -2587,13 +2668,12 @@ def nersc_segmentation_dinov3_task( def nersc_combine_segmentations_task( recon_folder_path: str, config: Optional[Config832] = None, - login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: logger = get_run_logger() if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=NERSCLoginMethod.IRIAPI) logger.info(f"Starting NERSC combine segmentations task for {recon_folder_path=}") success = tomography_controller.combine_segmentations(recon_folder_path=recon_folder_path) if not success: @@ -2622,9 +2702,14 @@ def nersc_segmentation_sam3_integration_test() -> bool: return flow_success -if __name__ == "__main__": - nersc_segmentation_dinov3_task( - recon_folder_path='dabramov/recmoon/', - config=Config832(), - project="moon" - ) +# if __name__ == "__main__": + # nersc_segmentation_dinov3_task( + # recon_folder_path='dabramov/recmoon/', + # config=Config832(), + # project="moon" + # ) + # nersc_petiole_segment_flow( + # file_path='dabramov/20260221_143000_petiole28', + # num_nodes=2, + # login_method=NERSCLoginMethod.IRIAPI + # ) From a490bfee39ac6f8344b33851a8d393f21dd3df1f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 15 Apr 2026 16:03:38 -0700 Subject: [PATCH 23/42] removing IRIAPI client ID from nersc.py, since it is only used in globus/get_globus_token.py --- orchestration/flows/bl832/nersc.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 5ea74f9b..808a2a4c 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -269,13 +269,6 @@ def _create_iriapi_client() -> Client: ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. RuntimeError: If the acquired token is missing required scopes. """ - client_id = "fae5c579-490a-4d76-b6eb-d78f65caeb63" # os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) - - if not client_id: - raise ValueError( - f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." - ) - token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE From 041f33688536035b6858a4d62959f20499400099 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 23 Apr 2026 11:04:53 -0700 Subject: [PATCH 24/42] Updating logger comments --- orchestration/flows/bl832/nersc.py | 42 ++++++++++-------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 808a2a4c..39bd4a02 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -209,7 +209,7 @@ def __init__( @staticmethod def create_nersc_client( login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, - ) -> Client: + ) -> Client | httpx.Client: """Create and return a NERSC client for the requested login method. Two fundamentally different auth strategies are supported: @@ -255,7 +255,7 @@ def create_nersc_client( return client @staticmethod - def _create_iriapi_client() -> Client: + def _create_iriapi_client() -> httpx.Client: """Create a NERSC client for the IRI API using a Globus bearer token. Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the @@ -263,7 +263,7 @@ def _create_iriapi_client() -> Client: via the client credentials grant. No browser or user interaction. Returns: - An authenticated :class:`sfapi_client.Client` targeting the IRI API. + An authenticated :class:`httpx.Client` targeting the IRI API. Raises: ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. @@ -442,18 +442,12 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: "arguments": ["-s"], # read script from stdin isn't supported, so... "pre_launch": script_body, # run the body here before the executable "resources": resources, - # { - # "node_count": sbatch_values.get("node_count", 1), - # "processes_per_node": 1, - # "cpu_cores_per_process": 64, - # "exclusive_node_use": True, - # }, "attributes": { "duration": sbatch_values.get("duration", 1800), "queue_name": sbatch_values.get("queue_name", "regular"), "account": sbatch_values.get("account", "als"), "custom_attributes": { - "constraint": constraint # sbatch_values.get("constraint", "cpu") + "constraint": constraint }, }, } @@ -464,7 +458,7 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: job_spec["stderr_path"] = sbatch_values["stderr_path"] response = self.client.post( - "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", + f"/api/v1/compute/job/{RESOURCE_IDS['perlmutter_job_submit']}", json=job_spec, ) if not response.is_success: @@ -473,13 +467,6 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: response.raise_for_status() return str(response.json()["id"]) - # response = self.client.post( - # "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", - # json=job_spec, - # ) - # response.raise_for_status() - # return str(response.json()["id"]) - else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -1640,15 +1627,15 @@ def check_shifter_image( pscratch_path = f"/pscratch/sd/{username[0]}/{username}" output_file = f"{pscratch_path}/tomo_recon_logs/shifter_check.txt" check_script = f"""#!/bin/bash - #SBATCH -q debug - #SBATCH -A als - #SBATCH -C cpu - #SBATCH -N 1 - #SBATCH --ntasks=1 - #SBATCH --cpus-per-task=1 - #SBATCH --time=0:05:00 - shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true - """ +#SBATCH -q debug +#SBATCH -A als +#SBATCH -C cpu +#SBATCH -N 1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --time=0:05:00 +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true +""" job_id = self._submit_job(check_script) self._wait_for_job(job_id) output = self._read_remote_file(output_file) @@ -2689,7 +2676,6 @@ def nersc_segmentation_sam3_integration_test() -> bool: flow_success = nersc_segmentation_sam3_task( recon_folder_path=recon_folder_path, config=Config832(), - login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Flow success: {flow_success}") return flow_success From 863b24e94dd2cab4e7ef5b92fdd78faa53fe261f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 12:37:31 -0700 Subject: [PATCH 25/42] connecting to AmSC MLflow service --- .env.example | 7 +- config.yml | 4 + orchestration/flows/bl832/config.py | 2 +- orchestration/flows/bl832/register_mlflow.py | 188 ++++++++++++++++--- orchestration/mlflow.py | 86 ++++++++- 5 files changed, 262 insertions(+), 25 deletions(-) diff --git a/.env.example b/.env.example index e3728e89..b7e54812 100644 --- a/.env.example +++ b/.env.example @@ -1,7 +1,12 @@ +BEAMLINE=8.3.2 GLOBUS_CLIENT_ID= GLOBUS_CLIENT_SECRET= PREFECT_API_URL= PREFECT_API_KEY= PUSHGATEWAY_URL= JOB_NAME= -INSTANCE_LABEL= \ No newline at end of file +INSTANCE_LABEL= +PATH_NERSC_CLIENT_ID= +PATH_NERSC_PRI_KEY= +NERSC_USERNAME= +AMSC_API_KEY= # found here: https://profile.american-science-cloud.org/ \ No newline at end of file diff --git a/config.yml b/config.yml index 2d6d2ab3..6ee1cbb7 100644 --- a/config.yml +++ b/config.yml @@ -173,6 +173,10 @@ mlflow: staging: tracking_uri: https://mlflow-staging.computing.als.lbl.gov registry_uri: https://mlflow-staging.computing.als.lbl.gov + amsc: + tracking_uri: https://mlflow.american-science-cloud.org/ + registry_uri: https://mlflow.american-science-cloud.org/ + experiment_name: als-bl832-models hpc_submission_settings832: # ── RECON + MULTIRES SETTINGS ─────────────────────────────────────────────── diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 8bbbf78c..281ba167 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -30,7 +30,7 @@ def _beam_specific_config(self) -> None: # SciCat self.scicat = self.config["scicat"] # MLflow - self.mlflow = self.config["mlflow"]["local"] + self.mlflow = self.config["mlflow"]["amsc"] # NERSC HPC submission settings self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] diff --git a/orchestration/flows/bl832/register_mlflow.py b/orchestration/flows/bl832/register_mlflow.py index 31fa3760..d540ea67 100644 --- a/orchestration/flows/bl832/register_mlflow.py +++ b/orchestration/flows/bl832/register_mlflow.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)-7s | %(name)s - %(message)s") def register_mlflow_checkpoints(): @@ -16,14 +17,18 @@ def register_mlflow_checkpoints(): register_checkpoint( model_name="sam3-petiole", nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", - alcf_path="/eagle/IRIBeta/als/seg_models/sam3/checkpoint_v6.pt", + alcf_path="/eagle/SYNAPS-I/segmentation/sam3_finetune/sam3/checkpoint_v6.pt", config=config, alias="production", description="SAM3 v6 fine-tuned on petiole micro-CT data.", inference_params={ - # ── paths ────────────────────────────────────────────────────────── - "original_checkpoint_path": - f"{scripts_dir}sam3_finetune/sam3/sam3.pt", + # ── site-specific HF caches ───────────────────────────────────────── + "nersc_hf_home": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + "nersc_hf_hub_cache": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + "alcf_hf_home": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + "alcf_hf_hub_cache": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # ── paths ─────────────────────────────────────────────────────────── + "original_checkpoint_path": f"{scripts_dir}sam3_finetune/sam3/sam3.pt", "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", @@ -32,9 +37,9 @@ def register_mlflow_checkpoints(): "script_name": "src/inference_v6.py", "batch_size": 1, "patch_size": 400, - "confidence": [0.5], # list → JSON-encoded automatically + "confidence": [0.5], "overlap": 0.25, - "prompts": [ # list → JSON-encoded automatically + "prompts": [ "Phloem Fibers", "Hydrated Xylem vessels", "Air-based Pith cells", @@ -46,12 +51,17 @@ def register_mlflow_checkpoints(): register_checkpoint( model_name="dinov3-petiole", nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", - alcf_path="/eagle/IRIBeta/als/seg_models/dino/best.ckpt", + alcf_path="/eagle/SYNAPS-I/segmentation/dino/best.ckpt", config=config, alias="production", description="DINOv3 fine-tuned on petiole micro-CT data.", inference_params={ - # ── paths ────────────────────────────────────────────────────────── + # ── site-specific HF caches ───────────────────────────────────────── + "nersc_hf_home": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + "nersc_hf_hub_cache": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + "alcf_hf_home": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + "alcf_hf_hub_cache": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # ── paths ─────────────────────────────────────────────────────────── "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", # ── inference hyperparameters ─────────────────────────────────────── @@ -64,19 +74,102 @@ def register_mlflow_checkpoints(): register_checkpoint( model_name="dinov3-moon", nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best_moon.ckpt", - alcf_path="/eagle/IRIBeta/als/seg_models/dino/best_moon.ckpt", + alcf_path="/eagle/SYNAPS-I/segmentation/seg_models/dino/best_moon.ckpt", config=config, alias="production", description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", inference_params={ + # ── site-specific HF caches ───────────────────────────────────────── + "nersc_hf_home": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + "nersc_hf_hub_cache": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + "alcf_hf_home": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + "alcf_hf_hub_cache": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # ── paths ─────────────────────────────────────────────────────────── "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", "seg_scripts_dir": f"{scripts_dir}moon_seg/forge_feb_seg_model_demo/", + # ── inference hyperparameters ─────────────────────────────────────── "script_name": "src.inference_dino_v2", "batch_size": 4, "nproc_per_node": 4, }, ) + # register_checkpoint( + # model_name="sam3-petiole", + # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + # nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", + # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_path="/eagle/SYNAPS-I/segmentation/sam3_finetune/sam3/checkpoint_v6.pt", + # config=config, + # alias="production", + # description="SAM3 v6 fine-tuned on petiole micro-CT data.", + # inference_params={ + # # ── paths ────────────────────────────────────────────────────────── + # "original_checkpoint_path": + # f"{scripts_dir}sam3_finetune/sam3/sam3.pt", + # "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", + # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", + # "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", + # "checkpoints_dir": f"{scripts_dir}sam3_finetune/sam3/", + # # ── inference hyperparameters ─────────────────────────────────────── + # "script_name": "src/inference_v6.py", + # "batch_size": 1, + # "patch_size": 400, + # "confidence": [0.5], # list → JSON-encoded automatically + # "overlap": 0.25, + # "prompts": [ # list → JSON-encoded automatically + # "Phloem Fibers", + # "Hydrated Xylem vessels", + # "Air-based Pith cells", + # "Dehydrated Xylem vessels", + # ], + # }, + # ) + + # register_checkpoint( + # model_name="dinov3-petiole", + # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + # nersc_checkpoint_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", + # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_path="/eagle/SYNAPS-I/segmentation/dino/best.ckpt", + # config=config, + # alias="production", + # description="DINOv3 fine-tuned on petiole micro-CT data.", + # inference_params={ + # # ── paths ────────────────────────────────────────────────────────── + # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", + # "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", + # # ── inference hyperparameters ─────────────────────────────────────── + # "script_name": "src.inference_dino_v1", + # "batch_size": 4, + # "nproc_per_node": 4, + # }, + # ) + + # register_checkpoint( + # model_name="dinov3-moon", + # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + # nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best_moon.ckpt", + # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_path="/eagle/SYNAPS-I/segmentation/seg_models/dino/best_moon.ckpt", + # config=config, + # alias="production", + # description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", + # inference_params={ + # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", + # "seg_scripts_dir": f"{scripts_dir}moon_seg/forge_feb_seg_model_demo/", + # "script_name": "src.inference_dino_v2", + # "batch_size": 4, + # "nproc_per_node": 4, + # }, + # ) + def retrieve_mlflow_params_test() -> bool: """Test that _load_job_options correctly pulls inference params from the MLflow registry. @@ -106,28 +199,56 @@ def retrieve_mlflow_params_test() -> bool: ) sam3_checks = { - # MLflow should have overridden these + # ── MLflow should have overridden these ────────────────────────────── "finetuned_checkpoint_path": ( lambda v: "checkpoint" in v, "finetuned_checkpoint_path should contain 'checkpoint'" ), + "original_checkpoint_path": ( + lambda v: v.endswith(".pt") and "sam3" in v.lower(), + "original_checkpoint_path should point at a sam3 .pt file" + ), + "bpe_path": ( + lambda v: v.endswith(".txt.gz"), + "bpe_path should point at a .txt.gz vocab file" + ), "conda_env_path": ( lambda v: "sam3" in v, "conda_env_path should reference sam3 env" ), + "seg_scripts_dir": ( + lambda v: isinstance(v, str) and len(v) > 0, + "seg_scripts_dir should be a non-empty path" + ), + "checkpoints_dir": ( + lambda v: isinstance(v, str) and len(v) > 0, + "checkpoints_dir should be a non-empty path" + ), + "script_name": ( + lambda v: "inference" in v.lower(), + "script_name should reference an inference script" + ), "prompts": ( lambda v: isinstance(v, list) and len(v) > 0, "prompts should be a non-empty list (JSON-deserialized)" ), "confidence": ( - lambda v: isinstance(v, list), - "confidence should be a list (JSON-deserialized)" + lambda v: isinstance(v, list) and len(v) > 0, + "confidence should be a non-empty list (JSON-deserialized)" ), "batch_size": ( - lambda v: isinstance(v, int), - "batch_size should be an int" + lambda v: isinstance(v, int) and v > 0, + "batch_size should be a positive int" ), - # SLURM params should still come from config + "patch_size": ( + lambda v: isinstance(v, int) and v > 0, + "patch_size should be a positive int" + ), + "overlap": ( + lambda v: isinstance(v, float) and 0.0 <= v < 1.0, + "overlap should be a float in [0.0, 1.0)" + ), + # ── SLURM params should still come from config ─────────────────────── "qos": ( lambda v: v == config.nersc_segment_sam3_settings["qos"], "qos should be unchanged from config" @@ -160,23 +281,32 @@ def retrieve_mlflow_params_test() -> bool: ) dino_checks = { + # ── MLflow-overridden ──────────────────────────────────────────────── "dino_checkpoint_path": ( lambda v: v.endswith(".ckpt"), "dino_checkpoint_path should end with .ckpt" ), "conda_env_path": ( - lambda v: len(v) > 0, + lambda v: isinstance(v, str) and len(v) > 0, "conda_env_path should be non-empty" ), - "batch_size": ( - lambda v: isinstance(v, int) and v > 0, - "batch_size should be a positive int" + "seg_scripts_dir": ( + lambda v: isinstance(v, str) and len(v) > 0, + "seg_scripts_dir should be a non-empty path" ), "script_name": ( lambda v: "dino" in v.lower(), "script_name should reference dino" ), - # SLURM params unchanged + "batch_size": ( + lambda v: isinstance(v, int) and v > 0, + "batch_size should be a positive int" + ), + "nproc_per_node": ( + lambda v: isinstance(v, int) and v > 0, + "nproc_per_node should be a positive int" + ), + # ── SLURM params unchanged ─────────────────────────────────────────── "qos": ( lambda v: v == config.nersc_segment_dinov3_settings["qos"], "qos should be unchanged from config" @@ -209,9 +339,18 @@ def retrieve_mlflow_params_test() -> bool: ) moon_checks = { + # ── MLflow-overridden ──────────────────────────────────────────────── "dino_checkpoint_path": ( - lambda v: v.endswith(".ckpt"), - "dino_checkpoint_path should end with .ckpt" + lambda v: v.endswith(".ckpt") and "moon" in v.lower(), + "dino_checkpoint_path should end with .ckpt and reference moon" + ), + "conda_env_path": ( + lambda v: isinstance(v, str) and len(v) > 0, + "conda_env_path should be non-empty" + ), + "seg_scripts_dir": ( + lambda v: isinstance(v, str) and "moon" in v.lower(), + "seg_scripts_dir should reference moon_seg" ), "script_name": ( lambda v: "v2" in v.lower(), @@ -221,6 +360,11 @@ def retrieve_mlflow_params_test() -> bool: lambda v: isinstance(v, int) and v > 0, "batch_size should be a positive int" ), + "nproc_per_node": ( + lambda v: isinstance(v, int) and v > 0, + "nproc_per_node should be a positive int" + ), + # ── SLURM params unchanged ─────────────────────────────────────────── "qos": ( lambda v: v == config.nersc_segment_dinov3_moon_settings["qos"], "qos should be unchanged from config" diff --git a/orchestration/mlflow.py b/orchestration/mlflow.py index cff8487c..c337a2ba 100644 --- a/orchestration/mlflow.py +++ b/orchestration/mlflow.py @@ -1,16 +1,23 @@ import logging from dataclasses import dataclass, field +from dotenv import load_dotenv import json +import os import requests from typing import Any import mlflow from mlflow.tracking import MlflowClient +import mlflow.utils.rest_utils as rest_utils + from orchestration.config import BeamlineConfig logger = logging.getLogger(__name__) +_AMSC_PATCH_FLAG: str = "_amsc_x_api_key_patched" +load_dotenv() + @dataclass class ModelCheckpointInfo: @@ -48,13 +55,63 @@ def _is_mlflow_reachable(tracking_uri: str, timeout: float = 2.0) -> bool: Returns: True if the server responds with HTTP 200, False otherwise. """ + headers = {} + api_key = os.environ.get("AMSC_API_KEY") + if api_key: + headers["X-Api-Key"] = api_key try: - response = requests.get(f"{tracking_uri}/health", timeout=timeout) + response = requests.get( + f"{tracking_uri}/health", headers=headers, timeout=timeout + ) return response.status_code == 200 except Exception: return False +def _enable_amsc_x_api_key() -> bool: + """Patch mlflow.utils.rest_utils.http_request to inject X-Api-Key. + + Required by the American Science Cloud MLflow server, which enforces + API-key auth on all REST calls. Standard MLflow does not send custom + headers, so we wrap ``http_request`` at import time. + + Idempotent: repeat calls are no-ops thanks to a sentinel attribute on + the wrapper. Silently skips patching if ``AMSC_API_KEY`` is unset, + which lets the same codebase target non-AMSC MLflow servers. + + Returns: + True if the patch is (or was already) active, False if the API + key env var is unset. + """ + + api_key = os.environ.get("AMSC_API_KEY") + if not api_key: + return False + + if getattr(rest_utils.http_request, _AMSC_PATCH_FLAG, False): + return True + + original = rest_utils.http_request + + def patched(host_creds, endpoint, method, *args, **kwargs): + # MLflow internals call http_request with either `headers` or + # `extra_headers` depending on the code path — handle both. + if "headers" in kwargs and kwargs["headers"] is not None: + h = dict(kwargs["headers"]) + h["X-Api-Key"] = api_key + kwargs["headers"] = h + else: + h = dict(kwargs.get("extra_headers") or {}) + h["X-Api-Key"] = api_key + kwargs["extra_headers"] = h + return original(host_creds, endpoint, method, *args, **kwargs) + + setattr(patched, _AMSC_PATCH_FLAG, True) + rest_utils.http_request = patched + logger.info("AMSC X-Api-Key injection enabled for MLflow REST calls.") + return True + + def get_mlflow_client(config: BeamlineConfig) -> MlflowClient: """Construct an MlflowClient pointed at the configured tracking server. @@ -65,6 +122,7 @@ def get_mlflow_client(config: BeamlineConfig) -> MlflowClient: An authenticated MlflowClient instance. """ tracking_uri = config.mlflow["tracking_uri"] + _enable_amsc_x_api_key() # Idempotent patch for AMSC API key injection mlflow.set_tracking_uri(tracking_uri) return MlflowClient(tracking_uri=tracking_uri) @@ -183,7 +241,19 @@ def register_checkpoint( client.create_registered_model(model_name) mlflow.set_tracking_uri(config.mlflow["tracking_uri"]) + + # Use a dedicated experiment so the creator (this user) gets MANAGE + # permission automatically — avoids 403 on the default experiment. + experiment_name = config.mlflow.get("experiment_name", "als-model-registration") + experiment = mlflow.get_experiment_by_name(experiment_name) + if experiment is None: + experiment_id = mlflow.create_experiment(experiment_name) + logger.info(f"Created MLflow experiment '{experiment_name}' (id={experiment_id}).") + else: + experiment_id = experiment.experiment_id + with mlflow.start_run( + experiment_id=experiment_id, run_name=f"register_{model_name}", tags={"mlflow.note.content": description}, ) as run: @@ -247,7 +317,21 @@ def log_segmentation_metrics( run_tags: dict[str, str] = {"model": model_name, "slurm_job_id": job_id} + tracking_uri = config.mlflow["tracking_uri"] + mlflow.set_tracking_uri(tracking_uri) + _enable_amsc_x_api_key() # ensure AMSC auth patch is active for this entrypoint too + + experiment_name = config.mlflow.get("experiment_name", "als-model-registration") + experiment = mlflow.get_experiment_by_name(experiment_name) + if experiment is None: + experiment_id = mlflow.create_experiment(experiment_name) + else: + experiment_id = experiment.experiment_id + + run_tags: dict[str, str] = {"model": model_name, "slurm_job_id": job_id} + with mlflow.start_run( + experiment_id=experiment_id, run_name=run_name, nested=parent_run_id is not None, parent_run_id=parent_run_id, From 0144f525f55cb4feaa4a1fa24b8163ed8ec0e805 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:20:29 -0700 Subject: [PATCH 26/42] removing old commented code --- orchestration/flows/bl832/register_mlflow.py | 76 -------------------- 1 file changed, 76 deletions(-) diff --git a/orchestration/flows/bl832/register_mlflow.py b/orchestration/flows/bl832/register_mlflow.py index d540ea67..93603295 100644 --- a/orchestration/flows/bl832/register_mlflow.py +++ b/orchestration/flows/bl832/register_mlflow.py @@ -94,82 +94,6 @@ def register_mlflow_checkpoints(): }, ) - # register_checkpoint( - # model_name="sam3-petiole", - # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", - # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", - # nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", - # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_path="/eagle/SYNAPS-I/segmentation/sam3_finetune/sam3/checkpoint_v6.pt", - # config=config, - # alias="production", - # description="SAM3 v6 fine-tuned on petiole micro-CT data.", - # inference_params={ - # # ── paths ────────────────────────────────────────────────────────── - # "original_checkpoint_path": - # f"{scripts_dir}sam3_finetune/sam3/sam3.pt", - # "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", - # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", - # "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", - # "checkpoints_dir": f"{scripts_dir}sam3_finetune/sam3/", - # # ── inference hyperparameters ─────────────────────────────────────── - # "script_name": "src/inference_v6.py", - # "batch_size": 1, - # "patch_size": 400, - # "confidence": [0.5], # list → JSON-encoded automatically - # "overlap": 0.25, - # "prompts": [ # list → JSON-encoded automatically - # "Phloem Fibers", - # "Hydrated Xylem vessels", - # "Air-based Pith cells", - # "Dehydrated Xylem vessels", - # ], - # }, - # ) - - # register_checkpoint( - # model_name="dinov3-petiole", - # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", - # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", - # nersc_checkpoint_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", - # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_path="/eagle/SYNAPS-I/segmentation/dino/best.ckpt", - # config=config, - # alias="production", - # description="DINOv3 fine-tuned on petiole micro-CT data.", - # inference_params={ - # # ── paths ────────────────────────────────────────────────────────── - # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", - # "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", - # # ── inference hyperparameters ─────────────────────────────────────── - # "script_name": "src.inference_dino_v1", - # "batch_size": 4, - # "nproc_per_node": 4, - # }, - # ) - - # register_checkpoint( - # model_name="dinov3-moon", - # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", - # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", - # nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best_moon.ckpt", - # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_path="/eagle/SYNAPS-I/segmentation/seg_models/dino/best_moon.ckpt", - # config=config, - # alias="production", - # description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", - # inference_params={ - # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", - # "seg_scripts_dir": f"{scripts_dir}moon_seg/forge_feb_seg_model_demo/", - # "script_name": "src.inference_dino_v2", - # "batch_size": 4, - # "nproc_per_node": 4, - # }, - # ) - def retrieve_mlflow_params_test() -> bool: """Test that _load_job_options correctly pulls inference params from the MLflow registry. From 0ad03aca2729c628c9c9b479e4db3f9eddd87803 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:20:38 -0700 Subject: [PATCH 27/42] updating pytest --- orchestration/_tests/test_bl832/test_nersc.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 3ba9742f..9952335d 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -5,6 +5,8 @@ from prefect.blocks.system import Secret from prefect.testing.utilities import prefect_test_harness +from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE + # ────────────────────────────────────────────────────────────────────────────── # Session fixture @@ -445,18 +447,21 @@ def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, assert result["success"] is True assert result["job_id"] == "99999" mock_iriapi_client.post.assert_called_once() - assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/compute" + assert ( + mock_iriapi_client.post.call_args.args[0] + == f"/api/v1/compute/job/{RESOURCE_IDS['perlmutter_job_submit']}" + ) posted_json = mock_iriapi_client.post.call_args.kwargs["json"] assert posted_json["executable"] == "/bin/bash" - assert posted_json["arguments"][0] == "-c" - assert isinstance(posted_json["arguments"][1], str) # the script body - assert "tomo_recon" in posted_json["arguments"][1] # sanity check it's the right script - assert mock_iriapi_client.get.call_count == 2 + assert posted_json["arguments"] == ["-s"] # matches nersc.py + assert "pre_launch" in posted_json # script body lives here + assert "tomo_recon" in posted_json["pre_launch"] # sanity check it's the right script + mock_iriapi_client.get.assert_any_call( - "/api/v1/compute/status/compute/99999" + f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/99999" ) mock_iriapi_client.get.assert_any_call( - "/api/v1/filesystem/file/perlmutter", + f"/api/v1/filesystem/view/{RESOURCE_IDS['perlmutter_login']}", params={"path": mocker.ANY}, ) From 9d8e2c113470070f754cbdfeb19e170d70623a28 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:25:48 -0700 Subject: [PATCH 28/42] linting --- orchestration/flows/bl832/nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 39bd4a02..6e4292d0 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -517,7 +517,7 @@ def _mkdir_remote(self, path: str) -> None: perlmutter.run(f"mkdir -p {path}") elif self.login_method is NERSCLoginMethod.IRIAPI: response = self.client.post( - f"/api/v1/filesystem/mkdir/{RESOURCE_IDS["perlmutter_login"]}", + f"/api/v1/filesystem/mkdir/{RESOURCE_IDS['perlmutter_login']}", json={"path": path, "parents": True}, ) response.raise_for_status() From e4f4e08d72656bc4ce55714be0eec4be131d66ff Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:31:46 -0700 Subject: [PATCH 29/42] adjusting import in pytest to avoid error on github that did not occur locally --- orchestration/_tests/test_bl832/test_nersc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 9952335d..1ec12264 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -5,8 +5,6 @@ from prefect.blocks.system import Secret from prefect.testing.utilities import prefect_test_harness -from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE - # ────────────────────────────────────────────────────────────────────────────── # Session fixture @@ -432,6 +430,7 @@ def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_co def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") From 98c70642a73b3861ff884d6e432e8379ba9fdce0 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 7 May 2026 09:39:23 -0700 Subject: [PATCH 30/42] Getting NERSC reservations working with IRI API --- config.yml | 32 ++++++++-------- orchestration/flows/bl832/nersc.py | 60 +++++++++++++++--------------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/config.yml b/config.yml index 6ee1cbb7..aa1fe9a1 100644 --- a/config.yml +++ b/config.yml @@ -182,10 +182,10 @@ hpc_submission_settings832: # ── RECON + MULTIRES SETTINGS ─────────────────────────────────────────────── nersc_reconstruction: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: debug - account: als - reservation: "" - num_nodes: 2 + qos: regular + account: amsc006 + reservation: "_CAP_SYNAPS_CPU4" + num_nodes: 4 cpus-per-task: 128 walltime: "0:30:00" nersc_multiresolution: @@ -199,11 +199,11 @@ hpc_submission_settings832: # ── PETIOLE SEGMENTATION SETTINGS ─────────────────────────────────────────── nersc_segmentation_sam3: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: debug - account: als + qos: regular + account: amsc006 constraint: gpu - reservation: "" - num_nodes: 2 + reservation: "_CAP_SYNAPS_GPU4" + num_nodes: 32 ntasks-per-node: 1 gpus-per-node: 4 cpus-per-task: 128 @@ -230,11 +230,11 @@ hpc_submission_settings832: finetuned_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/checkpoint_v6.pt nersc_segmentation_dinov3: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: debug - account: als + qos: regular + account: amsc006 constraint: gpu - reservation: "" - num_nodes: 2 + reservation: "_CAP_SYNAPS_GPU4" + num_nodes: 8 ntasks-per-node: 1 nproc_per_node: 4 gpus-per-node: 4 @@ -250,11 +250,11 @@ hpc_submission_settings832: dino_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt nersc_combine_segmentations: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: debug - account: als + qos: regular + account: amsc006 constraint: cpu - reservation: "" - num_nodes: 2 + reservation: "_CAP_SYNAPS_CPU4" + num_nodes: 4 ntasks: 128 cpus-per-task: 1 walltime: "00:30:00" diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 6e4292d0..b067156a 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field import datetime from dotenv import load_dotenv -import enum import httpx import json import logging @@ -176,19 +175,6 @@ def _load_job_options( return {**opts, **overrides} -class NERSCLoginMethod(enum.Enum): - """Selects which NERSC API login method to use when creating a NERSC client. - - Each method corresponds to a different set of credentials and API base URL. - """ - - SFAPI = "sfapi" - """Standard Superfacility API via Iris-registered OAuth2 credentials.""" - - IRIAPI = "iriapi" - """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" - - class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): """ Implementation for a NERSC-based tomography HPC controller. @@ -200,7 +186,7 @@ def __init__( self, config: Config832, client: Client | httpx.Client | None = None, - login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, + login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, ) -> None: TomographyHPCController.__init__(self, config) self.client = client @@ -208,7 +194,7 @@ def __init__( @staticmethod def create_nersc_client( - login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, + login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, ) -> Client | httpx.Client: """Create and return a NERSC client for the requested login method. @@ -417,6 +403,8 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: sbatch_values["stdout_path"] = line.split("--output=")[-1].strip() elif "--error=" in line: sbatch_values["stderr_path"] = line.split("--error=")[-1].strip() + elif "--reservation=" in line: + sbatch_values["reservation"] = line.split("--reservation=")[-1].strip() # Strip shebang and SBATCH headers, keep the script body script_body = "\n".join( @@ -437,19 +425,31 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: else: resources["cpu_cores_per_process"] = 128 + custom_attributes = {"constraint": constraint} + + attributes = { + "duration": sbatch_values.get("duration", 1800), + "queue_name": sbatch_values.get("queue_name", "regular"), + "account": sbatch_values.get("account", "als"), + "custom_attributes": custom_attributes, + } + if "reservation" in sbatch_values: + attributes["reservation_id"] = sbatch_values["reservation"] + + job_spec = { + "executable": "/bin/bash", + "arguments": ["-s"], + "pre_launch": script_body, + "resources": resources, + "attributes": attributes, +} + job_spec = { "executable": "/bin/bash", "arguments": ["-s"], # read script from stdin isn't supported, so... "pre_launch": script_body, # run the body here before the executable "resources": resources, - "attributes": { - "duration": sbatch_values.get("duration", 1800), - "queue_name": sbatch_values.get("queue_name", "regular"), - "account": sbatch_values.get("account", "als"), - "custom_attributes": { - "constraint": constraint - }, - }, + "attributes": attributes } if "stdout_path" in sbatch_values: @@ -2681,14 +2681,14 @@ def nersc_segmentation_sam3_integration_test() -> bool: return flow_success -# if __name__ == "__main__": +if __name__ == "__main__": # nersc_segmentation_dinov3_task( # recon_folder_path='dabramov/recmoon/', # config=Config832(), # project="moon" # ) - # nersc_petiole_segment_flow( - # file_path='dabramov/20260221_143000_petiole28', - # num_nodes=2, - # login_method=NERSCLoginMethod.IRIAPI - # ) + nersc_petiole_segment_flow( + file_path='dabramov/20260221_143000_petiole28', + num_nodes=4, + login_method=NERSCLoginMethod.IRIAPI + ) From f9200fcc451328c84e13e024d95ce9f40fe2ad4c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 7 May 2026 11:38:34 -0700 Subject: [PATCH 31/42] Updating pytests --- .../_tests/test_bl832/test_mlflow.py | 59 +++++++------ orchestration/_tests/test_bl832/test_nersc.py | 88 ++++++++++++++----- 2 files changed, 100 insertions(+), 47 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_mlflow.py b/orchestration/_tests/test_bl832/test_mlflow.py index 4ca12720..49806332 100644 --- a/orchestration/_tests/test_bl832/test_mlflow.py +++ b/orchestration/_tests/test_bl832/test_mlflow.py @@ -423,12 +423,12 @@ class TestSegmentationSam3MLflowCheckpoint: segmentation_sam3 uses it in the submitted SLURM job script. """ - def test_mlflow_checkpoint_appears_in_job_script(self, mocker, mock_sfapi_client, mock_config832): + def test_mlflow_checkpoint_appears_in_job_script(self, mocker, mock_config832): """ When _load_job_options returns an MLflow-sourced finetuned_checkpoint_path, that path must appear in the SLURM script submitted to Perlmutter. """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") @@ -441,60 +441,69 @@ def test_mlflow_checkpoint_appears_in_job_script(self, mocker, mock_sfapi_client return_value=resolved_settings, ) + # Capture the script passed to _submit_job — bypasses the SFAPI/IRIAPI + # dispatch entirely and avoids the real client.post() / job polling. captured = [] - original_job = mock_sfapi_client.compute.return_value.submit_job.return_value - def capture_script(script): + def capture_script(script, *args, **kwargs): captured.append(script) - return original_job + return "12345" # fake job id - mock_sfapi_client.compute.return_value.submit_job.side_effect = capture_script - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + # Use a real client mock just so __init__ doesn't fail; we'll never call it. + controller = NERSCTomographyHPCController( + client=mocker.MagicMock(), + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + mocker.patch.object(controller, "_submit_job", side_effect=capture_script) + mocker.patch.object(controller, "_wait_for_job", return_value=True) + mocker.patch.object(controller, "_mkdir_remote", return_value=None) mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) + # _get_nersc_username reads NERSC_USERNAME for IRIAPI; stub it + mocker.patch.object(controller, "_get_nersc_username", return_value="testuser") + result = controller.segmentation_sam3(recon_folder_path="folder/recfile") - assert captured, "submit_job was never called" + assert captured, "_submit_job was never called" assert mlflow_checkpoint in captured[0], ( "The MLflow checkpoint path must appear in the SLURM job script" ) assert result["success"] is True - def test_config_default_checkpoint_used_when_mlflow_unavailable( - self, mocker, mock_sfapi_client, mock_config832 - ): - """ - When _load_job_options returns the unmodified config default (MLflow absent), - the default checkpoint path should appear in the job script. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + def test_config_default_checkpoint_used_when_mlflow_unavailable(self, mocker, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch( "orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}, ) - # MLflow is unreachable; _load_job_options falls back to config mocker.patch( "orchestration.flows.bl832.nersc.get_checkpoint_info", return_value=None, ) captured = [] - original_job = mock_sfapi_client.compute.return_value.submit_job.return_value - def capture_script(script): + def capture_script(script, *args, **kwargs): captured.append(script) - return original_job - - mock_sfapi_client.compute.return_value.submit_job.side_effect = capture_script + return "12345" - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mocker.MagicMock(), + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + mocker.patch.object(controller, "_submit_job", side_effect=capture_script) + mocker.patch.object(controller, "_wait_for_job", return_value=True) + mocker.patch.object(controller, "_mkdir_remote", return_value=None) mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) + mocker.patch.object(controller, "_get_nersc_username", return_value="testuser") + controller.segmentation_sam3(recon_folder_path="folder/recfile") config_default = mock_config832.nersc_segment_sam3_settings["finetuned_checkpoint_path"] - assert captured, "submit_job was never called" + assert captured, "_submit_job was never called" assert config_default in captured[0], ( "Config default checkpoint path must be used when MLflow is unavailable" ) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 1ec12264..c5b9163f 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -239,11 +239,15 @@ def test_create_sfapi_client_missing_files(mocker): # ────────────────────────────────────────────────────────────────────────────── def test_build_multi_resolution_success(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) result = controller.build_multi_resolution(file_path="folder/file.h5") @@ -255,11 +259,15 @@ def test_build_multi_resolution_success(mocker, mock_sfapi_client, mock_config83 def test_build_multi_resolution_submission_failure(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) result = controller.build_multi_resolution(file_path="folder/file.h5") @@ -271,12 +279,16 @@ def test_build_multi_resolution_submission_failure(mocker, mock_sfapi_client, mo # ────────────────────────────────────────────────────────────────────────────── def test_segmentation_sam3_success(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) result = controller.segmentation_sam3(recon_folder_path="folder/recfile") @@ -290,12 +302,16 @@ def test_segmentation_sam3_success(mocker, mock_sfapi_client, mock_config832): def test_segmentation_sam3_submission_failure(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("GPU queue full") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) result = controller.segmentation_sam3(recon_folder_path="folder/recfile") @@ -306,7 +322,7 @@ def test_segmentation_sam3_submission_failure(mocker, mock_sfapi_client, mock_co def test_segmentation_sam3_uses_variable_options(mocker, mock_sfapi_client, mock_config832): """Custom Prefect variable options should be forwarded into the job script.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={ @@ -321,7 +337,11 @@ def test_segmentation_sam3_uses_variable_options(mocker, mock_sfapi_client, mock "checkpoint": "checkpoint_v7.pt", }) - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) captured_scripts = [] @@ -350,12 +370,16 @@ def capture_script(script): # ────────────────────────────────────────────────────────────────────────────── def test_segmentation_dinov3_success(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) result = controller.segmentation_dinov3(recon_folder_path="folder/recfile") @@ -366,12 +390,16 @@ def test_segmentation_dinov3_success(mocker, mock_sfapi_client, mock_config832): def test_segmentation_dinov3_submission_failure(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("No GPU nodes") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) result = controller.segmentation_dinov3(recon_folder_path="folder/recfile") assert result is False @@ -551,7 +579,7 @@ def test_segmentation_dinov3_output_paths(mocker, mock_sfapi_client, mock_config output_dir = ".../scratch/folder/segfile/dino" So the script contains "segfile" and "/dino", not a literal "/seg/" segment. """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) @@ -564,7 +592,11 @@ def capture(script): return original_return mock_sfapi_client.compute.return_value.submit_job.side_effect = capture - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) controller.segmentation_dinov3(recon_folder_path="folder/recfile") script = captured_scripts[0] @@ -577,12 +609,16 @@ def capture(script): # ────────────────────────────────────────────────────────────────────────────── def test_combine_segmentations_success(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) result = controller.combine_segmentations(recon_folder_path="folder/recfile") @@ -593,12 +629,16 @@ def test_combine_segmentations_success(mocker, mock_sfapi_client, mock_config832 def test_combine_segmentations_submission_failure(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Cluster down") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) result = controller.combine_segmentations(recon_folder_path="folder/recfile") @@ -607,7 +647,7 @@ def test_combine_segmentations_submission_failure(mocker, mock_sfapi_client, moc def test_combine_segmentations_script_references_sam3_and_dino(mocker, mock_sfapi_client, mock_config832): """The combination job script should reference both model output directories.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) @@ -620,7 +660,11 @@ def capture(script): return original_return mock_sfapi_client.compute.return_value.submit_job.side_effect = capture - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) controller.combine_segmentations(recon_folder_path="folder/recfile") script = captured_scripts[0] From 4f4a3d847146d7ab351b3810101376b6943bcdf0 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 7 May 2026 11:40:04 -0700 Subject: [PATCH 32/42] launch jobs with IRI API and a reservation --- orchestration/flows/bl832/nersc.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index b067156a..d3995e49 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -436,20 +436,12 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: if "reservation" in sbatch_values: attributes["reservation_id"] = sbatch_values["reservation"] - job_spec = { - "executable": "/bin/bash", - "arguments": ["-s"], - "pre_launch": script_body, - "resources": resources, - "attributes": attributes, -} - job_spec = { "executable": "/bin/bash", "arguments": ["-s"], # read script from stdin isn't supported, so... "pre_launch": script_body, # run the body here before the executable "resources": resources, - "attributes": attributes + "attributes": attributes, } if "stdout_path" in sbatch_values: From 2ff38764ea5d8c49d66447017758a0f408de8f5e Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 7 May 2026 15:48:29 -0700 Subject: [PATCH 33/42] fixing dino extra_flags bug --- orchestration/flows/bl832/nersc.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index d3995e49..cdd90dea 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1180,9 +1180,15 @@ def segmentation_dinov3( mlflow_checkpoint_key=spec.mlflow_checkpoint_key, ) - extra_flags = "\n".join( - f" {flag} {value} \\" for flag, value in spec.extra_cli_flags.items() - ) + # extra_flags = "\n".join( + # f" {flag} {value} \\" for flag, value in spec.extra_cli_flags.items() + # ) + + tail_args: list[str] = [] + for flag, value in spec.extra_cli_flags.items(): + tail_args.append(f"{flag} {value}") + tail_args.append("--save-overlay") + extra_flags = " \\\n ".join(tail_args) cfs_path = opts["cfs_path"] conda_env_path = opts["conda_env_path"] @@ -1279,7 +1285,6 @@ def segmentation_dinov3( --batch-size {batch_size} \\ --finetuned-checkpoint "{dino_checkpoint}" \\ {extra_flags} - --save-overlay SEG_STATUS=$? From 405b197f1431fdb92fbddfbbd4ac9c6dbdb2bdd9 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 7 May 2026 15:48:54 -0700 Subject: [PATCH 34/42] fixing globus token race condition when jobs are launch simultaneously --- orchestration/globus/get_globus_token.py | 40 ++++++++++++++++++------ 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/orchestration/globus/get_globus_token.py b/orchestration/globus/get_globus_token.py index f740a034..d987947e 100644 --- a/orchestration/globus/get_globus_token.py +++ b/orchestration/globus/get_globus_token.py @@ -94,17 +94,39 @@ def load_tokens(token_file: Path) -> dict | None: return json.load(f) +# def save_tokens(token_file: Path, tokens: dict) -> None: +# ensure_private_parent_dir(token_file) +# tmp = token_file.with_suffix(".tmp") +# with os.fdopen( +# os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), +# "w", +# encoding="utf-8", +# ) as f: +# json.dump(tokens, f, indent=2) +# os.replace(tmp, token_file) +# os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + def save_tokens(token_file: Path, tokens: dict) -> None: ensure_private_parent_dir(token_file) - tmp = token_file.with_suffix(".tmp") - with os.fdopen( - os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), - "w", - encoding="utf-8", - ) as f: - json.dump(tokens, f, indent=2) - os.replace(tmp, token_file) - os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + # Per-process unique tmp name to avoid races between concurrent writers + tmp = token_file.with_suffix(f".tmp.{os.getpid()}.{os.urandom(4).hex()}") + try: + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + except Exception: + # Clean up tmp if anything between open and replace failed + try: + tmp.unlink(missing_ok=True) + except OSError: + pass + raise def get_refresh_token(stored_tokens: dict) -> str | None: From 33c3d952cbb510e3c5dc4bafc704d712b1cefb73 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Sun, 10 May 2026 11:04:09 -0700 Subject: [PATCH 35/42] Adding frontend/prefect_runner.html --- config.yml | 8 +- frontend/prefect_runner.html | 1547 ++++++++++++++++++++++++++++++++++ 2 files changed, 1551 insertions(+), 4 deletions(-) create mode 100644 frontend/prefect_runner.html diff --git a/config.yml b/config.yml index aa1fe9a1..f0bd7dab 100644 --- a/config.yml +++ b/config.yml @@ -184,7 +184,7 @@ hpc_submission_settings832: # ── SLURM resource allocation ───────────────────────────────────────────── qos: regular account: amsc006 - reservation: "_CAP_SYNAPS_CPU4" + reservation: "_CAP_SYNAPS_CPU5" num_nodes: 4 cpus-per-task: 128 walltime: "0:30:00" @@ -202,7 +202,7 @@ hpc_submission_settings832: qos: regular account: amsc006 constraint: gpu - reservation: "_CAP_SYNAPS_GPU4" + reservation: "_CAP_SYNAPS_GPU5" num_nodes: 32 ntasks-per-node: 1 gpus-per-node: 4 @@ -233,7 +233,7 @@ hpc_submission_settings832: qos: regular account: amsc006 constraint: gpu - reservation: "_CAP_SYNAPS_GPU4" + reservation: "_CAP_SYNAPS_GPU5" num_nodes: 8 ntasks-per-node: 1 nproc_per_node: 4 @@ -253,7 +253,7 @@ hpc_submission_settings832: qos: regular account: amsc006 constraint: cpu - reservation: "_CAP_SYNAPS_CPU4" + reservation: "_CAP_SYNAPS_CPU5" num_nodes: 4 ntasks: 128 cpus-per-task: 1 diff --git a/frontend/prefect_runner.html b/frontend/prefect_runner.html new file mode 100644 index 00000000..319875bd --- /dev/null +++ b/frontend/prefect_runner.html @@ -0,0 +1,1547 @@ + + + + + +AmSC Tomography Pipeline + + + + + + + +
+
AmSC. Tomography Pipeline
+
nersc_petiole_segment_flow
+
+
disconnected
+
+ +
+ + +
+
+ Step 1 +

Connect to Prefect

+
+
+ + + +
+ + +
+ + + +
+ + +
+ +
+
+
+ + + + + + + +
+ +
+ Built for the Photon Science Computing group · Lawrence Berkeley National Laboratory +
+ + + + + \ No newline at end of file From 81bf47d4f48eea6ca0ca5d837fdb76d162f30fb2 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 12 May 2026 07:33:21 -0700 Subject: [PATCH 36/42] updating html page with a timer and collapsible logs --- frontend/prefect_runner.html | 223 ++++++++++++++++++++++++++++------- 1 file changed, 181 insertions(+), 42 deletions(-) diff --git a/frontend/prefect_runner.html b/frontend/prefect_runner.html index 319875bd..ade52803 100644 --- a/frontend/prefect_runner.html +++ b/frontend/prefect_runner.html @@ -429,6 +429,21 @@ .step.completed .step-detail, .step.failed .step-detail, .step.crashed .step-detail { fill: rgba(255,255,255,0.85); } + /* Timer text. Slightly bolder and bigger than detail lines because it's + the most "live" thing on the diagram — the ticking number that says + "this is actually working." Tabular-nums keeps digit width stable so the + line doesn't twitch as the seconds count up. */ + .step-timer { + font-family: "DM Mono", monospace; + font-size: 13px; + font-weight: 500; + font-variant-numeric: tabular-nums; + fill: var(--slate-dark); + } + .step.running .step-timer, + .step.completed .step-timer, + .step.failed .step-timer, + .step.crashed .step-timer { fill: #fff; } .substep-detail { font-family: "DM Mono", monospace; font-size: 9px; @@ -438,6 +453,17 @@ .substep.completed .substep-detail, .substep.failed .substep-detail, .substep.crashed .substep-detail { fill: rgba(255,255,255,0.85); } + .substep-timer { + font-family: "DM Mono", monospace; + font-size: 10px; + font-weight: 500; + font-variant-numeric: tabular-nums; + fill: var(--slate-dark); + } + .substep.running .substep-timer, + .substep.completed .substep-timer, + .substep.failed .substep-timer, + .substep.crashed .substep-timer { fill: #fff; } .step-arrow { fill: none; stroke: var(--finch-neutral-line); @@ -678,10 +704,11 @@

Monitor flow run

- Reconstruction - IRI API - - + Reconstruction + IRI API + + + @@ -694,29 +721,32 @@

Monitor flow run

Segmentation IRI API + - + - - SAM3 - + + SAM3 + + - - DINOv3 - + + DINOv3 + + - - AmSC MLflow + + AmSC MLflow - + @@ -727,10 +757,11 @@

Monitor flow run

- Combine - IRI API - - + Combine + IRI API + + + @@ -744,6 +775,7 @@

Monitor flow run

+
Waiting for logs… @@ -1233,12 +1265,20 @@

Monitor flow run

------------------------------------------------------------- */ const STEP_STATE_CLASSES = ["pending", "running", "completed", "failed", "crashed", "cancelled", "scheduled"]; -// task name → NERSC job ID, accumulated across polls. Logs eventually scroll -// off the 200-line poll window, so we cache here once we see the job ID. -// Reset by resetPipeline() between runs. -const jobIdCache = {}; +// task name → { jobId, submittedAt } accumulated across polls. Both pieces +// come from the same "Submitted job ID: N" log line — we capture its +// timestamp as the per-step timer's start point (option B: "time since +// submission to NERSC"). Logs land in our poll independently of task_run +// row updates, so this also dodges a Prefect-side race where a task can be +// RUNNING in state but still have null start_time in the task_runs table, +// which was leaving the second-to-start substep's timer blank. +const submissionCache = {}; // task_run_id → task name, also cached. Used to attribute log lines to a task. const taskIdToName = {}; +// task name → {start: ISO string|null, end: ISO string|null}. +// Populated from task_runs each poll. We use this for end_time only; +// start_time comes from submissionCache via the log line. +const timingCache = {}; function setStepState(elementId, stateType) { const el = document.getElementById(elementId); @@ -1252,11 +1292,75 @@

Monitor flow run

if (el) el.textContent = text || ""; } +// Format an elapsed-seconds count as MM:SS or H:MM:SS. Choices: pad seconds to +// 2 digits so the display doesn't jitter; only show hours when over an hour +// so short tasks read cleanly. +function formatDuration(seconds) { + if (seconds < 0 || !isFinite(seconds)) return ""; + const s = Math.floor(seconds); + const h = Math.floor(s / 3600); + const m = Math.floor((s % 3600) / 60); + const r = s % 60; + const mm = String(m).padStart(2, "0"); + const rr = String(r).padStart(2, "0"); + return h > 0 ? `${h}:${mm}:${rr}` : `${mm}:${rr}`; +} + +// Render timer text for one task. Start time comes from the "Submitted job ID" +// log line (i.e. when the NERSC job was actually queued); end time comes from +// the Prefect task's end_time. Tasks that haven't submitted yet stay blank. +function updateTimerForTask(elementId, taskName) { + const sub = submissionCache[taskName]; + if (!sub || !sub.submittedAt) { setDetailText(elementId, ""); return; } + const start = Date.parse(sub.submittedAt); + const timing = timingCache[taskName]; + const end = (timing && timing.end) ? Date.parse(timing.end) : Date.now(); + if (Number.isNaN(start)) { setDetailText(elementId, ""); return; } + setDetailText(elementId, formatDuration((end - start) / 1000)); +} + +// Render the segmentation parent timer. Wall-clock from the earliest substep +// submission to the latest substep end (or now if either is still running). +function updateSegmentationTimer() { + const sam = submissionCache["nersc_segmentation_sam3_task"]; + const dino = submissionCache["nersc_segmentation_dinov3_task"]; + const starts = [sam && sam.submittedAt, dino && dino.submittedAt] + .filter(Boolean).map(Date.parse).filter((n) => !Number.isNaN(n)); + if (starts.length === 0) { setDetailText("detail-seg-timer", ""); return; } + const earliestStart = Math.min(...starts); + // For end: if a substep submitted but hasn't ended, treat its end as "now"; + // if a substep never submitted, ignore it from the max. + const samEnd = (sam && sam.submittedAt) + ? ((timingCache["nersc_segmentation_sam3_task"] || {}).end + ? Date.parse(timingCache["nersc_segmentation_sam3_task"].end) + : Date.now()) + : -Infinity; + const dinoEnd = (dino && dino.submittedAt) + ? ((timingCache["nersc_segmentation_dinov3_task"] || {}).end + ? Date.parse(timingCache["nersc_segmentation_dinov3_task"].end) + : Date.now()) + : -Infinity; + const latestEnd = Math.max(samEnd, dinoEnd); + setDetailText("detail-seg-timer", formatDuration((latestEnd - earliestStart) / 1000)); +} + +// Refresh every visible timer from the current timingCache. Called every 1s +// while a run is being monitored AND once per poll (so end_times get applied +// immediately when a task completes between ticks). +function renderTimers() { + updateTimerForTask("detail-recon-timer", "nersc_reconstruction_task"); + updateTimerForTask("detail-sam3-timer", "nersc_segmentation_sam3_task"); + updateTimerForTask("detail-dino-timer", "nersc_segmentation_dinov3_task"); + updateTimerForTask("detail-combine-timer", "nersc_combine_segmentations_task"); + updateSegmentationTimer(); +} + // Given a list of log records and the current task-id-to-name map, parse out -// any "Submitted job ID: N" lines and update the jobIdCache. Idempotent — -// safe to call on every poll. We attribute by task_run_id so that a single -// `Submitted job ID:` log line is correctly tied to its submitting task, -// even if multiple tasks run concurrently. +// any "Submitted job ID: N" lines and update submissionCache. We capture both +// the job ID and the log line's timestamp (which is when the job was actually +// submitted from the worker — more accurate than Date.now() since a log line +// may arrive in a later poll than when it was emitted). +// Idempotent: safe to call on every poll; we only write the first sighting. function harvestJobIdsFromLogs(logs) { const jobIdRe = /Submitted job ID:\s*(\S+)/; for (const log of logs) { @@ -1264,8 +1368,11 @@

Monitor flow run

if (!m) continue; const taskId = log.task_run_id; const taskName = taskId && taskIdToName[taskId]; - if (taskName && !jobIdCache[taskName]) { - jobIdCache[taskName] = m[1]; + if (taskName && !submissionCache[taskName]) { + submissionCache[taskName] = { + jobId: m[1], + submittedAt: log.timestamp || new Date().toISOString(), + }; } } } @@ -1303,6 +1410,12 @@

Monitor flow run

byNameMessage[name] = tr.state.message; } if (tr.id) taskIdToName[tr.id] = name; + // Cache end_time only — start_time comes from the "Submitted job ID" log + // line, captured by harvestJobIdsFromLogs(). Logs aren't subject to the + // brief delay between a task transitioning to RUNNING and Prefect + // populating start_time on the task_run row, which was causing the + // second-to-start parallel substep's timer to stay blank. + timingCache[name] = { end: tr.end_time || null }; } // Now that taskIdToName is up to date, parse fresh job IDs from any logs. @@ -1312,8 +1425,8 @@

Monitor flow run

const reconState = byName["nersc_reconstruction_task"]; setStepState("step-recon", reconState); setStepState("arrow-recon-seg", reconState === "COMPLETED" ? "completed" : reconState); - setDetailText("detail-recon-job", jobIdCache["nersc_reconstruction_task"] - ? "Job " + jobIdCache["nersc_reconstruction_task"] : ""); + setDetailText("detail-recon-job", submissionCache["nersc_reconstruction_task"] + ? "Job " + submissionCache["nersc_reconstruction_task"].jobId : ""); setDetailText("detail-recon-msg", reconState === "RUNNING" ? (byNameMessage["nersc_reconstruction_task"] || "") : ""); @@ -1322,10 +1435,10 @@

Monitor flow run

const dinoState = byName["nersc_segmentation_dinov3_task"]; setStepState("substep-sam3", samState); setStepState("substep-dino", dinoState); - setDetailText("detail-sam3-job", jobIdCache["nersc_segmentation_sam3_task"] - ? "Job " + jobIdCache["nersc_segmentation_sam3_task"] : ""); - setDetailText("detail-dino-job", jobIdCache["nersc_segmentation_dinov3_task"] - ? "Job " + jobIdCache["nersc_segmentation_dinov3_task"] : ""); + setDetailText("detail-sam3-job", submissionCache["nersc_segmentation_sam3_task"] + ? "Job " + submissionCache["nersc_segmentation_sam3_task"].jobId : ""); + setDetailText("detail-dino-job", submissionCache["nersc_segmentation_dinov3_task"] + ? "Job " + submissionCache["nersc_segmentation_dinov3_task"].jobId : ""); const segState = deriveSegmentationState(samState, dinoState); setStepState("step-seg", segState); @@ -1342,10 +1455,14 @@

Monitor flow run

// Step 3: combine const combineState = byName["nersc_combine_segmentations_task"]; setStepState("step-combine", combineState); - setDetailText("detail-combine-job", jobIdCache["nersc_combine_segmentations_task"] - ? "Job " + jobIdCache["nersc_combine_segmentations_task"] : ""); + setDetailText("detail-combine-job", submissionCache["nersc_combine_segmentations_task"] + ? "Job " + submissionCache["nersc_combine_segmentations_task"].jobId : ""); setDetailText("detail-combine-msg", combineState === "RUNNING" ? (byNameMessage["nersc_combine_segmentations_task"] || "") : ""); + + // Render timers immediately so newly-finalized end_times take effect on the + // same poll, not on the next 1s tick. + renderTimers(); } function resetPipeline() { @@ -1354,14 +1471,16 @@

Monitor flow run

"arrow-recon-seg", "arrow-seg-combine"]) { setStepState(id, null); } - for (const id of ["detail-recon-job", "detail-recon-msg", - "detail-sam3-job", "detail-dino-job", - "detail-seg-msg", - "detail-combine-job", "detail-combine-msg"]) { + for (const id of ["detail-recon-job", "detail-recon-msg", "detail-recon-timer", + "detail-sam3-job", "detail-sam3-timer", + "detail-dino-job", "detail-dino-timer", + "detail-seg-msg", "detail-seg-timer", + "detail-combine-job", "detail-combine-msg", "detail-combine-timer"]) { setDetailText(id, ""); } - for (const k of Object.keys(jobIdCache)) delete jobIdCache[k]; + for (const k of Object.keys(submissionCache)) delete submissionCache[k]; for (const k of Object.keys(taskIdToName)) delete taskIdToName[k]; + for (const k of Object.keys(timingCache)) delete timingCache[k]; const mlflow = document.getElementById("mlflow-badge"); if (mlflow) mlflow.classList.remove("active"); } @@ -1370,6 +1489,7 @@

Monitor flow run

Monitor: poll flow run state + logs + task runs ------------------------------------------------------------- */ let pollHandle = null; +let timerHandle = null; // 1s tick for live elapsed-time rendering let pollPaused = false; let seenLogIds = new Set(); let lastState = null; @@ -1496,9 +1616,15 @@

Monitor flow run

$("btn-pause-poll").textContent = "Pause"; pollOnce(flowRunId); pollHandle = setInterval(() => pollOnce(flowRunId), POLL_INTERVAL_MS); + // Independent 1s tick for the live timers. Cheap — just reads timingCache + // and updates text nodes. Stopped together with pollHandle on terminal state. + timerHandle = setInterval(renderTimers, 1000); } function stopPolling() { - if (pollHandle) { clearInterval(pollHandle); pollHandle = null; } + if (pollHandle) { clearInterval(pollHandle); pollHandle = null; } + if (timerHandle) { clearInterval(timerHandle); timerHandle = null; } + // Final render so end_times from the last poll are shown stably. + renderTimers(); } function startMonitoring(flowRunId) { @@ -1522,6 +1648,19 @@

Monitor flow run

$("btn-clear-logs").addEventListener("click", () => { resetLogPane(); }); +$("btn-toggle-logs").addEventListener("click", (e) => { + const pane = $("log-pane"); + const collapsed = pane.style.display === "none"; + if (collapsed) { + pane.style.display = ""; + e.target.textContent = "Collapse"; + // Re-pin scroll to bottom on expand so newly-arrived logs are in view. + pane.scrollTop = pane.scrollHeight; + } else { + pane.style.display = "none"; + e.target.textContent = "Expand"; + } +}); $("btn-cancel").addEventListener("click", async () => { if (!currentRunId) return; if (!confirm("Cancel this flow run?")) return; From 273593f25ff4809229567beb609bb6f8391029de Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 12 May 2026 07:33:45 -0700 Subject: [PATCH 37/42] updating config with confab reservation --- config.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config.yml b/config.yml index f0bd7dab..7632d213 100644 --- a/config.yml +++ b/config.yml @@ -184,15 +184,15 @@ hpc_submission_settings832: # ── SLURM resource allocation ───────────────────────────────────────────── qos: regular account: amsc006 - reservation: "_CAP_SYNAPS_CPU5" - num_nodes: 4 + reservation: "_CAP_SYNAPS_LIVEDEMO_CPU1" + num_nodes: 16 cpus-per-task: 128 walltime: "0:30:00" nersc_multiresolution: # ── SLURM resource allocation ───────────────────────────────────────────── qos: debug account: als - reservation: "" + reservation: "_CAP_SYNAPS_LIVEDEMO_CPU1" cpus-per-task: 128 walltime: "0:15:00" @@ -202,7 +202,7 @@ hpc_submission_settings832: qos: regular account: amsc006 constraint: gpu - reservation: "_CAP_SYNAPS_GPU5" + reservation: "_CAP_SYNAPS_LIVEDEMO_GPU1" num_nodes: 32 ntasks-per-node: 1 gpus-per-node: 4 @@ -233,7 +233,7 @@ hpc_submission_settings832: qos: regular account: amsc006 constraint: gpu - reservation: "_CAP_SYNAPS_GPU5" + reservation: "_CAP_SYNAPS_LIVEDEMO_GPU1" num_nodes: 8 ntasks-per-node: 1 nproc_per_node: 4 @@ -253,7 +253,7 @@ hpc_submission_settings832: qos: regular account: amsc006 constraint: cpu - reservation: "_CAP_SYNAPS_CPU5" + reservation: "_CAP_SYNAPS_LIVEDEMO_CPU1" num_nodes: 4 ntasks: 128 cpus-per-task: 1 From b7f57e1252f0a5aab5c25a4bdc1bf68eb453ccf1 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 20 May 2026 13:07:09 -0700 Subject: [PATCH 38/42] Separated out general MLflow tests (non-specific to beamlines) --- orchestration/_tests/test_mlflow.py | 542 ++++++++++++++++++++++++++++ 1 file changed, 542 insertions(+) create mode 100644 orchestration/_tests/test_mlflow.py diff --git a/orchestration/_tests/test_mlflow.py b/orchestration/_tests/test_mlflow.py new file mode 100644 index 00000000..e9f0e3de --- /dev/null +++ b/orchestration/_tests/test_mlflow.py @@ -0,0 +1,542 @@ +# orchestration/_tests/test_mlflow.py +# +# Tests for orchestration/mlflow.py — the beamline-agnostic helper that +# wraps the MLflow Model Registry for checkpoint metadata lookup. +# +# Beamline-specific tests (e.g. _load_job_options, segmentation_sam3) +# live in _tests/test_bl832/test_mlflow.py. + +import json +import pytest +from uuid import uuid4 + +import mlflow.utils.rest_utils as rest_utils + +from prefect.blocks.system import Secret +from prefect.testing.utilities import prefect_test_harness + + +# ────────────────────────────────────────────────────────────────────────────── +# Session fixture +# ────────────────────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True, scope="session") +def prefect_test_fixture(): + with prefect_test_harness(): + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) + yield + + +# ────────────────────────────────────────────────────────────────────────────── +# Shared fixtures +# ────────────────────────────────────────────────────────────────────────────── + +@pytest.fixture +def mock_beamline_config(mocker): + """Minimal BeamlineConfig mock with mlflow tracking_uri.""" + config = mocker.MagicMock() + config.mlflow = {"tracking_uri": "http://mock-mlflow:5000"} + return config + + +def _make_model_version(mocker, *, version="1", tags=None): + """Helper: build a mock MlflowClient model version object.""" + mv = mocker.MagicMock() + mv.version = version + mv.tags = tags or {} + return mv + + +# ────────────────────────────────────────────────────────────────────────────── +# get_checkpoint_info +# ────────────────────────────────────────────────────────────────────────────── + +class TestGetCheckpointInfo: + + def test_returns_checkpoint_info_when_mlflow_reachable(self, mocker, mock_beamline_config): + """Happy path: reachable server + valid production alias → ModelCheckpointInfo.""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + + mv = _make_model_version(mocker, version="3", tags={ + "nersc_path": "/cfs/checkpoints/sam3_v3.pt", + "alcf_path": "/eagle/checkpoints/sam3_v3.pt", + "batch_size": "2", + "prompts": json.dumps(["cell wall", "lumen"]), + }) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.return_value = mv + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config, alias="production") + + assert info is not None + assert info.model_name == "sam3-petiole" + assert info.version == "3" + assert info.alias == "production" + assert info.nersc_path == "/cfs/checkpoints/sam3_v3.pt" + assert info.alcf_path == "/eagle/checkpoints/sam3_v3.pt" + + def test_deserializes_json_inference_params(self, mocker, mock_beamline_config): + """JSON-encoded tag values (lists, dicts) are decoded into Python objects.""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + mv = _make_model_version(mocker, tags={ + "nersc_path": "/cfs/sam3.pt", + "prompts": json.dumps(["cell wall", "lumen"]), + "confidence": json.dumps([0.6, 0.7]), + "batch_size": "4", + }) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.return_value = mv + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert info.inference_params["prompts"] == ["cell wall", "lumen"] + assert info.inference_params["confidence"] == [0.6, 0.7] + assert info.inference_params["batch_size"] == 4 # "4" is valid JSON → int + + def test_returns_none_when_mlflow_unreachable(self, mocker, mock_beamline_config): + """Unreachable tracking server → None (caller falls back to config defaults).""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=False) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert info is None + + def test_returns_none_when_alias_not_found(self, mocker, mock_beamline_config): + """Missing production alias → MlflowException → None.""" + from orchestration.mlflow import get_checkpoint_info + import mlflow.exceptions + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.side_effect = ( + mlflow.exceptions.MlflowException("Alias not found") + ) + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert info is None + + def test_returns_none_when_nersc_path_tag_missing(self, mocker, mock_beamline_config): + """A model version without 'nersc_path' tag → None.""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + mv = _make_model_version(mocker, tags={"alcf_path": "/eagle/sam3.pt"}) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.return_value = mv + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert info is None + + def test_nersc_and_alcf_paths_excluded_from_inference_params(self, mocker, mock_beamline_config): + """nersc_path and alcf_path must NOT appear in inference_params.""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + mv = _make_model_version(mocker, tags={ + "nersc_path": "/cfs/sam3.pt", + "alcf_path": "/eagle/sam3.pt", + "batch_size": "2", + }) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.return_value = mv + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert "nersc_path" not in info.inference_params + assert "alcf_path" not in info.inference_params + assert "batch_size" in info.inference_params + + +# ────────────────────────────────────────────────────────────────────────────── +# _is_mlflow_reachable +# ────────────────────────────────────────────────────────────────────────────── + +class TestIsMlflowReachable: + def test_returns_true_when_health_200(self, mocker): + from orchestration.mlflow import _is_mlflow_reachable + mock_resp = mocker.MagicMock() + mock_resp.status_code = 200 + mocker.patch("requests.get", return_value=mock_resp) + assert _is_mlflow_reachable("http://mlflow:5000") is True + + def test_returns_false_when_non_200(self, mocker): + from orchestration.mlflow import _is_mlflow_reachable + mock_resp = mocker.MagicMock() + mock_resp.status_code = 503 + mocker.patch("requests.get", return_value=mock_resp) + assert _is_mlflow_reachable("http://mlflow:5000") is False + + def test_returns_false_on_exception(self, mocker): + from orchestration.mlflow import _is_mlflow_reachable + mocker.patch("requests.get", side_effect=Exception("timeout")) + assert _is_mlflow_reachable("http://mlflow:5000") is False + + def test_sends_api_key_header_when_env_set(self, mocker, monkeypatch): + from orchestration.mlflow import _is_mlflow_reachable + monkeypatch.setenv("AMSC_API_KEY", "test-key-123") + mock_resp = mocker.MagicMock() + mock_resp.status_code = 200 + mock_get = mocker.patch("requests.get", return_value=mock_resp) + _is_mlflow_reachable("http://mlflow:5000") + assert mock_get.call_args.kwargs["headers"].get("X-Api-Key") == "test-key-123" + + def test_omits_api_key_header_when_env_unset(self, mocker, monkeypatch): + from orchestration.mlflow import _is_mlflow_reachable + monkeypatch.delenv("AMSC_API_KEY", raising=False) + mock_resp = mocker.MagicMock() + mock_resp.status_code = 200 + mock_get = mocker.patch("requests.get", return_value=mock_resp) + _is_mlflow_reachable("http://mlflow:5000") + assert "X-Api-Key" not in mock_get.call_args.kwargs["headers"] + + +# ────────────────────────────────────────────────────────────────────────────── +# _enable_amsc_x_api_key +# ────────────────────────────────────────────────────────────────────────────── + +@pytest.fixture +def reset_amsc_patch(): + """Save and restore rest_utils.http_request to isolate AMSC patch state.""" + original = rest_utils.http_request + yield + rest_utils.http_request = original + + +class TestEnableAmscXApiKey: + def test_returns_false_when_key_unset(self, monkeypatch, reset_amsc_patch): + from orchestration.mlflow import _enable_amsc_x_api_key, _AMSC_PATCH_FLAG + monkeypatch.delenv("AMSC_API_KEY", raising=False) + assert _enable_amsc_x_api_key() is False + assert not getattr(rest_utils.http_request, _AMSC_PATCH_FLAG, False) + + def test_returns_true_and_patches_when_key_set(self, monkeypatch, reset_amsc_patch): + from orchestration.mlflow import _enable_amsc_x_api_key, _AMSC_PATCH_FLAG + monkeypatch.setenv("AMSC_API_KEY", "test-key") + assert _enable_amsc_x_api_key() is True + assert getattr(rest_utils.http_request, _AMSC_PATCH_FLAG, False) is True + + def test_idempotent_second_call_does_not_rewrap(self, monkeypatch, reset_amsc_patch): + from orchestration.mlflow import _enable_amsc_x_api_key + monkeypatch.setenv("AMSC_API_KEY", "test-key") + _enable_amsc_x_api_key() + patched_once = rest_utils.http_request + _enable_amsc_x_api_key() + assert rest_utils.http_request is patched_once + + def test_injects_key_via_extra_headers_when_no_headers_kwarg(self, mocker, monkeypatch, reset_amsc_patch): + from orchestration.mlflow import _enable_amsc_x_api_key + monkeypatch.setenv("AMSC_API_KEY", "my-key") + spy = mocker.MagicMock() + spy._amsc_x_api_key_patched = False # prevent MagicMock auto-attr from being truthy + rest_utils.http_request = spy # captured as 'original' by the patch closure + _enable_amsc_x_api_key() + + rest_utils.http_request(mocker.MagicMock(), "/api", "GET") + + assert spy.call_args.kwargs.get("extra_headers", {}).get("X-Api-Key") == "my-key" + + def test_injects_key_into_existing_headers_kwarg(self, mocker, monkeypatch, reset_amsc_patch): + from orchestration.mlflow import _enable_amsc_x_api_key + monkeypatch.setenv("AMSC_API_KEY", "my-key") + spy = mocker.MagicMock() + spy._amsc_x_api_key_patched = False # prevent MagicMock auto-attr from being truthy + rest_utils.http_request = spy + _enable_amsc_x_api_key() + + rest_utils.http_request( + mocker.MagicMock(), "/api", "GET", headers={"Content-Type": "application/json"} + ) + + assert spy.call_args.kwargs["headers"]["X-Api-Key"] == "my-key" + assert spy.call_args.kwargs["headers"]["Content-Type"] == "application/json" + + +# ────────────────────────────────────────────────────────────────────────────── +# get_mlflow_client +# ────────────────────────────────────────────────────────────────────────────── + +class TestGetMlflowClient: + def test_returns_client_with_tracking_uri(self, mocker, mock_beamline_config): + from orchestration.mlflow import get_mlflow_client + mock_enable = mocker.patch("orchestration.mlflow._enable_amsc_x_api_key") + mock_set_uri = mocker.patch("mlflow.set_tracking_uri") + mock_client_cls = mocker.patch("orchestration.mlflow.MlflowClient") + + result = get_mlflow_client(mock_beamline_config) + + mock_enable.assert_called_once() + mock_set_uri.assert_called_once_with("http://mock-mlflow:5000") + mock_client_cls.assert_called_once_with(tracking_uri="http://mock-mlflow:5000") + assert result is mock_client_cls.return_value + + +# ────────────────────────────────────────────────────────────────────────────── +# register_checkpoint +# ────────────────────────────────────────────────────────────────────────────── + +def _setup_register_mocks(mocker, mock_beamline_config, *, version="1", existing_model=True, existing_experiment=True): + """Wire up standard mocks for register_checkpoint; returns mock_client.""" + import mlflow.exceptions + mock_client = mocker.MagicMock() + if not existing_model: + mock_client.get_registered_model.side_effect = mlflow.exceptions.MlflowException("not found") + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + mocker.patch("mlflow.set_tracking_uri") + + if existing_experiment: + mock_exp = mocker.MagicMock() + mock_exp.experiment_id = "exp-1" + mocker.patch("mlflow.get_experiment_by_name", return_value=mock_exp) + else: + mocker.patch("mlflow.get_experiment_by_name", return_value=None) + mocker.patch("mlflow.create_experiment", return_value="exp-new") + + mock_run = mocker.MagicMock() + mock_run.info.run_id = "run-abc-123" + mock_start = mocker.patch("mlflow.start_run") + mock_start.return_value.__enter__.return_value = mock_run + mock_start.return_value.__exit__.return_value = False + mocker.patch("mlflow.log_param") + mocker.patch("mlflow.log_params") + + mock_mv = mocker.MagicMock() + mock_mv.version = version + mocker.patch("mlflow.register_model", return_value=mock_mv) + return mock_client + + +class TestRegisterCheckpoint: + def test_happy_path_returns_version(self, mocker, mock_beamline_config): + from orchestration.mlflow import register_checkpoint + mock_client = _setup_register_mocks(mocker, mock_beamline_config) + + version = register_checkpoint("sam3-petiole", "/cfs/sam3.pt", mock_beamline_config) + + assert version == "1" + mock_client.set_registered_model_alias.assert_called_once_with( + "sam3-petiole", "production", "1" + ) + mock_client.set_model_version_tag.assert_any_call( + "sam3-petiole", "1", "nersc_path", "/cfs/sam3.pt" + ) + + def test_creates_registered_model_when_not_found(self, mocker, mock_beamline_config): + from orchestration.mlflow import register_checkpoint + mock_client = _setup_register_mocks(mocker, mock_beamline_config, existing_model=False) + + register_checkpoint("sam3-petiole", "/cfs/sam3.pt", mock_beamline_config) + + mock_client.create_registered_model.assert_called_once_with("sam3-petiole") + + def test_skips_create_model_when_already_exists(self, mocker, mock_beamline_config): + from orchestration.mlflow import register_checkpoint + mock_client = _setup_register_mocks(mocker, mock_beamline_config, existing_model=True) + + register_checkpoint("sam3-petiole", "/cfs/sam3.pt", mock_beamline_config) + + mock_client.create_registered_model.assert_not_called() + + def test_creates_experiment_when_not_found(self, mocker, mock_beamline_config): + from orchestration.mlflow import register_checkpoint + _setup_register_mocks(mocker, mock_beamline_config, existing_experiment=False) + # Re-patch to get a reference for assertion (second patch wins) + mock_create_exp = mocker.patch("mlflow.create_experiment", return_value="exp-new") + + register_checkpoint("sam3-petiole", "/cfs/sam3.pt", mock_beamline_config) + + mock_create_exp.assert_called_once() + + def test_alcf_path_tag_set_when_provided(self, mocker, mock_beamline_config): + from orchestration.mlflow import register_checkpoint + mock_client = _setup_register_mocks(mocker, mock_beamline_config) + + register_checkpoint( + "sam3-petiole", "/cfs/sam3.pt", mock_beamline_config, alcf_path="/eagle/sam3.pt" + ) + + mock_client.set_model_version_tag.assert_any_call( + "sam3-petiole", "1", "alcf_path", "/eagle/sam3.pt" + ) + + def test_alcf_path_tag_omitted_when_empty(self, mocker, mock_beamline_config): + from orchestration.mlflow import register_checkpoint + mock_client = _setup_register_mocks(mocker, mock_beamline_config) + + register_checkpoint("sam3-petiole", "/cfs/sam3.pt", mock_beamline_config, alcf_path="") + + tag_names = [c.args[2] for c in mock_client.set_model_version_tag.call_args_list] + assert "alcf_path" not in tag_names + + def test_inference_params_list_json_encoded(self, mocker, mock_beamline_config): + from orchestration.mlflow import register_checkpoint + mock_client = _setup_register_mocks(mocker, mock_beamline_config) + + register_checkpoint( + "sam3-petiole", + "/cfs/sam3.pt", + mock_beamline_config, + inference_params={"prompts": ["cell wall", "lumen"], "batch_size": 4}, + ) + + tag_calls = {c.args[2]: c.args[3] for c in mock_client.set_model_version_tag.call_args_list} + assert tag_calls["prompts"] == json.dumps(["cell wall", "lumen"]) + assert tag_calls["batch_size"] == "4" + + def test_no_inference_params_skips_tag_loop(self, mocker, mock_beamline_config): + from orchestration.mlflow import register_checkpoint + mock_client = _setup_register_mocks(mocker, mock_beamline_config) + + register_checkpoint("sam3-petiole", "/cfs/sam3.pt", mock_beamline_config) + + tag_names = {c.args[2] for c in mock_client.set_model_version_tag.call_args_list} + assert tag_names == {"nersc_path"} + + +# ────────────────────────────────────────────────────────────────────────────── +# log_segmentation_metrics +# ────────────────────────────────────────────────────────────────────────────── + +def _setup_log_metrics_mocks(mocker, mock_beamline_config): + """Wire up standard mocks for log_segmentation_metrics; returns mock_run.""" + mocker.patch("mlflow.set_tracking_uri") + mocker.patch("orchestration.mlflow._enable_amsc_x_api_key") + mock_exp = mocker.MagicMock() + mock_exp.experiment_id = "exp-1" + mocker.patch("mlflow.get_experiment_by_name", return_value=mock_exp) + + mock_run = mocker.MagicMock() + mock_run.info.run_id = "run-xyz-999" + mock_start = mocker.patch("mlflow.start_run") + mock_start.return_value.__enter__.return_value = mock_run + mock_start.return_value.__exit__.return_value = False + + mocker.patch("mlflow.log_param") + mocker.patch("mlflow.log_params") + mocker.patch("mlflow.log_metrics") + return mock_run + + +class TestLogSegmentationMetrics: + def test_happy_path_returns_run_id(self, mocker, mock_beamline_config): + from orchestration.mlflow import log_segmentation_metrics + _setup_log_metrics_mocks(mocker, mock_beamline_config) + + run_id = log_segmentation_metrics("seg-run-1", "sam3", "job-42", mock_beamline_config) + + assert run_id == "run-xyz-999" + + def test_logs_slurm_job_id_and_model_params(self, mocker, mock_beamline_config): + from orchestration.mlflow import log_segmentation_metrics + _setup_log_metrics_mocks(mocker, mock_beamline_config) + mock_log_param = mocker.patch("mlflow.log_param") + + log_segmentation_metrics("seg-run-1", "sam3", "job-42", mock_beamline_config) + + mock_log_param.assert_any_call("slurm_job_id", "job-42") + mock_log_param.assert_any_call("model", "sam3") + + def test_full_timing_dict_logged_as_metrics(self, mocker, mock_beamline_config): + from orchestration.mlflow import log_segmentation_metrics + _setup_log_metrics_mocks(mocker, mock_beamline_config) + mock_log_metrics = mocker.patch("mlflow.log_metrics") + + timing = { + "total_seconds": 120.5, + "num_images": 50, + "throughput": 25.0, + "time_per_image": "2.41s", + } + log_segmentation_metrics("seg-run-1", "sam3", "job-42", mock_beamline_config, timing=timing) + + logged = mock_log_metrics.call_args.args[0] + assert logged["total_seconds"] == 120.5 + assert logged["num_images"] == 50.0 + assert logged["throughput_images_per_min"] == 25.0 + assert logged["time_per_image_seconds"] == pytest.approx(2.41) + + def test_time_per_image_unit_stripped(self, mocker, mock_beamline_config): + from orchestration.mlflow import log_segmentation_metrics + _setup_log_metrics_mocks(mocker, mock_beamline_config) + mock_log_metrics = mocker.patch("mlflow.log_metrics") + + log_segmentation_metrics( + "seg-run-1", "sam3", "job-42", mock_beamline_config, + timing={"time_per_image": "3.23s"}, + ) + + logged = mock_log_metrics.call_args.args[0] + assert logged["time_per_image_seconds"] == pytest.approx(3.23) + + def test_non_numeric_time_per_image_omitted(self, mocker, mock_beamline_config): + from orchestration.mlflow import log_segmentation_metrics + _setup_log_metrics_mocks(mocker, mock_beamline_config) + mock_log_metrics = mocker.patch("mlflow.log_metrics") + + log_segmentation_metrics( + "seg-run-1", "sam3", "job-42", mock_beamline_config, + timing={"time_per_image": "N/A"}, + ) + + mock_log_metrics.assert_not_called() + + def test_parent_run_id_sets_nested_true(self, mocker, mock_beamline_config): + from orchestration.mlflow import log_segmentation_metrics + _setup_log_metrics_mocks(mocker, mock_beamline_config) + mock_start = mocker.patch("mlflow.start_run") + mock_run = mocker.MagicMock() + mock_run.info.run_id = "child-run" + mock_start.return_value.__enter__.return_value = mock_run + mock_start.return_value.__exit__.return_value = False + + log_segmentation_metrics( + "seg-run-1", "sam3", "job-42", mock_beamline_config, + parent_run_id="parent-run-id-123", + ) + + kwargs = mock_start.call_args.kwargs + assert kwargs["nested"] is True + assert kwargs["parent_run_id"] == "parent-run-id-123" + + def test_extra_params_logged(self, mocker, mock_beamline_config): + from orchestration.mlflow import log_segmentation_metrics + _setup_log_metrics_mocks(mocker, mock_beamline_config) + mock_log_params = mocker.patch("mlflow.log_params") + + log_segmentation_metrics( + "seg-run-1", "sam3", "job-42", mock_beamline_config, + params={"dataset": "beamline_832", "threshold": 0.5}, + ) + + mock_log_params.assert_called_once_with({"dataset": "beamline_832", "threshold": 0.5}) + + def test_amsc_patch_called(self, mocker, mock_beamline_config): + from orchestration.mlflow import log_segmentation_metrics + _setup_log_metrics_mocks(mocker, mock_beamline_config) + mock_enable = mocker.patch("orchestration.mlflow._enable_amsc_x_api_key") + + log_segmentation_metrics("seg-run-1", "sam3", "job-42", mock_beamline_config) + + mock_enable.assert_called() + + def test_no_metrics_logged_when_no_timing(self, mocker, mock_beamline_config): + from orchestration.mlflow import log_segmentation_metrics + _setup_log_metrics_mocks(mocker, mock_beamline_config) + mock_log_metrics = mocker.patch("mlflow.log_metrics") + + log_segmentation_metrics("seg-run-1", "sam3", "job-42", mock_beamline_config) + + mock_log_metrics.assert_not_called() From ff8a45f97556aebb601a926680c85314fb0afe9a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 20 May 2026 14:34:14 -0700 Subject: [PATCH 39/42] removing quotes around enum --- orchestration/flows/bl832/job_controller.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/job_controller.py b/orchestration/flows/bl832/job_controller.py index 1a23d02a..c526f72a 100644 --- a/orchestration/flows/bl832/job_controller.py +++ b/orchestration/flows/bl832/job_controller.py @@ -79,7 +79,7 @@ class HPC(Enum): def get_controller( hpc_type: HPC, config: Config832, - login_method: "NERSCLoginMethod | None" = None, + login_method: NERSCLoginMethod | None = None, ) -> TomographyHPCController: """ Factory function that returns an HPC controller instance for the given HPC environment. @@ -104,6 +104,7 @@ def get_controller( resolved_login_method = login_method if isinstance(login_method, NERSCLoginMethod) else NERSCLoginMethod.SFAPI return NERSCTomographyHPCController( client=NERSCTomographyHPCController.create_nersc_client( + config=config, login_method=resolved_login_method ), config=config, From d95abcf75379488a99c3328d0f6e0fea2e31487b Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 20 May 2026 14:35:16 -0700 Subject: [PATCH 40/42] moving nersc iri/sf-api resource definitions to config (no longer global variables) --- config.yml | 33 +++++++++++++++++++++++------ orchestration/flows/bl832/config.py | 2 ++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/config.yml b/config.yml index 7632d213..255f82d7 100644 --- a/config.yml +++ b/config.yml @@ -178,13 +178,34 @@ mlflow: registry_uri: https://mlflow.american-science-cloud.org/ experiment_name: als-bl832-models +nersc_resources: + iri: + api_base_url: https://api.iri.nersc.gov + compute_resource: "compute" + # Perlmutter compute + perlmutter_compute: "94351904-6dba-4c16-b5cd-fbd280d8615b" + perlmutter_login: "e525a224-61c1-419f-9642-91168c792e39" + perlmutter_realtime: "3776417d-747c-4753-895a-6323c17b9c98" + perlmutter_job_submit: "3cf3c048-855e-4dd8-a189-065a483954bb" + # Storage + scratch: "43d8f6c0-f900-48ce-b267-73714103f4ac" + homes: "65b28619-c3b6-4942-8da1-044a3b3a2a9e" + common: "7e07a611-f927-4a39-a44d-b1d6e307accd" + cfs: "59e80c79-4dfd-4c53-9c07-7405685fcd37" + archive: "f4916c65-9001-49c2-b0bf-6fe4276b564c" + # Services + globus: "0a207df3-4bec-45b8-9060-13505d269da9" + dtns: "a762cbdc-af7a-4b2b-9463-67f0189dd2ae" + sfapi: + api_base_url: https://api.nersc.gov/api/v1.2 + hpc_submission_settings832: # ── RECON + MULTIRES SETTINGS ─────────────────────────────────────────────── nersc_reconstruction: # ── SLURM resource allocation ───────────────────────────────────────────── qos: regular account: amsc006 - reservation: "_CAP_SYNAPS_LIVEDEMO_CPU1" + reservation: "_CAP_SYNAPS_LIVEDEMO_CPU2" num_nodes: 16 cpus-per-task: 128 walltime: "0:30:00" @@ -192,7 +213,7 @@ hpc_submission_settings832: # ── SLURM resource allocation ───────────────────────────────────────────── qos: debug account: als - reservation: "_CAP_SYNAPS_LIVEDEMO_CPU1" + reservation: "_CAP_SYNAPS_LIVEDEMO_CPU2" cpus-per-task: 128 walltime: "0:15:00" @@ -202,7 +223,7 @@ hpc_submission_settings832: qos: regular account: amsc006 constraint: gpu - reservation: "_CAP_SYNAPS_LIVEDEMO_GPU1" + reservation: "_CAP_SYNAPS_LIVEDEMO_GPU2" num_nodes: 32 ntasks-per-node: 1 gpus-per-node: 4 @@ -233,7 +254,7 @@ hpc_submission_settings832: qos: regular account: amsc006 constraint: gpu - reservation: "_CAP_SYNAPS_LIVEDEMO_GPU1" + reservation: "_CAP_SYNAPS_LIVEDEMO_GPU2" num_nodes: 8 ntasks-per-node: 1 nproc_per_node: 4 @@ -253,7 +274,7 @@ hpc_submission_settings832: qos: regular account: amsc006 constraint: cpu - reservation: "_CAP_SYNAPS_LIVEDEMO_CPU1" + reservation: "_CAP_SYNAPS_LIVEDEMO_CPU2" num_nodes: 4 ntasks: 128 cpus-per-task: 1 @@ -272,7 +293,7 @@ hpc_submission_settings832: qos: regular account: als constraint: gpu - reservation: "_CAP_TOMO_MOON_GPU" + reservation: "_CAP_TOMO_MOON_GPU2" num_nodes: 4 ntasks-per-node: 1 nproc_per_node: 4 diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 281ba167..e338e5a0 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -31,6 +31,8 @@ def _beam_specific_config(self) -> None: self.scicat = self.config["scicat"] # MLflow self.mlflow = self.config["mlflow"]["amsc"] + # NERSC resources + self.nersc_resources = self.config["nersc_resources"] # NERSC HPC submission settings self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] From b333e8ef35862fd0b707630ef12a52b5bc9a70e9 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 20 May 2026 14:38:30 -0700 Subject: [PATCH 41/42] Updating nersc.py to pull iri/sf-api parameters from the config, rather than a global variable --- orchestration/_tests/test_bl832/test_nersc.py | 37 +++++++++--- orchestration/flows/bl832/nersc.py | 58 +++++++------------ 2 files changed, 51 insertions(+), 44 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index c5b9163f..b8de0129 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -75,6 +75,27 @@ def mock_config832(mocker): ep.root_path = f"/mock/{attr}" setattr(mock_config, attr, ep) + mock_config.nersc_resources = { + "iri": { + "api_base_url": "https://mock-iri.nersc.gov", + "perlmutter_compute": "mock-perlmutter-compute-uuid", + "perlmutter_login": "mock-perlmutter-login-uuid", + "perlmutter_realtime": "mock-perlmutter-realtime-uuid", + "perlmutter_job_submit": "mock-perlmutter-job-submit-uuid", + "compute_resource": "compute", + "scratch": "mock-scratch-uuid", + "homes": "mock-homes-uuid", + "common": "mock-common-uuid", + "cfs": "mock-cfs-uuid", + "archive": "mock-archive-uuid", + "globus": "mock-globus-uuid", + "dtns": "mock-dtns-uuid", + }, + "sfapi": { + "api_base_url": "https://mock-sfapi.nersc.gov", + }, + } + mock_config.nersc_recon_settings = { "qos": "realtime", "account": "mock_account", @@ -458,7 +479,6 @@ def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_co def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod - from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") @@ -471,24 +491,27 @@ def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, result = controller.reconstruct(file_path="folder/scan.h5") + # Resource lookups now live on the config, not module-level constants + iri = mock_config832.nersc_resources["iri"] + assert result["success"] is True assert result["job_id"] == "99999" mock_iriapi_client.post.assert_called_once() assert ( mock_iriapi_client.post.call_args.args[0] - == f"/api/v1/compute/job/{RESOURCE_IDS['perlmutter_job_submit']}" + == f"/api/v1/compute/job/{iri['perlmutter_job_submit']}" ) posted_json = mock_iriapi_client.post.call_args.kwargs["json"] assert posted_json["executable"] == "/bin/bash" - assert posted_json["arguments"] == ["-s"] # matches nersc.py - assert "pre_launch" in posted_json # script body lives here - assert "tomo_recon" in posted_json["pre_launch"] # sanity check it's the right script + assert posted_json["arguments"] == ["-s"] + assert "pre_launch" in posted_json + assert "tomo_recon" in posted_json["pre_launch"] mock_iriapi_client.get.assert_any_call( - f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/99999" + f"/api/v1/compute/status/{iri['compute_resource']}/99999" ) mock_iriapi_client.get.assert_any_call( - f"/api/v1/filesystem/view/{RESOURCE_IDS['perlmutter_login']}", + f"/api/v1/filesystem/view/{iri['perlmutter_login']}", params={"path": mocker.ANY}, ) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index cdd90dea..036608af 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -35,33 +35,8 @@ load_dotenv() # Applies only to NERSCLoginMethod.IRIAPI -_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" -_IRI_COMPUTE_RESOURCE: str = "compute" _IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" -_API_BASE_URLS: dict[NERSCLoginMethod, str] = { - NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", - NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", -} - -# NERSC resource IDs (from status/resources endpoint) -RESOURCE_IDS = { - # Perlmutter compute - "perlmutter_compute": "94351904-6dba-4c16-b5cd-fbd280d8615b", - "perlmutter_login": "e525a224-61c1-419f-9642-91168c792e39", - "perlmutter_realtime": "3776417d-747c-4753-895a-6323c17b9c98", - "perlmutter_job_submit": "3cf3c048-855e-4dd8-a189-065a483954bb", - # Storage - "scratch": "43d8f6c0-f900-48ce-b267-73714103f4ac", - "homes": "65b28619-c3b6-4942-8da1-044a3b3a2a9e", - "common": "7e07a611-f927-4a39-a44d-b1d6e307accd", - "cfs": "59e80c79-4dfd-4c53-9c07-7405685fcd37", - "archive": "f4916c65-9001-49c2-b0bf-6fe4276b564c", - # Services - "globus": "0a207df3-4bec-45b8-9060-13505d269da9", - "dtns": "a762cbdc-af7a-4b2b-9463-67f0189dd2ae", -} - @dataclass class SegmentationModelSpec: @@ -191,9 +166,16 @@ def __init__( TomographyHPCController.__init__(self, config) self.client = client self.login_method = login_method + if login_method is NERSCLoginMethod.IRIAPI: + self.nersc_resources: dict[str, str] = config.nersc_resources["iri"] + elif login_method is NERSCLoginMethod.SFAPI: + self.nersc_resources = config.nersc_resources["sfapi"] + else: + raise ValueError(f"Unsupported NERSCLoginMethod: {login_method}") @staticmethod def create_nersc_client( + config: Config832, login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, ) -> Client | httpx.Client: """Create and return a NERSC client for the requested login method. @@ -209,8 +191,9 @@ def create_nersc_client( file path, or rely on the default (``~/.globus/auth_tokens.json``). Args: + config: Config832 instance for accessing config settings needed during client creation. login_method: Which NERSC API to authenticate against. - Defaults to :attr:`NERSCLoginMethod.SFAPI`. + Defaults to :attr:`NERSCLoginMethod.IRIAPI`. Returns: An authenticated :class:`sfapi_client.Client` instance. @@ -222,32 +205,33 @@ def create_nersc_client( Exception: If the underlying client construction fails. """ logger.info(f"Creating NERSC client using login method: {login_method.value}") - api_url = _API_BASE_URLS[login_method] - logger.info(f"Targeting API base URL: {api_url}") if login_method is NERSCLoginMethod.SFAPI: + api_base_url = config.nersc_resources["sfapi"]["api_base_url"] client = NERSCTomographyHPCController._create_sfapi_client() elif login_method is NERSCLoginMethod.IRIAPI: - client = NERSCTomographyHPCController._create_iriapi_client() - + api_base_url = config.nersc_resources["iri"]["api_base_url"] + client = NERSCTomographyHPCController._create_iriapi_client(api_base_url) else: raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") logger.info( f"NERSC client created successfully " - f"(method={login_method.value}, api_url={api_url})." + f"(method={login_method.value}, api_url={api_base_url})." ) return client @staticmethod - def _create_iriapi_client() -> httpx.Client: + def _create_iriapi_client(api_base_url: str) -> httpx.Client: """Create a NERSC client for the IRI API using a Globus bearer token. Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the environment. Reuses a cached token if valid; otherwise mints a new one via the client credentials grant. No browser or user interaction. + Parameters: + api_base_url: The base URL for the NERSC IRI API Returns: An authenticated :class:`httpx.Client` targeting the IRI API. @@ -265,7 +249,7 @@ def _create_iriapi_client() -> httpx.Client: ) return httpx.Client( - base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + base_url=api_base_url, headers={"Authorization": f"Bearer {access_token}"}, timeout=httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0), ) @@ -450,7 +434,7 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: job_spec["stderr_path"] = sbatch_values["stderr_path"] response = self.client.post( - f"/api/v1/compute/job/{RESOURCE_IDS['perlmutter_job_submit']}", + f"/api/v1/compute/job/{self.nersc_resources['perlmutter_job_submit']}", json=job_spec, ) if not response.is_success: @@ -483,7 +467,7 @@ def _wait_for_job(self, job_id: str) -> bool: elif self.login_method is NERSCLoginMethod.IRIAPI: while True: response = self.client.get( - f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/{job_id}" # ← was "perlmutter" + f"/api/v1/compute/status/{self.nersc_resources['compute_resource']}/{job_id}" ) response.raise_for_status() state = response.json().get("status", {}).get("state") @@ -509,7 +493,7 @@ def _mkdir_remote(self, path: str) -> None: perlmutter.run(f"mkdir -p {path}") elif self.login_method is NERSCLoginMethod.IRIAPI: response = self.client.post( - f"/api/v1/filesystem/mkdir/{RESOURCE_IDS['perlmutter_login']}", + f"/api/v1/filesystem/mkdir/{self.nersc_resources['perlmutter_login']}", json={"path": path, "parents": True}, ) response.raise_for_status() @@ -538,7 +522,7 @@ def _read_remote_file(self, path: str) -> str: elif self.login_method is NERSCLoginMethod.IRIAPI: response = self.client.get( - f"/api/v1/filesystem/view/{RESOURCE_IDS['perlmutter_login']}", + f"/api/v1/filesystem/view/{self.nersc_resources['perlmutter_login']}", params={"path": path}, ) response.raise_for_status() From b7a0fa357a3c5e9be0d78f85dd75c030aefdbeb4 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 20 May 2026 14:42:33 -0700 Subject: [PATCH 42/42] removing redundant logging setLevel --- orchestration/flows/bl832/register_mlflow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/orchestration/flows/bl832/register_mlflow.py b/orchestration/flows/bl832/register_mlflow.py index 93603295..c224c769 100644 --- a/orchestration/flows/bl832/register_mlflow.py +++ b/orchestration/flows/bl832/register_mlflow.py @@ -5,7 +5,6 @@ from orchestration.mlflow import register_checkpoint logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)-7s | %(name)s - %(message)s")