diff --git a/src/cloudai/_core/test_scenario.py b/src/cloudai/_core/test_scenario.py index 4c768158d..4556c2168 100644 --- a/src/cloudai/_core/test_scenario.py +++ b/src/cloudai/_core/test_scenario.py @@ -109,6 +109,11 @@ def has_more_iterations(self) -> bool: """ return self.current_iteration + 1 < self.iterations + def increment_step(self) -> int: + """Advance the trial counter and return the new value.""" + self.step += 1 + return self.step + @property def metric_reporter(self) -> Optional[Type[ReportGenerationStrategy]]: if not self.reports: diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 0284fcd9e..745a700ce 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -157,17 +157,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: agent = agent_class(env, agent_config) - for step in range(agent.max_steps): - result = agent.select_action() - if result is None: - break - step, action = result - env.test_run.step = step - logging.info(f"Running step {step} (of {agent.max_steps}) with action {action}") - observation, reward, *_ = env.step(action) - feedback = {"trial_index": step, "value": reward} - agent.update_policy(feedback) - logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}") + err |= agent.run() if args.mode == "run": runner.runner.test_scenario.test_runs = original_test_runs diff --git a/src/cloudai/configurator/base_agent.py b/src/cloudai/configurator/base_agent.py index f7fafbd99..e9dc5c2f3 100644 --- a/src/cloudai/configurator/base_agent.py +++ b/src/cloudai/configurator/base_agent.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from abc import ABC, abstractmethod from typing import Any, Dict, Literal @@ -105,3 +106,29 @@ def update_policy(self, _feedback: Dict[str, Any]) -> None: feedback (Dict[str, Any]): Feedback information from the environment. """ pass + + def run(self) -> int: + """ + Orchestrate this agent's exploration over ``self.env``. + + Default: a step loop driven by the dispatcher (``select_action`` → + ``env.step`` → ``update_policy`` per trial). Agents that drive their + own training loop (e.g. RLlib-based agents calling ``algo.train()``) + override this method. + + Returns: + int: Process-style return code (``0`` success, non-zero failure). + ``handle_dse_job`` accumulates this via ``err |= agent.run()``. + """ + for _ in range(self.max_steps): + result = self.select_action() + if result is None: + break + step, action = result + logging.info(f"Running step {step} (of {self.max_steps}) with action {action}") + observation, reward, *_ = self.env.step(action) + self.update_policy({"trial_index": step, "value": reward}) + logging.info( + f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}" + ) + return 0 diff --git a/src/cloudai/configurator/cloudai_gym.py b/src/cloudai/configurator/cloudai_gym.py index d1bdba1f1..258a47f3c 100644 --- a/src/cloudai/configurator/cloudai_gym.py +++ b/src/cloudai/configurator/cloudai_gym.py @@ -118,6 +118,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]: - done (bool): Whether the episode is done. - info (dict): Additional info for debugging. """ + self.test_run.increment_step() self.test_run = self.test_run.apply_params_set(action) cached_result = self.get_cached_trajectory_result(action) diff --git a/tests/test_cloudaigym.py b/tests/test_cloudaigym.py index ecb9eb0a5..995eafc0e 100644 --- a/tests/test_cloudaigym.py +++ b/tests/test_cloudaigym.py @@ -152,10 +152,13 @@ def test_tr_output_path(setup_env: tuple[TestRun, BaseRunner]): agent = GridSearchAgent(env, GridSearchAgent.get_config_class()()) _, action = agent.select_action() - env.test_run.step = 42 + env.test_run.step = 41 env.step(action) - assert env.test_run.output_path.name == "42" + assert env.test_run.output_path.name == "42", ( + "CloudAIGymEnv.step() must advance test_run.step before computing output_path; " + "starting at 41, step #42's artifacts must land in dir '42'." + ) @pytest.mark.parametrize( @@ -401,7 +404,7 @@ def test_cached_step_appends_trajectory_row(nemorun: NeMoRunTestDefinition, tmp_ env.test_run.current_iteration = 0 env.trajectory = {0: [TrajectoryEntry(step=1, action=cached_action, reward=0.42, observation=[0.84])]} - env.test_run.step = 5 + env.test_run.step = 4 obs, reward, done, _info = env.step(cached_action) runner.run.assert_not_called() @@ -410,7 +413,10 @@ def test_cached_step_appends_trajectory_row(nemorun: NeMoRunTestDefinition, tmp_ assert done is False rows = env.trajectory[0] assert len(rows) == 2 - assert rows[-1].step == 5 + assert rows[-1].step == 5, ( + "CloudAIGymEnv.step() advances test_run.step before recording the trajectory row; " + "the cached row must be tagged with the advanced trial index, not the pre-step value." + ) assert rows[-1].reward == 0.42 assert rows[-1].action == cached_action diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 5124186c0..174dfa9a5 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -16,14 +16,19 @@ import argparse from pathlib import Path -from typing import Any, ClassVar, Iterator +from typing import Any, ClassVar, Iterator, Optional from unittest.mock import MagicMock import pandas as pd import pytest from pydantic import Field -from cloudai.cli.handlers import handle_dse_job, verify_system_configs, verify_test_configs, verify_test_scenarios +from cloudai.cli.handlers import ( + handle_dse_job, + verify_system_configs, + verify_test_configs, + verify_test_scenarios, +) from cloudai.core import ( BaseAgent, BaseAgentConfig, @@ -254,3 +259,87 @@ def test_verify_test_scenarios_logs_failure_details(tmp_path: Path, caplog: pyte assert str(broken_scenario) in caplog.text assert "duplicate TOML key 'name'" in caplog.text assert "1 out of 1 test scenarios have issues." in caplog.text + + +class CustomRunStubAgentConfig(BaseAgentConfig): + pass + + +class CustomRunStubAgent(BaseAgent): + """Stub agent that overrides ``run()`` to drive its own training (e.g. RLlib-like).""" + + run_calls: ClassVar[int] = 0 + run_returns: ClassVar[int] = 0 + run_raises: ClassVar[Optional[BaseException]] = None + + def __init__(self, env, config: CustomRunStubAgentConfig): + self.env = env + self.config = config + self.max_steps = 0 + + @staticmethod + def get_config_class() -> type[CustomRunStubAgentConfig]: + return CustomRunStubAgentConfig + + def configure(self, config: dict[str, Any]) -> None: + raise NotImplementedError + + def select_action(self) -> tuple[int, dict[str, Any]]: # pragma: no cover - never called + raise AssertionError("select_action must not be called when run() is overridden") + + def update_policy(self, _feedback: dict[str, Any]) -> None: + return + + def run(self) -> int: + CustomRunStubAgent.run_calls += 1 + if CustomRunStubAgent.run_raises is not None: + raise CustomRunStubAgent.run_raises + return CustomRunStubAgent.run_returns + + +@pytest.fixture +def custom_run_agent_name() -> Iterator[str]: + registry = Registry() + agent_name = "test_handlers_custom_run_agent" + old_agent = registry.agents_map.get(agent_name) + registry.update_agent(agent_name, CustomRunStubAgent) + CustomRunStubAgent.run_calls = 0 + CustomRunStubAgent.run_returns = 0 + CustomRunStubAgent.run_raises = None + yield agent_name + CustomRunStubAgent.run_calls = 0 + CustomRunStubAgent.run_returns = 0 + CustomRunStubAgent.run_raises = None + if old_agent is None: + del registry.agents_map[agent_name] + else: + registry.update_agent(agent_name, old_agent) + + +def test_handle_dse_job_invokes_agent_run( + slurm_system: SlurmSystem, + dse_tr: TestRun, + custom_run_agent_name: str, +) -> None: + """``handle_dse_job`` must delegate orchestration to ``agent.run()`` (polymorphism).""" + dse_tr.test.agent = custom_run_agent_name + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0 + assert CustomRunStubAgent.run_calls == 1 + + +def test_handle_dse_job_propagates_agent_run_nonzero_rc( + slurm_system: SlurmSystem, + dse_tr: TestRun, + custom_run_agent_name: str, +) -> None: + """A non-zero rc from ``agent.run()`` must flow through to the caller via ``err |= rc``.""" + CustomRunStubAgent.run_returns = 1 + dse_tr.test.agent = custom_run_agent_name + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 1 + assert CustomRunStubAgent.run_calls == 1 diff --git a/tests/test_test_scenario.py b/tests/test_test_scenario.py index a1ba44dce..220da69e7 100644 --- a/tests/test_test_scenario.py +++ b/tests/test_test_scenario.py @@ -283,6 +283,32 @@ def test_total_time_limit_with_empty_hooks(): assert result == "01:00:00" +class TestIncrementStep: + """``TestRun.increment_step`` is the single mutator for the trial counter.""" + + def _make_tr(self, tdef: TestDefinition) -> TestRun: + return TestRun(name="incr_tr", test=tdef, num_nodes=1, nodes=[]) + + def test_starts_at_zero_and_advances_to_one(self, tdef: TestDefinition) -> None: + tr = self._make_tr(tdef) + assert tr.step == 0 + assert tr.increment_step() == 1 + assert tr.step == 1 + + def test_is_monotonic_across_repeated_calls(self, tdef: TestDefinition) -> None: + tr = self._make_tr(tdef) + seen = [tr.increment_step() for _ in range(5)] + assert seen == [1, 2, 3, 4, 5] + assert tr.step == 5 + + def test_resumes_from_pre_existing_value(self, tdef: TestDefinition) -> None: + """Recovery / batch-unroll callers may seed ``step`` to a historical value.""" + tr = self._make_tr(tdef) + tr.step = 42 + assert tr.increment_step() == 43 + assert tr.step == 43 + + class TestInScenario: @pytest.mark.parametrize("missing_arg", ["test_template_name", "name", "description"]) def test_without_base(self, missing_arg: str):