From cc1eb0afa3ef2f8da9cd14102a148eadda1399d4 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Fri, 15 May 2026 16:57:04 -0400 Subject: [PATCH 1/5] feat(cli): support agents with custom training loops in handle_dse_job - Agents that set HAS_CUSTOM_TRAINING_LOOP = True drive their own training loop; handle_dse_job calls agent.train() and skips the per-step env.step loop. - New _run_custom_training_loop helper logs exceptions, returns a process-style exit code, and always invokes agent.shutdown() (when defined) in a finally block so resources are released on both success and failure paths. - CustomTrainingLoopAgent Protocol documents the opt-in contract for type checkers and IDEs. --- src/cloudai/cli/handlers.py | 40 ++++++++++- tests/test_handlers.py | 130 +++++++++++++++++++++++++++++++++++- 2 files changed, 167 insertions(+), 3 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 0284fcd9e..49f750529 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -20,7 +20,7 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Protocol, runtime_checkable from unittest.mock import Mock import toml @@ -118,6 +118,40 @@ def prepare_installation( return installables, installer +@runtime_checkable +class CustomTrainingLoopAgent(Protocol): + """ + Agent that drives its own training loop and skips the ``handle_dse_job`` step loop. + + Set ``HAS_CUSTOM_TRAINING_LOOP = True`` on the agent class to opt in. Used by + agents (e.g. RLlib-based) whose training loops are not modelled as a sequence + of independent ``select_action`` / ``env.step`` calls. + """ + + HAS_CUSTOM_TRAINING_LOOP: bool + + def train(self) -> None: ... + + +def _has_custom_training_loop(agent: object) -> bool: + return bool(getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False)) + + +def _run_custom_training_loop(agent: CustomTrainingLoopAgent, agent_type: str) -> int: + """Drive an agent's self-contained training loop and return a process-style exit code.""" + logging.info(f"Agent {agent_type} drives its own training loop; delegating to agent.train().") + try: + agent.train() + return 0 + except Exception: + logging.exception(f"Custom training loop failed for agent {agent_type}.") + return 1 + finally: + shutdown = getattr(agent, "shutdown", None) + if callable(shutdown): + shutdown() + + def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: registry = Registry() @@ -157,6 +191,10 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: agent = agent_class(env, agent_config) + if _has_custom_training_loop(agent): + err |= _run_custom_training_loop(agent, agent_type) + continue + for step in range(agent.max_steps): result = agent.select_action() if result is None: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 5124186c0..19e4b0eae 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -15,15 +15,22 @@ # limitations under the License. import argparse +import logging from pathlib import Path -from typing import Any, ClassVar, Iterator +from typing import Any, ClassVar, Iterator, Optional from unittest.mock import MagicMock import pandas as pd import pytest from pydantic import Field -from cloudai.cli.handlers import handle_dse_job, verify_system_configs, verify_test_configs, verify_test_scenarios +from cloudai.cli.handlers import ( + _run_custom_training_loop, + handle_dse_job, + verify_system_configs, + verify_test_configs, + verify_test_scenarios, +) from cloudai.core import ( BaseAgent, BaseAgentConfig, @@ -254,3 +261,122 @@ def test_verify_test_scenarios_logs_failure_details(tmp_path: Path, caplog: pyte assert str(broken_scenario) in caplog.text assert "duplicate TOML key 'name'" in caplog.text assert "1 out of 1 test scenarios have issues." in caplog.text + + +class CustomLoopStubAgentConfig(BaseAgentConfig): + pass + + +class CustomLoopStubAgent(BaseAgent): + """Stub agent that opts into the custom-training-loop dispatch path.""" + + HAS_CUSTOM_TRAINING_LOOP: ClassVar[bool] = True + + train_calls: ClassVar[int] = 0 + shutdown_calls: ClassVar[int] = 0 + train_raises: ClassVar[Optional[BaseException]] = None + + def __init__(self, env, config: CustomLoopStubAgentConfig): + self.env = env + self.config = config + self.max_steps = 0 + + @staticmethod + def get_config_class() -> type[CustomLoopStubAgentConfig]: + return CustomLoopStubAgentConfig + + def configure(self, config: dict[str, Any]) -> None: + raise NotImplementedError + + def select_action(self) -> tuple[int, dict[str, Any]]: # pragma: no cover - never called + raise AssertionError("select_action must not be called when HAS_CUSTOM_TRAINING_LOOP is True") + + def update_policy(self, _feedback: dict[str, Any]) -> None: + return + + def train(self) -> None: + CustomLoopStubAgent.train_calls += 1 + if CustomLoopStubAgent.train_raises is not None: + raise CustomLoopStubAgent.train_raises + + def shutdown(self) -> None: + CustomLoopStubAgent.shutdown_calls += 1 + + +@pytest.fixture +def custom_loop_agent_name() -> Iterator[str]: + registry = Registry() + agent_name = "test_handlers_custom_loop_agent" + old_agent = registry.agents_map.get(agent_name) + registry.update_agent(agent_name, CustomLoopStubAgent) + CustomLoopStubAgent.train_calls = 0 + CustomLoopStubAgent.shutdown_calls = 0 + CustomLoopStubAgent.train_raises = None + yield agent_name + CustomLoopStubAgent.train_calls = 0 + CustomLoopStubAgent.shutdown_calls = 0 + CustomLoopStubAgent.train_raises = None + if old_agent is None: + del registry.agents_map[agent_name] + else: + registry.update_agent(agent_name, old_agent) + + +def test_run_custom_training_loop_calls_train_and_shutdown() -> None: + agent = MagicMock() + agent.train = MagicMock() + agent.shutdown = MagicMock() + + assert _run_custom_training_loop(agent, "mock_agent") == 0 + agent.train.assert_called_once_with() + agent.shutdown.assert_called_once_with() + + +def test_run_custom_training_loop_returns_error_and_still_shuts_down( + caplog: pytest.LogCaptureFixture, +) -> None: + agent = MagicMock() + agent.train = MagicMock(side_effect=RuntimeError("boom")) + agent.shutdown = MagicMock() + + with caplog.at_level(logging.ERROR): + assert _run_custom_training_loop(agent, "mock_agent") == 1 + + agent.shutdown.assert_called_once_with() + assert "boom" in caplog.text + + +def test_run_custom_training_loop_tolerates_missing_shutdown() -> None: + agent = MagicMock(spec=["train"]) + agent.train = MagicMock() + + assert _run_custom_training_loop(agent, "mock_agent") == 0 + agent.train.assert_called_once_with() + + +def test_handle_dse_job_dispatches_to_custom_training_loop( + slurm_system: SlurmSystem, + dse_tr: TestRun, + custom_loop_agent_name: str, +) -> None: + dse_tr.test.agent = custom_loop_agent_name + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0 + assert CustomLoopStubAgent.train_calls == 1 + assert CustomLoopStubAgent.shutdown_calls == 1 + + +def test_handle_dse_job_propagates_custom_loop_failure( + slurm_system: SlurmSystem, + dse_tr: TestRun, + custom_loop_agent_name: str, +) -> None: + CustomLoopStubAgent.train_raises = RuntimeError("training blew up") + dse_tr.test.agent = custom_loop_agent_name + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 1 + assert CustomLoopStubAgent.shutdown_calls == 1 From ede2ae5a7af2b9a9d4763cbcaa79f56b7afaa343 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Fri, 15 May 2026 18:30:15 -0400 Subject: [PATCH 2/5] fix(cli): narrow agent type via TypeGuard in custom-loop dispatch Pyright rejected calling _run_custom_training_loop(agent, ...) because the plain bool predicate did not narrow agent's static type from BaseAgent to CustomTrainingLoopAgent. Return TypeGuard[CustomTrainingLoopAgent] from _has_custom_training_loop so the truthy branch in handle_dse_job sees the opted-in shape and the helper can call agent.train() directly. --- src/cloudai/cli/handlers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 49f750529..7d8c33689 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -20,7 +20,7 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Optional, Protocol, runtime_checkable +from typing import Callable, List, Optional, Protocol, TypeGuard, runtime_checkable from unittest.mock import Mock import toml @@ -133,7 +133,15 @@ class CustomTrainingLoopAgent(Protocol): def train(self) -> None: ... -def _has_custom_training_loop(agent: object) -> bool: +def _has_custom_training_loop(agent: object) -> TypeGuard[CustomTrainingLoopAgent]: + """ + Narrow ``agent`` to :class:`CustomTrainingLoopAgent` when it opts into the dispatch path. + + Returning :class:`TypeGuard` (instead of plain ``bool``) lets the type checker + treat this predicate like ``isinstance``: callers inside the truthy branch see + ``agent`` as a :class:`CustomTrainingLoopAgent`, so ``agent.train()`` type-checks + without ``getattr`` or ``cast``. + """ return bool(getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False)) From 9552e5a5ff0bf821136c3a96f00ff2cc4c7faef1 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Mon, 18 May 2026 12:21:47 -0400 Subject: [PATCH 3/5] review: isolate shutdown() failures from the exit-code contract If agent.shutdown() raised from the finally block, Python suppressed the earlier return 0/1 from agent.train() and propagated the exception, breaking the outer test-run loop in handle_dse_job (skipped remaining scenarios, failed to accumulate err |= rc). Wrap shutdown() in its own try/except, log via logging.exception, set rc = 1, and return rc after finally so the helper always honours the (int) -> int contract. Adds tests for shutdown-only failure and combined train+shutdown failure. --- src/cloudai/cli/handlers.py | 20 ++++++++++++++++---- tests/test_handlers.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 7d8c33689..5f862f270 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -146,18 +146,30 @@ def _has_custom_training_loop(agent: object) -> TypeGuard[CustomTrainingLoopAgen def _run_custom_training_loop(agent: CustomTrainingLoopAgent, agent_type: str) -> int: - """Drive an agent's self-contained training loop and return a process-style exit code.""" + """ + Drive an agent's self-contained training loop and return a process-style exit code. + + ``shutdown()`` runs inside its own ``try/except`` so a faulty teardown cannot + suppress the exit code from ``train()`` nor propagate out of this helper: + ``handle_dse_job`` relies on the returned ``rc`` to accumulate ``err |= rc`` + and continue with the remaining test runs. + """ logging.info(f"Agent {agent_type} drives its own training loop; delegating to agent.train().") + rc = 0 try: agent.train() - return 0 except Exception: logging.exception(f"Custom training loop failed for agent {agent_type}.") - return 1 + rc = 1 finally: shutdown = getattr(agent, "shutdown", None) if callable(shutdown): - shutdown() + try: + shutdown() + except Exception: + logging.exception(f"Shutdown failed for agent {agent_type}.") + rc = 1 + return rc def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 19e4b0eae..fec9f2eff 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -354,6 +354,38 @@ def test_run_custom_training_loop_tolerates_missing_shutdown() -> None: agent.train.assert_called_once_with() +def test_run_custom_training_loop_reports_shutdown_failure( + caplog: pytest.LogCaptureFixture, +) -> None: + """shutdown() raising must not suppress the exit code or propagate the exception.""" + agent = MagicMock() + agent.train = MagicMock() + agent.shutdown = MagicMock(side_effect=RuntimeError("teardown blew up")) + + with caplog.at_level(logging.ERROR): + assert _run_custom_training_loop(agent, "mock_agent") == 1 + + agent.train.assert_called_once_with() + agent.shutdown.assert_called_once_with() + assert "teardown blew up" in caplog.text + + +def test_run_custom_training_loop_reports_combined_train_and_shutdown_failures( + caplog: pytest.LogCaptureFixture, +) -> None: + """When both train() and shutdown() raise, the helper still returns 1 and logs both.""" + agent = MagicMock() + agent.train = MagicMock(side_effect=RuntimeError("training boom")) + agent.shutdown = MagicMock(side_effect=RuntimeError("teardown boom")) + + with caplog.at_level(logging.ERROR): + assert _run_custom_training_loop(agent, "mock_agent") == 1 + + agent.shutdown.assert_called_once_with() + assert "training boom" in caplog.text + assert "teardown boom" in caplog.text + + def test_handle_dse_job_dispatches_to_custom_training_loop( slurm_system: SlurmSystem, dse_tr: TestRun, From 7cebf90d6932e474cf5726fff6179e5ad211713b Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Fri, 22 May 2026 20:12:01 -0400 Subject: [PATCH 4/5] feat(core): TestRun owns trial counter; CloudAIGymEnv.step() is sole mutator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, ``test_run.step`` had no clear owner: the dispatcher set it from outside, the adapter rewound it on ``reset()``, and other callers wrote to it ad hoc. In RLlib custom-loop runs this collapsed every trial onto ``step=1``, overwriting ``trajectory.csv`` and ``env.csv`` rows. Centralize the invariant: ``TestRun.increment_step()`` is the single named mutator, and ``CloudAIGymEnv.step()`` is its only caller. One ``env.step()`` call advances the trial counter by exactly one — independent of any episode or dispatcher concept above the gym env. Contract tests in ``TestIncrementStep`` cover the API; ``test_cloudaigym`` asserts ``step`` is advanced *before* ``output_path`` and trajectory rows are computed, so cached and live trials both record the post-increment value. --- src/cloudai/_core/test_scenario.py | 5 +++++ src/cloudai/configurator/cloudai_gym.py | 1 + tests/test_cloudaigym.py | 14 +++++++++---- tests/test_test_scenario.py | 26 +++++++++++++++++++++++++ 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/cloudai/_core/test_scenario.py b/src/cloudai/_core/test_scenario.py index 4c768158d..4556c2168 100644 --- a/src/cloudai/_core/test_scenario.py +++ b/src/cloudai/_core/test_scenario.py @@ -109,6 +109,11 @@ def has_more_iterations(self) -> bool: """ return self.current_iteration + 1 < self.iterations + def increment_step(self) -> int: + """Advance the trial counter and return the new value.""" + self.step += 1 + return self.step + @property def metric_reporter(self) -> Optional[Type[ReportGenerationStrategy]]: if not self.reports: diff --git a/src/cloudai/configurator/cloudai_gym.py b/src/cloudai/configurator/cloudai_gym.py index d1bdba1f1..258a47f3c 100644 --- a/src/cloudai/configurator/cloudai_gym.py +++ b/src/cloudai/configurator/cloudai_gym.py @@ -118,6 +118,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]: - done (bool): Whether the episode is done. - info (dict): Additional info for debugging. """ + self.test_run.increment_step() self.test_run = self.test_run.apply_params_set(action) cached_result = self.get_cached_trajectory_result(action) diff --git a/tests/test_cloudaigym.py b/tests/test_cloudaigym.py index ecb9eb0a5..995eafc0e 100644 --- a/tests/test_cloudaigym.py +++ b/tests/test_cloudaigym.py @@ -152,10 +152,13 @@ def test_tr_output_path(setup_env: tuple[TestRun, BaseRunner]): agent = GridSearchAgent(env, GridSearchAgent.get_config_class()()) _, action = agent.select_action() - env.test_run.step = 42 + env.test_run.step = 41 env.step(action) - assert env.test_run.output_path.name == "42" + assert env.test_run.output_path.name == "42", ( + "CloudAIGymEnv.step() must advance test_run.step before computing output_path; " + "starting at 41, step #42's artifacts must land in dir '42'." + ) @pytest.mark.parametrize( @@ -401,7 +404,7 @@ def test_cached_step_appends_trajectory_row(nemorun: NeMoRunTestDefinition, tmp_ env.test_run.current_iteration = 0 env.trajectory = {0: [TrajectoryEntry(step=1, action=cached_action, reward=0.42, observation=[0.84])]} - env.test_run.step = 5 + env.test_run.step = 4 obs, reward, done, _info = env.step(cached_action) runner.run.assert_not_called() @@ -410,7 +413,10 @@ def test_cached_step_appends_trajectory_row(nemorun: NeMoRunTestDefinition, tmp_ assert done is False rows = env.trajectory[0] assert len(rows) == 2 - assert rows[-1].step == 5 + assert rows[-1].step == 5, ( + "CloudAIGymEnv.step() advances test_run.step before recording the trajectory row; " + "the cached row must be tagged with the advanced trial index, not the pre-step value." + ) assert rows[-1].reward == 0.42 assert rows[-1].action == cached_action diff --git a/tests/test_test_scenario.py b/tests/test_test_scenario.py index a1ba44dce..220da69e7 100644 --- a/tests/test_test_scenario.py +++ b/tests/test_test_scenario.py @@ -283,6 +283,32 @@ def test_total_time_limit_with_empty_hooks(): assert result == "01:00:00" +class TestIncrementStep: + """``TestRun.increment_step`` is the single mutator for the trial counter.""" + + def _make_tr(self, tdef: TestDefinition) -> TestRun: + return TestRun(name="incr_tr", test=tdef, num_nodes=1, nodes=[]) + + def test_starts_at_zero_and_advances_to_one(self, tdef: TestDefinition) -> None: + tr = self._make_tr(tdef) + assert tr.step == 0 + assert tr.increment_step() == 1 + assert tr.step == 1 + + def test_is_monotonic_across_repeated_calls(self, tdef: TestDefinition) -> None: + tr = self._make_tr(tdef) + seen = [tr.increment_step() for _ in range(5)] + assert seen == [1, 2, 3, 4, 5] + assert tr.step == 5 + + def test_resumes_from_pre_existing_value(self, tdef: TestDefinition) -> None: + """Recovery / batch-unroll callers may seed ``step`` to a historical value.""" + tr = self._make_tr(tdef) + tr.step = 42 + assert tr.increment_step() == 43 + assert tr.step == 43 + + class TestInScenario: @pytest.mark.parametrize("missing_arg", ["test_template_name", "name", "description"]) def test_without_base(self, missing_arg: str): From a1a268ab225069c11388f99abf0310ec91dd7495 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Fri, 22 May 2026 20:14:11 -0400 Subject: [PATCH 5/5] refactor(cli): collapse custom-loop dispatch to BaseAgent.run() polymorphism MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Earlier commits in this PR introduced ``HAS_CUSTOM_TRAINING_LOOP`` + a ``CustomTrainingLoopAgent`` Protocol + a TypeGuard helper + an ``if/else`` in ``handle_dse_job`` to switch between the cloudai step loop and an agent-owned ``train()`` loop. That is a type-tagged conditional dispatching on agent identity — the textbook signal to replace conditional with polymorphism (Fowler). Add a default ``BaseAgent.run() -> int`` that holds the step-loop body (``select_action`` / ``env.step`` / ``update_policy`` per trial). Agents that drive their own training (RLlib, etc.) override ``run()`` to delegate to whatever loop they own and return a process-style exit code. ``handle_dse_job`` collapses to ``err |= agent.run()`` — one line, no branching, no Protocol vocabulary. The handler no longer knows that "custom training loops" exist as a category; that's an agent implementation detail. Net: -89 lines on cloudai. Surface area shrinks (no Protocol, no TypeGuard, no flag). ``test_handlers`` replaces the 5 helper unit tests + 2 dispatcher integration tests with 2 polymorphic tests asserting ``handle_dse_job`` delegates to ``agent.run()`` and propagates its return code. --- src/cloudai/cli/handlers.py | 72 +------------ src/cloudai/configurator/base_agent.py | 27 +++++ tests/test_handlers.py | 139 +++++++------------------ 3 files changed, 64 insertions(+), 174 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 5f862f270..745a700ce 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -20,7 +20,7 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Optional, Protocol, TypeGuard, runtime_checkable +from typing import Callable, List, Optional from unittest.mock import Mock import toml @@ -118,60 +118,6 @@ def prepare_installation( return installables, installer -@runtime_checkable -class CustomTrainingLoopAgent(Protocol): - """ - Agent that drives its own training loop and skips the ``handle_dse_job`` step loop. - - Set ``HAS_CUSTOM_TRAINING_LOOP = True`` on the agent class to opt in. Used by - agents (e.g. RLlib-based) whose training loops are not modelled as a sequence - of independent ``select_action`` / ``env.step`` calls. - """ - - HAS_CUSTOM_TRAINING_LOOP: bool - - def train(self) -> None: ... - - -def _has_custom_training_loop(agent: object) -> TypeGuard[CustomTrainingLoopAgent]: - """ - Narrow ``agent`` to :class:`CustomTrainingLoopAgent` when it opts into the dispatch path. - - Returning :class:`TypeGuard` (instead of plain ``bool``) lets the type checker - treat this predicate like ``isinstance``: callers inside the truthy branch see - ``agent`` as a :class:`CustomTrainingLoopAgent`, so ``agent.train()`` type-checks - without ``getattr`` or ``cast``. - """ - return bool(getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False)) - - -def _run_custom_training_loop(agent: CustomTrainingLoopAgent, agent_type: str) -> int: - """ - Drive an agent's self-contained training loop and return a process-style exit code. - - ``shutdown()`` runs inside its own ``try/except`` so a faulty teardown cannot - suppress the exit code from ``train()`` nor propagate out of this helper: - ``handle_dse_job`` relies on the returned ``rc`` to accumulate ``err |= rc`` - and continue with the remaining test runs. - """ - logging.info(f"Agent {agent_type} drives its own training loop; delegating to agent.train().") - rc = 0 - try: - agent.train() - except Exception: - logging.exception(f"Custom training loop failed for agent {agent_type}.") - rc = 1 - finally: - shutdown = getattr(agent, "shutdown", None) - if callable(shutdown): - try: - shutdown() - except Exception: - logging.exception(f"Shutdown failed for agent {agent_type}.") - rc = 1 - return rc - - def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: registry = Registry() @@ -211,21 +157,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: agent = agent_class(env, agent_config) - if _has_custom_training_loop(agent): - err |= _run_custom_training_loop(agent, agent_type) - continue - - for step in range(agent.max_steps): - result = agent.select_action() - if result is None: - break - step, action = result - env.test_run.step = step - logging.info(f"Running step {step} (of {agent.max_steps}) with action {action}") - observation, reward, *_ = env.step(action) - feedback = {"trial_index": step, "value": reward} - agent.update_policy(feedback) - logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}") + err |= agent.run() if args.mode == "run": runner.runner.test_scenario.test_runs = original_test_runs diff --git a/src/cloudai/configurator/base_agent.py b/src/cloudai/configurator/base_agent.py index f7fafbd99..e9dc5c2f3 100644 --- a/src/cloudai/configurator/base_agent.py +++ b/src/cloudai/configurator/base_agent.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from abc import ABC, abstractmethod from typing import Any, Dict, Literal @@ -105,3 +106,29 @@ def update_policy(self, _feedback: Dict[str, Any]) -> None: feedback (Dict[str, Any]): Feedback information from the environment. """ pass + + def run(self) -> int: + """ + Orchestrate this agent's exploration over ``self.env``. + + Default: a step loop driven by the dispatcher (``select_action`` → + ``env.step`` → ``update_policy`` per trial). Agents that drive their + own training loop (e.g. RLlib-based agents calling ``algo.train()``) + override this method. + + Returns: + int: Process-style return code (``0`` success, non-zero failure). + ``handle_dse_job`` accumulates this via ``err |= agent.run()``. + """ + for _ in range(self.max_steps): + result = self.select_action() + if result is None: + break + step, action = result + logging.info(f"Running step {step} (of {self.max_steps}) with action {action}") + observation, reward, *_ = self.env.step(action) + self.update_policy({"trial_index": step, "value": reward}) + logging.info( + f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}" + ) + return 0 diff --git a/tests/test_handlers.py b/tests/test_handlers.py index fec9f2eff..174dfa9a5 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -15,7 +15,6 @@ # limitations under the License. import argparse -import logging from pathlib import Path from typing import Any, ClassVar, Iterator, Optional from unittest.mock import MagicMock @@ -25,7 +24,6 @@ from pydantic import Field from cloudai.cli.handlers import ( - _run_custom_training_loop, handle_dse_job, verify_system_configs, verify_test_configs, @@ -263,152 +261,85 @@ def test_verify_test_scenarios_logs_failure_details(tmp_path: Path, caplog: pyte assert "1 out of 1 test scenarios have issues." in caplog.text -class CustomLoopStubAgentConfig(BaseAgentConfig): +class CustomRunStubAgentConfig(BaseAgentConfig): pass -class CustomLoopStubAgent(BaseAgent): - """Stub agent that opts into the custom-training-loop dispatch path.""" +class CustomRunStubAgent(BaseAgent): + """Stub agent that overrides ``run()`` to drive its own training (e.g. RLlib-like).""" - HAS_CUSTOM_TRAINING_LOOP: ClassVar[bool] = True + run_calls: ClassVar[int] = 0 + run_returns: ClassVar[int] = 0 + run_raises: ClassVar[Optional[BaseException]] = None - train_calls: ClassVar[int] = 0 - shutdown_calls: ClassVar[int] = 0 - train_raises: ClassVar[Optional[BaseException]] = None - - def __init__(self, env, config: CustomLoopStubAgentConfig): + def __init__(self, env, config: CustomRunStubAgentConfig): self.env = env self.config = config self.max_steps = 0 @staticmethod - def get_config_class() -> type[CustomLoopStubAgentConfig]: - return CustomLoopStubAgentConfig + def get_config_class() -> type[CustomRunStubAgentConfig]: + return CustomRunStubAgentConfig def configure(self, config: dict[str, Any]) -> None: raise NotImplementedError def select_action(self) -> tuple[int, dict[str, Any]]: # pragma: no cover - never called - raise AssertionError("select_action must not be called when HAS_CUSTOM_TRAINING_LOOP is True") + raise AssertionError("select_action must not be called when run() is overridden") def update_policy(self, _feedback: dict[str, Any]) -> None: return - def train(self) -> None: - CustomLoopStubAgent.train_calls += 1 - if CustomLoopStubAgent.train_raises is not None: - raise CustomLoopStubAgent.train_raises - - def shutdown(self) -> None: - CustomLoopStubAgent.shutdown_calls += 1 + def run(self) -> int: + CustomRunStubAgent.run_calls += 1 + if CustomRunStubAgent.run_raises is not None: + raise CustomRunStubAgent.run_raises + return CustomRunStubAgent.run_returns @pytest.fixture -def custom_loop_agent_name() -> Iterator[str]: +def custom_run_agent_name() -> Iterator[str]: registry = Registry() - agent_name = "test_handlers_custom_loop_agent" + agent_name = "test_handlers_custom_run_agent" old_agent = registry.agents_map.get(agent_name) - registry.update_agent(agent_name, CustomLoopStubAgent) - CustomLoopStubAgent.train_calls = 0 - CustomLoopStubAgent.shutdown_calls = 0 - CustomLoopStubAgent.train_raises = None + registry.update_agent(agent_name, CustomRunStubAgent) + CustomRunStubAgent.run_calls = 0 + CustomRunStubAgent.run_returns = 0 + CustomRunStubAgent.run_raises = None yield agent_name - CustomLoopStubAgent.train_calls = 0 - CustomLoopStubAgent.shutdown_calls = 0 - CustomLoopStubAgent.train_raises = None + CustomRunStubAgent.run_calls = 0 + CustomRunStubAgent.run_returns = 0 + CustomRunStubAgent.run_raises = None if old_agent is None: del registry.agents_map[agent_name] else: registry.update_agent(agent_name, old_agent) -def test_run_custom_training_loop_calls_train_and_shutdown() -> None: - agent = MagicMock() - agent.train = MagicMock() - agent.shutdown = MagicMock() - - assert _run_custom_training_loop(agent, "mock_agent") == 0 - agent.train.assert_called_once_with() - agent.shutdown.assert_called_once_with() - - -def test_run_custom_training_loop_returns_error_and_still_shuts_down( - caplog: pytest.LogCaptureFixture, -) -> None: - agent = MagicMock() - agent.train = MagicMock(side_effect=RuntimeError("boom")) - agent.shutdown = MagicMock() - - with caplog.at_level(logging.ERROR): - assert _run_custom_training_loop(agent, "mock_agent") == 1 - - agent.shutdown.assert_called_once_with() - assert "boom" in caplog.text - - -def test_run_custom_training_loop_tolerates_missing_shutdown() -> None: - agent = MagicMock(spec=["train"]) - agent.train = MagicMock() - - assert _run_custom_training_loop(agent, "mock_agent") == 0 - agent.train.assert_called_once_with() - - -def test_run_custom_training_loop_reports_shutdown_failure( - caplog: pytest.LogCaptureFixture, -) -> None: - """shutdown() raising must not suppress the exit code or propagate the exception.""" - agent = MagicMock() - agent.train = MagicMock() - agent.shutdown = MagicMock(side_effect=RuntimeError("teardown blew up")) - - with caplog.at_level(logging.ERROR): - assert _run_custom_training_loop(agent, "mock_agent") == 1 - - agent.train.assert_called_once_with() - agent.shutdown.assert_called_once_with() - assert "teardown blew up" in caplog.text - - -def test_run_custom_training_loop_reports_combined_train_and_shutdown_failures( - caplog: pytest.LogCaptureFixture, -) -> None: - """When both train() and shutdown() raise, the helper still returns 1 and logs both.""" - agent = MagicMock() - agent.train = MagicMock(side_effect=RuntimeError("training boom")) - agent.shutdown = MagicMock(side_effect=RuntimeError("teardown boom")) - - with caplog.at_level(logging.ERROR): - assert _run_custom_training_loop(agent, "mock_agent") == 1 - - agent.shutdown.assert_called_once_with() - assert "training boom" in caplog.text - assert "teardown boom" in caplog.text - - -def test_handle_dse_job_dispatches_to_custom_training_loop( +def test_handle_dse_job_invokes_agent_run( slurm_system: SlurmSystem, dse_tr: TestRun, - custom_loop_agent_name: str, + custom_run_agent_name: str, ) -> None: - dse_tr.test.agent = custom_loop_agent_name + """``handle_dse_job`` must delegate orchestration to ``agent.run()`` (polymorphism).""" + dse_tr.test.agent = custom_run_agent_name test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0 - assert CustomLoopStubAgent.train_calls == 1 - assert CustomLoopStubAgent.shutdown_calls == 1 + assert CustomRunStubAgent.run_calls == 1 -def test_handle_dse_job_propagates_custom_loop_failure( +def test_handle_dse_job_propagates_agent_run_nonzero_rc( slurm_system: SlurmSystem, dse_tr: TestRun, - custom_loop_agent_name: str, + custom_run_agent_name: str, ) -> None: - CustomLoopStubAgent.train_raises = RuntimeError("training blew up") - dse_tr.test.agent = custom_loop_agent_name + """A non-zero rc from ``agent.run()`` must flow through to the caller via ``err |= rc``.""" + CustomRunStubAgent.run_returns = 1 + dse_tr.test.agent = custom_run_agent_name test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 1 - assert CustomLoopStubAgent.shutdown_calls == 1 + assert CustomRunStubAgent.run_calls == 1