Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/cloudai/_core/test_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def has_more_iterations(self) -> bool:
"""
return self.current_iteration + 1 < self.iterations

def increment_step(self) -> int:
"""Advance the trial counter and return the new value."""
self.step += 1
return self.step

@property
def metric_reporter(self) -> Optional[Type[ReportGenerationStrategy]]:
if not self.reports:
Expand Down
12 changes: 1 addition & 11 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:

agent = agent_class(env, agent_config)

for step in range(agent.max_steps):
result = agent.select_action()
if result is None:
break
step, action = result
env.test_run.step = step
logging.info(f"Running step {step} (of {agent.max_steps}) with action {action}")
observation, reward, *_ = env.step(action)
feedback = {"trial_index": step, "value": reward}
agent.update_policy(feedback)
logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}")
err |= agent.run()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard agent.run() failures to preserve DSE batch execution.

Line 160 can raise out of handle_dse_job if an agent throws, which aborts remaining test runs and bypasses intended err accumulation. Convert exceptions into a non-zero rc and continue.

Suggested patch
-        err |= agent.run()
+        try:
+            err |= agent.run()
+        except Exception:
+            logging.exception("Agent %s failed during run().", agent_type)
+            err |= 1
+            continue
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/cloudai/cli/handlers.py` at line 160, handle_dse_job currently calls
agent.run() directly which can throw and abort the whole DSE batch; wrap the
call to agent.run() in a try/except that catches any exception, sets/updates the
existing err variable to a non-zero return code (e.g., err |= 1 or err = 1) and
continues processing remaining runs instead of re-raising; ensure the catch logs
the exception (including agent identity) for debugging and references the
agent.run() call and err variable so the change is applied in the handle_dse_job
function.


if args.mode == "run":
runner.runner.test_scenario.test_runs = original_test_runs
Expand Down
27 changes: 27 additions & 0 deletions src/cloudai/configurator/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Literal

Expand Down Expand Up @@ -105,3 +106,29 @@ def update_policy(self, _feedback: Dict[str, Any]) -> None:
feedback (Dict[str, Any]): Feedback information from the environment.
"""
pass

