From f00d3e9ea96a149a81d6d7e2915f723b49454235 Mon Sep 17 00:00:00 2001 From: Ivan Podkidyshev Date: Thu, 2 Apr 2026 17:59:52 +0200 Subject: [PATCH] generating reports for dse dry-run --- src/cloudai/cli/handlers.py | 5 ++--- tests/test_handlers.py | 39 +++++++++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 1337976c2..6507d36fc 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -163,9 +163,8 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: agent.update_policy(feedback) logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}") - if args.mode == "run": - runner.runner.test_scenario.test_runs = original_test_runs - generate_reports(runner.runner.system, runner.runner.test_scenario, runner.runner.scenario_root) + runner.runner.test_scenario.test_runs = original_test_runs + generate_reports(runner.runner.system, runner.runner.test_scenario, runner.runner.scenario_root) logging.info("All jobs are complete.") return err diff --git a/tests/test_handlers.py b/tests/test_handlers.py index e495162da..b67276142 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -39,6 +39,9 @@ class StubAgentConfig(BaseAgentConfig): knob: int = 0 + max_steps: int = 0 + start_step: int = 0 + action: dict[str, Any] = Field(default_factory=dict) payload: dict[str, Any] = Field(default_factory=dict) @@ -48,7 +51,8 @@ class StubAgent(BaseAgent): def __init__(self, env, config: StubAgentConfig): self.env = env self.config = config - self.max_steps = 0 + self._next_step = 0 + self.max_steps = config.max_steps StubAgent.received_configs.append(config) @staticmethod @@ -59,7 +63,12 @@ def configure(self, config: dict[str, Any]) -> None: raise NotImplementedError def select_action(self) -> tuple[int, dict[str, Any]]: - raise NotImplementedError + if self._next_step >= self.config.max_steps: + return None # type: ignore + + step = self.config.start_step + self._next_step + self._next_step += 1 + return step, self.config.action or {} def update_policy(self, _feedback: dict[str, Any]) -> None: return @@ -207,3 +216,29 @@ def _job_output_path(tr: TestRun, create: bool = True): pd.testing.assert_frame_equal(actual_trajectory, expected_trajectory) assert [tr.step for tr in reporter.trs] == [1, 3] + + +def test_dse_dry_run_reports(slurm_system: SlurmSystem, dse_tr: TestRun, tmp_path: Any, stub_agent_name: str): + dse_tr.test.agent = stub_agent_name + dse_tr.iterations = 1 + dse_tr.test.agent_config = { + "max_steps": 1, + "start_step": 1, + "action": {"extra_env_vars.VAR1": "value1"}, + } + + slurm_system.output_path = tmp_path + slurm_system.reports = { + "per_test": ReportConfig(enable=False), + "status": ReportConfig(enable=True), + "dse": ReportConfig(enable=True), + } + test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr]) + runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario) + + assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0 + + assert (runner.runner.scenario_root / "test_scenario.html").exists() + assert (runner.runner.scenario_root / dse_tr.name / "0" / "trajectory.csv").exists() + assert not (runner.runner.scenario_root / "test_scenario-dse-report.html").exists() + assert not (runner.runner.scenario_root / dse_tr.name / "0" / f"{dse_tr.name}.toml").exists()