Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,8 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
agent.update_policy(feedback)
logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}")

if args.mode == "run":
runner.runner.test_scenario.test_runs = original_test_runs
generate_reports(runner.runner.system, runner.runner.test_scenario, runner.runner.scenario_root)
runner.runner.test_scenario.test_runs = original_test_runs
generate_reports(runner.runner.system, runner.runner.test_scenario, runner.runner.scenario_root)

logging.info("All jobs are complete.")
return err
Expand Down
39 changes: 37 additions & 2 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@

class StubAgentConfig(BaseAgentConfig):
knob: int = 0
max_steps: int = 0
start_step: int = 0
action: dict[str, Any] = Field(default_factory=dict)
payload: dict[str, Any] = Field(default_factory=dict)


Expand All @@ -48,7 +51,8 @@ class StubAgent(BaseAgent):
def __init__(self, env, config: StubAgentConfig):
self.env = env
self.config = config
self.max_steps = 0
self._next_step = 0
self.max_steps = config.max_steps
StubAgent.received_configs.append(config)

@staticmethod
Expand All @@ -59,7 +63,12 @@ def configure(self, config: dict[str, Any]) -> None:
raise NotImplementedError

def select_action(self) -> tuple[int, dict[str, Any]]:
raise NotImplementedError
if self._next_step >= self.config.max_steps:
return None # type: ignore

step = self.config.start_step + self._next_step
self._next_step += 1
return step, self.config.action or {}

def update_policy(self, _feedback: dict[str, Any]) -> None:
return
Expand Down Expand Up @@ -207,3 +216,29 @@ def _job_output_path(tr: TestRun, create: bool = True):
pd.testing.assert_frame_equal(actual_trajectory, expected_trajectory)

assert [tr.step for tr in reporter.trs] == [1, 3]


def test_dse_dry_run_reports(slurm_system: SlurmSystem, dse_tr: TestRun, tmp_path: Any, stub_agent_name: str):
dse_tr.test.agent = stub_agent_name
dse_tr.iterations = 1
dse_tr.test.agent_config = {
"max_steps": 1,
"start_step": 1,
"action": {"extra_env_vars.VAR1": "value1"},
}

slurm_system.output_path = tmp_path
slurm_system.reports = {
"per_test": ReportConfig(enable=False),
"status": ReportConfig(enable=True),
"dse": ReportConfig(enable=True),
}
test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr])
runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario)

assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0

assert (runner.runner.scenario_root / "test_scenario.html").exists()
assert (runner.runner.scenario_root / dse_tr.name / "0" / "trajectory.csv").exists()
assert not (runner.runner.scenario_root / "test_scenario-dse-report.html").exists()
assert not (runner.runner.scenario_root / dse_tr.name / "0" / f"{dse_tr.name}.toml").exists()
Loading