def run(self) -> int:
"""
Orchestrate this agent's exploration over ``self.env``.

Default: a step loop driven by the dispatcher (``select_action`` →
``env.step`` → ``update_policy`` per trial). Agents that drive their
own training loop (e.g. RLlib-based agents calling ``algo.train()``)
override this method.

Returns:
int: Process-style return code (``0`` success, non-zero failure).
``handle_dse_job`` accumulates this via ``err |= agent.run()``.
"""
for _ in range(self.max_steps):
result = self.select_action()
if result is None:
break
step, action = result
logging.info(f"Running step {step} (of {self.max_steps}) with action {action}")
observation, reward, *_ = self.env.step(action)
self.update_policy({"trial_index": step, "value": reward})
logging.info(
f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}"
)
return 0
1 change: 1 addition & 0 deletions src/cloudai/configurator/cloudai_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
- done (bool): Whether the episode is done.
- info (dict): Additional info for debugging.
"""
self.test_run.increment_step()
self.test_run = self.test_run.apply_params_set(action)

cached_result = self.get_cached_trajectory_result(action)
Expand Down
14 changes: 10 additions & 4 deletions tests/test_cloudaigym.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,13 @@ def test_tr_output_path(setup_env: tuple[TestRun, BaseRunner]):
agent = GridSearchAgent(env, GridSearchAgent.get_config_class()())

_, action = agent.select_action()
env.test_run.step = 42
env.test_run.step = 41
env.step(action)

assert env.test_run.output_path.name == "42"
assert env.test_run.output_path.name == "42", (
"CloudAIGymEnv.step() must advance test_run.step before computing output_path; "
"starting at 41, step #42's artifacts must land in dir '42'."
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -401,7 +404,7 @@ def test_cached_step_appends_trajectory_row(nemorun: NeMoRunTestDefinition, tmp_
env.test_run.current_iteration = 0
env.trajectory = {0: [TrajectoryEntry(step=1, action=cached_action, reward=0.42, observation=[0.84])]}

env.test_run.step = 5
env.test_run.step = 4
obs, reward, done, _info = env.step(cached_action)

runner.run.assert_not_called()
Expand All @@ -410,7 +413,10 @@ def test_cached_step_appends_trajectory_row(nemorun: NeMoRunTestDefinition, tmp_
assert done is False
rows = env.trajectory[0]
assert len(rows) == 2
assert rows[-1].step == 5
assert rows[-1].step == 5, (
"CloudAIGymEnv.step() advances test_run.step before recording the trajectory row; "
"the cached row must be tagged with the advanced trial index, not the pre-step value."
)
assert rows[-1].reward == 0.42
assert rows[-1].action == cached_action

Expand Down
93 changes: 91 additions & 2 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@

import argparse
from pathlib import Path
from typing import Any, ClassVar, Iterator
from typing import Any, ClassVar, Iterator, Optional
from unittest.mock import MagicMock

import pandas as pd
import pytest
from pydantic import Field

from cloudai.cli.handlers import handle_dse_job, verify_system_configs, verify_test_configs, verify_test_scenarios
from cloudai.cli.handlers import (
handle_dse_job,
verify_system_configs,
verify_test_configs,
verify_test_scenarios,
)
from cloudai.core import (
BaseAgent,
BaseAgentConfig,
Expand Down Expand Up @@ -254,3 +259,87 @@ def test_verify_test_scenarios_logs_failure_details(tmp_path: Path, caplog: pyte
assert str(broken_scenario) in caplog.text
assert "duplicate TOML key 'name'" in caplog.text
assert "1 out of 1 test scenarios have issues." in caplog.text


class CustomRunStubAgentConfig(BaseAgentConfig):
pass


class CustomRunStubAgent(BaseAgent):
"""Stub agent that overrides ``run()`` to drive its own training (e.g. RLlib-like)."""

run_calls: ClassVar[int] = 0
run_returns: ClassVar[int] = 0
run_raises: ClassVar[Optional[BaseException]] = None

def __init__(self, env, config: CustomRunStubAgentConfig):
self.env = env
self.config = config
self.max_steps = 0

@staticmethod
def get_config_class() -> type[CustomRunStubAgentConfig]:
return CustomRunStubAgentConfig

def configure(self, config: dict[str, Any]) -> None:
raise NotImplementedError

def select_action(self) -> tuple[int, dict[str, Any]]: # pragma: no cover - never called
raise AssertionError("select_action must not be called when run() is overridden")

def update_policy(self, _feedback: dict[str, Any]) -> None:
return

def run(self) -> int:
CustomRunStubAgent.run_calls += 1
if CustomRunStubAgent.run_raises is not None:
raise CustomRunStubAgent.run_raises
return CustomRunStubAgent.run_returns


@pytest.fixture
def custom_run_agent_name() -> Iterator[str]:
registry = Registry()
agent_name = "test_handlers_custom_run_agent"
old_agent = registry.agents_map.get(agent_name)
registry.update_agent(agent_name, CustomRunStubAgent)
CustomRunStubAgent.run_calls = 0
CustomRunStubAgent.run_returns = 0
CustomRunStubAgent.run_raises = None
yield agent_name
CustomRunStubAgent.run_calls = 0
CustomRunStubAgent.run_returns = 0
CustomRunStubAgent.run_raises = None
if old_agent is None:
del registry.agents_map[agent_name]
else:
registry.update_agent(agent_name, old_agent)


def test_handle_dse_job_invokes_agent_run(
slurm_system: SlurmSystem,
dse_tr: TestRun,
custom_run_agent_name: str,
) -> None:
"""``handle_dse_job`` must delegate orchestration to ``agent.run()`` (polymorphism)."""
dse_tr.test.agent = custom_run_agent_name
test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr])
runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario)

assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 0
assert CustomRunStubAgent.run_calls == 1


def test_handle_dse_job_propagates_agent_run_nonzero_rc(
slurm_system: SlurmSystem,
dse_tr: TestRun,
custom_run_agent_name: str,
) -> None:
"""A non-zero rc from ``agent.run()`` must flow through to the caller via ``err |= rc``."""
CustomRunStubAgent.run_returns = 1
dse_tr.test.agent = custom_run_agent_name
test_scenario = TestScenario(name="test_scenario", test_runs=[dse_tr])
runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario)

assert handle_dse_job(runner, argparse.Namespace(mode="dry-run")) == 1
assert CustomRunStubAgent.run_calls == 1
26 changes: 26 additions & 0 deletions tests/test_test_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,32 @@ def test_total_time_limit_with_empty_hooks():
assert result == "01:00:00"


class TestIncrementStep:
"""``TestRun.increment_step`` is the single mutator for the trial counter."""

def _make_tr(self, tdef: TestDefinition) -> TestRun:
return TestRun(name="incr_tr", test=tdef, num_nodes=1, nodes=[])

def test_starts_at_zero_and_advances_to_one(self, tdef: TestDefinition) -> None:
tr = self._make_tr(tdef)
assert tr.step == 0
assert tr.increment_step() == 1
assert tr.step == 1

def test_is_monotonic_across_repeated_calls(self, tdef: TestDefinition) -> None:
tr = self._make_tr(tdef)
seen = [tr.increment_step() for _ in range(5)]
assert seen == [1, 2, 3, 4, 5]
assert tr.step == 5

def test_resumes_from_pre_existing_value(self, tdef: TestDefinition) -> None:
"""Recovery / batch-unroll callers may seed ``step`` to a historical value."""
tr = self._make_tr(tdef)
tr.step = 42
assert tr.increment_step() == 43
assert tr.step == 43


class TestInScenario:
@pytest.mark.parametrize("missing_arg", ["test_template_name", "name", "description"])
def test_without_base(self, missing_arg: str):
Expand Down