Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ requires-python = ">=3.10"
"import-linter~=2.10",
"pytest-deadfixtures~=3.1",
"taplo~=0.9.3",
"gymnasium~=1.2",
]
rl = ["gymnasium~=1.2"]
docs = [
"sphinx~=8.1",
"nvidia-sphinx-theme~=0.0.8",
Expand Down
14 changes: 14 additions & 0 deletions src/cloudai/_core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Registry(metaclass=Singleton):
scenario_reports: ClassVar[dict[str, type[Reporter]]] = {}
report_configs: ClassVar[dict[str, ReportConfig]] = {}
reward_functions_map: ClassVar[dict[str, RewardFunction]] = {}
env_factories_map: ClassVar[dict[str, Callable]] = {}
command_gen_strategies_map: ClassVar[dict[tuple[Type[System], Type[TestDefinition]], Type[CommandGenStrategy]]] = {}
json_gen_strategies_map: ClassVar[dict[tuple[Type[System], Type[TestDefinition]], Type[JsonGenStrategy]]] = {}
grading_strategies_map: ClassVar[dict[Tuple[Type[System], Type[TestDefinition]], Type[GradingStrategy]]] = {}
Expand Down Expand Up @@ -249,6 +250,19 @@ def get_reward_function(self, name: str) -> RewardFunction:
)
return self.reward_functions_map[name]

def add_env_factory(self, name: str, factory: Callable) -> None:
if name in self.env_factories_map:
raise ValueError(f"Duplicating implementation for '{name}', use 'update()' for replacement.")
self.update_env_factory(name, factory)

def update_env_factory(self, name: str, factory: Callable) -> None:
self.env_factories_map[name] = factory

def get_env_factory(self, name: str) -> Callable:
if name not in self.env_factories_map:
raise KeyError(f"Env factory '{name}' not found. Available: {list(self.env_factories_map.keys())}")
return self.env_factories_map[name]

def add_command_gen_strategy(
self, system_type: Type[System], tdef_type: Type[TestDefinition], value: Type[CommandGenStrategy]
) -> None:
Expand Down
40 changes: 35 additions & 5 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ def prepare_installation(
return installables, installer


def _run_custom_training_loop(agent: object, agent_type: str) -> int:
"""Delegate to an agent's own training loop (e.g. RLlib PPO)."""
logging.info(f"Agent {agent_type} uses a custom training loop, delegating to agent.train()")
try:
agent.train() # type: ignore[union-attr]
return 0
except Exception as e:
logging.error(f"Agent training failed for {agent_type}: {e}", exc_info=True)
return 1
finally:
if hasattr(agent, "shutdown"):
agent.shutdown() # type: ignore[union-attr]


def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
registry = Registry()

Expand All @@ -132,6 +146,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
err = 0
for tr in runner.runner.test_scenario.test_runs:
test_run = copy.deepcopy(tr)
test_run.output_path = runner.runner.get_job_output_path(test_run)

agent_type = test_run.test.agent
agent_class = registry.agents_map.get(agent_type)
Expand All @@ -151,15 +166,29 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:

agent = agent_class(env, agent_config)

if getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False):
err |= _run_custom_training_loop(agent, agent_type)
continue

observation, _ = env.reset()

for step in range(agent.max_steps):
result = agent.select_action()
result = agent.select_action(observation=observation)
if result is None:
break
step, action = result
env.test_run.step = step
logging.info(f"Running step {step} (of {agent.max_steps}) with action {action}")
observation, reward, *_ = env.step(action)
feedback = {"trial_index": step, "value": reward}
prev_obs = observation
observation, reward, done, *_ = env.step(action)
feedback = {
"trial_index": step,
"value": reward,
"observation": observation,
"prev_observation": prev_obs,
"action": action,
"done": done,
}
agent.update_policy(feedback)
logging.info(f"Step {step}: Observation: {[round(obs, 4) for obs in observation]}, Reward: {reward:.4f}")

Expand Down Expand Up @@ -304,11 +333,12 @@ def handle_dry_run_and_run(args: argparse.Namespace) -> int:
logging.info(f"Scenario results will be stored at: {runner.runner.scenario_root}")

has_dse = any(tr.is_dse_job for tr in test_scenario.test_runs)
if args.single_sbatch or not has_dse: # in this mode cases are unrolled using grid search
has_live_rl = any(getattr(tr.test.cmd_args, "live_rl_mode", False) for tr in test_scenario.test_runs)
if args.single_sbatch or (not has_dse and not has_live_rl):
handle_non_dse_job(runner, args)
return 0

if all(tr.is_dse_job for tr in test_scenario.test_runs):
if all(tr.is_dse_job or getattr(tr.test.cmd_args, "live_rl_mode", False) for tr in test_scenario.test_runs):
return handle_dse_job(runner, args)

logging.error("Mixing DSE and non-DSE jobs is not allowed.")
Expand Down
5 changes: 4 additions & 1 deletion src/cloudai/configurator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

from .base_agent import BaseAgent
from .base_gym import BaseGym
from .cloudai_gym import CloudAIGymEnv, TrajectoryEntry
from .cloudai_gym import CloudAIGymEnv, GymServer, TrajectoryEntry
from .grid_search import GridSearchAgent
from .gymnasium_adapter import GymnasiumAdapter

__all__ = [
"BaseAgent",
"BaseGym",
"CloudAIGymEnv",
"GridSearchAgent",
"GymServer",
"GymnasiumAdapter",
"TrajectoryEntry",
]
7 changes: 5 additions & 2 deletions src/cloudai/configurator/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Literal
from typing import Any, Dict, Literal, Optional

from pydantic import BaseModel, ConfigDict

Expand Down Expand Up @@ -68,10 +68,13 @@ def configure(self, config: dict[str, Any]) -> None:
pass

@abstractmethod
def select_action(self) -> tuple[int, dict[str, Any]]:
def select_action(self, observation: Optional[list] = None) -> tuple[int, dict[str, Any]]:
"""
Select an action from the action space.

Args:
observation: Optional environment observation from the previous step.

Returns:
Tuple[int, Dict[str, Any]]: The current step index and a dictionary mapping action keys to selected values.
"""
Expand Down
Loading