-
Notifications
You must be signed in to change notification settings - Fork 42
Agent params #792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Agent params #792
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,5 @@ | ||||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES | ||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||
|
|
@@ -21,11 +21,12 @@ | |||||||||||||||||||||||||||||||||||
| import signal | ||||||||||||||||||||||||||||||||||||
| from contextlib import contextmanager | ||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||
| from typing import Callable, List, Optional | ||||||||||||||||||||||||||||||||||||
| from typing import Any, Callable, List, Optional | ||||||||||||||||||||||||||||||||||||
| from unittest.mock import Mock | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| import toml | ||||||||||||||||||||||||||||||||||||
| import yaml | ||||||||||||||||||||||||||||||||||||
| from pydantic import ValidationError | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from cloudai.core import ( | ||||||||||||||||||||||||||||||||||||
| BaseInstaller, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -40,6 +41,11 @@ | |||||||||||||||||||||||||||||||||||
| TestParser, | ||||||||||||||||||||||||||||||||||||
| TestScenario, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| from cloudai.models.agent_config import ( | ||||||||||||||||||||||||||||||||||||
| BayesianOptimizationConfig, | ||||||||||||||||||||||||||||||||||||
| GeneticAlgorithmConfig, | ||||||||||||||||||||||||||||||||||||
| MultiArmedBanditConfig, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| from cloudai.models.scenario import ReportConfig | ||||||||||||||||||||||||||||||||||||
| from cloudai.models.workload import TestDefinition | ||||||||||||||||||||||||||||||||||||
| from cloudai.parser import HOOK_ROOT | ||||||||||||||||||||||||||||||||||||
|
|
@@ -145,7 +151,19 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: | |||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| env = CloudAIGymEnv(test_run=test_run, runner=runner.runner) | ||||||||||||||||||||||||||||||||||||
| agent = agent_class(env) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| agent_overrides = validate_agent_overrides(agent_type, test_run.test.agent_config) | ||||||||||||||||||||||||||||||||||||
| except ValidationError as e: | ||||||||||||||||||||||||||||||||||||
| logging.error(f"Invalid agent_config for agent '{agent_type}':") | ||||||||||||||||||||||||||||||||||||
| for error in e.errors(): | ||||||||||||||||||||||||||||||||||||
| field = ".".join(str(loc) for loc in error["loc"]) | ||||||||||||||||||||||||||||||||||||
| logging.error(f" - {field}: {error['msg']}") | ||||||||||||||||||||||||||||||||||||
| err = 1 | ||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| agent = agent_class(env, **agent_overrides) if agent_overrides else agent_class(env) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| for step in range(agent.max_steps): | ||||||||||||||||||||||||||||||||||||
| result = agent.select_action() | ||||||||||||||||||||||||||||||||||||
| if result is None: | ||||||||||||||||||||||||||||||||||||
|
|
@@ -166,6 +184,31 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: | |||||||||||||||||||||||||||||||||||
| return err | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]]) -> dict[str, Any]: | ||||||||||||||||||||||||||||||||||||
| """Validate and process agent configuration overrides.""" | ||||||||||||||||||||||||||||||||||||
| if not agent_config: | ||||||||||||||||||||||||||||||||||||
| return {} | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| config_class_map = { | ||||||||||||||||||||||||||||||||||||
| "ga": GeneticAlgorithmConfig, | ||||||||||||||||||||||||||||||||||||
| "bo": BayesianOptimizationConfig, | ||||||||||||||||||||||||||||||||||||
| "mab": MultiArmedBanditConfig, | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| config_class = config_class_map.get(agent_type) | ||||||||||||||||||||||||||||||||||||
| if not config_class: | ||||||||||||||||||||||||||||||||||||
| logging.debug(f"No config validation available for agent type '{agent_type}', using defaults.") | ||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||||||||||||||||||||||||||||||||||||
| return {} | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+198
to
+201
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Warn when overrides are ignored for unsupported agent types. 🛠 Proposed change- if not config_class:
- logging.debug(f"No config validation available for agent type '{agent_type}', using defaults.")
- return {}
+ if not config_class:
+ logging.warning(
+ f"agent_config provided for agent type '{agent_type}', but no validation model is registered; "
+ "overrides are ignored."
+ )
+ return {}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
Comment on lines
+200
to
+201
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silent fallback for unknown agent types could hide config typos. Users won't get any indication their config was ignored except a debug log. Consider logging at warning level when
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| validated_config = config_class.model_validate(agent_config) | ||||||||||||||||||||||||||||||||||||
| agent_kwargs = validated_config.model_dump(exclude_none=True) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if agent_kwargs: | ||||||||||||||||||||||||||||||||||||
| logging.info(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| return agent_kwargs | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None: | ||||||||||||||||||||||||||||||||||||
| registry = Registry() | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,73 @@ | ||||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES | ||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from abc import ABC | ||||||||||||||||||||||||||||||||||||
| from typing import Any, Optional | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from pydantic import BaseModel, ConfigDict, Field | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class AgentConfig(BaseModel, ABC): | ||||||||||||||||||||||||||||||||||||
| """Base configuration for agent overrides.""" | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| model_config = ConfigDict(extra="forbid") | ||||||||||||||||||||||||||||||||||||
| random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class GeneticAlgorithmConfig(AgentConfig): | ||||||||||||||||||||||||||||||||||||
| """Configuration overrides for Genetic Algorithm agent.""" | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| population_size: Optional[int] = Field(default=None, ge=2, description="Population size for the genetic algorithm") | ||||||||||||||||||||||||||||||||||||
alexmanle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| n_offsprings: Optional[int] = Field(default=None, ge=1, description="Number of offsprings per generation") | ||||||||||||||||||||||||||||||||||||
| crossover_prob: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Crossover probability") | ||||||||||||||||||||||||||||||||||||
| mutation_prob: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Mutation probability") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class BayesianOptimizationConfig(AgentConfig): | ||||||||||||||||||||||||||||||||||||
| """Configuration overrides for Bayesian Optimization agent.""" | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| sobol_num_trials: Optional[int] = Field(default=None, ge=1, description="Number of SOBOL initialization trials") | ||||||||||||||||||||||||||||||||||||
| botorch_num_trials: Optional[int] = Field( | ||||||||||||||||||||||||||||||||||||
| default=None, description="Number of BoTorch trials (-1 for unlimited until max_steps)" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+42
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # First, let's examine the actual file structure and find agent_config.py
fd -t f "agent_config.py" | head -20Repository: NVIDIA/cloudai Length of output: 92 🏁 Script executed: # Read the agent_config.py file to see current state
cat -n src/cloudai/models/agent_config.pyRepository: NVIDIA/cloudai Length of output: 3888 🏁 Script executed: # Search for botorch_num_trials usage throughout the codebase
rg "botorch_num_trials" -B 2 -A 2Repository: NVIDIA/cloudai Length of output: 499 🏁 Script executed: # Search for files that use BayesianOptimizationConfig or botorch_num_trials
rg "BayesianOptimizationConfig|botorch_num_trials" -t py --files-with-matchesRepository: NVIDIA/cloudai Length of output: 120 🏁 Script executed: # Look for the BO agent implementation
fd -t f ".*agent.*" -path "*/bayesian*" -o -path "**/optimization*" | grep -i agentRepository: NVIDIA/cloudai Length of output: 226 🏁 Script executed: # Broader search for agent implementations
fd -t f "agent" -type fRepository: NVIDIA/cloudai Length of output: 228 🏁 Script executed: # Search more broadly for agent implementations
fd -e py "agent" | head -30Repository: NVIDIA/cloudai Length of output: 152 🏁 Script executed: # Also search for where BayesianOptimizationConfig is instantiated
rg "BayesianOptimizationConfig" -B 3 -A 10 -t pyRepository: NVIDIA/cloudai Length of output: 2822 🏁 Script executed: # Search for botorch usage in the codebase
rg "botorch" -i -B 3 -A 3Repository: NVIDIA/cloudai Length of output: 673 🏁 Script executed: # Find the base_agent.py file and search for BO implementation
cat src/cloudai/configurator/base_agent.py | head -100Repository: NVIDIA/cloudai Length of output: 2224 🏁 Script executed: # Search for where botorch_num_trials is actually used in the agent
rg "botorch_num_trials" -A 5 -B 5 src/cloudai/configurator/Repository: NVIDIA/cloudai Length of output: 40 🏁 Script executed: # Look for any BO agent or related files
fd -e py | xargs rg -l "class.*Bayesian|botorch_num_trials"Repository: NVIDIA/cloudai Length of output: 92 🏁 Script executed: # Search for actual agent implementations/classes
rg "class.*Agent.*:" -A 5 | head -100Repository: NVIDIA/cloudai Length of output: 4155 🏁 Script executed: # Look at test_agents.py to see how these configs are tested
cat tests/test_agents.pyRepository: NVIDIA/cloudai Length of output: 2749 🏁 Script executed: # Check if there are any actual BO agent implementations elsewhere
fd -e py -x rg -l "num_trials" {} \;Repository: NVIDIA/cloudai Length of output: 94 🏁 Script executed: # Look for any BO agent implementations or references
fd -e py | xargs rg -l "BayesianOptimization|BoTorch" 2>/dev/nullRepository: NVIDIA/cloudai Length of output: 120 🏁 Script executed: # Check if there's a bayesian agent file
fd -e py -path "*/agent*" -o -path "*/bayesian*"Repository: NVIDIA/cloudai Length of output: 226 🏁 Script executed: # Search for where the agent configs are actually used
rg "botorch_num_trials|sobol_num_trials" -B 5 -A 5Repository: NVIDIA/cloudai Length of output: 905 🏁 Script executed: # Check what agent types are actually implemented
ls -la src/cloudai/configurator/Repository: NVIDIA/cloudai Length of output: 559 🏁 Script executed: # Check if BO agent exists or is planned
fd "\.py$" | xargs rg -l "class.*Bayesian" 2>/dev/nullRepository: NVIDIA/cloudai Length of output: 92 🏁 Script executed: # Look at handlers.py more carefully to see how config is used
cat src/cloudai/cli/handlers.py | grep -A 20 "BayesianOptimizationConfig"Repository: NVIDIA/cloudai Length of output: 1467 🏁 Script executed: # Check if there's external documentation or plugin system for agents
rg "entry_points\|entrypoint" -A 5 -B 5Repository: NVIDIA/cloudai Length of output: 40 🏁 Script executed: # Search for any test or documentation mentioning BO agent
rg "bayesian|bo agent" -i -B 2 -A 2Repository: NVIDIA/cloudai Length of output: 1069 Add validation constraint to 🛠 Proposed fix-from pydantic import BaseModel, ConfigDict, Field
+from pydantic import BaseModel, ConfigDict, Field, field_validator
@@
class BayesianOptimizationConfig(AgentConfig):
"""Configuration overrides for Bayesian Optimization agent."""
@@
botorch_num_trials: Optional[int] = Field(
default=None, description="Number of BoTorch trials (-1 for unlimited until max_steps)"
)
+
+ `@field_validator`("botorch_num_trials")
+ `@classmethod`
+ def _validate_botorch_num_trials(cls, v: Optional[int]) -> Optional[int]:
+ if v is None:
+ return v
+ if v != -1 and v < 1:
+ raise ValueError("botorch_num_trials must be -1 or >= 1")
+ return v📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alexmanle would be good to consider this. wdyt?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also would be good to test it with BO agent w/ AIConfigurator along with GA. |
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class MultiArmedBanditConfig(AgentConfig): | ||||||||||||||||||||||||||||||||||||
| """Configuration overrides for Multi-Armed Bandit agent.""" | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| algorithm: Optional[str] = Field( | ||||||||||||||||||||||||||||||||||||
| default=None, | ||||||||||||||||||||||||||||||||||||
| description="MAB algorithm: ucb1, ts (thompson_sampling), epsilon_greedy, softmax, or random", | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+51
to
+54
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧹 Nitpick | 🔵 Trivial Validate ♻️ Suggested tightening (if only these values are supported)-from typing import Any, Optional
+from typing import Any, Optional, Literal
@@
- algorithm: Optional[str] = Field(
+ algorithm: Optional[Literal["ucb1", "ts", "thompson_sampling", "epsilon_greedy", "softmax", "random"]] = Field(
default=None,
description="MAB algorithm: ucb1, ts (thompson_sampling), epsilon_greedy, softmax, or random",
)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||
| algorithm_params: Optional[dict[str, Any]] = Field( | ||||||||||||||||||||||||||||||||||||
| default=None, description="Algorithm-specific parameters (e.g., alpha for UCB1, epsilon for epsilon_greedy)" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| seed_parameters: Optional[dict[str, Any]] = Field( | ||||||||||||||||||||||||||||||||||||
| default=None, description="Initial seed configuration to evaluate first" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| max_arms: Optional[int] = Field(default=None, ge=1, description="Maximum number of arms in the action space") | ||||||||||||||||||||||||||||||||||||
| warm_start_size: Optional[int] = Field( | ||||||||||||||||||||||||||||||||||||
| default=None, ge=0, description="Number of arms to randomly explore initially" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| epsilon_override: Optional[float] = Field( | ||||||||||||||||||||||||||||||||||||
| default=None, ge=0.0, le=1.0, description="Epsilon value for exploration (overrides algorithm epsilon)" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| max_explore_steps: Optional[int] = Field( | ||||||||||||||||||||||||||||||||||||
| default=None, ge=0, description="Maximum steps for epsilon exploration (None for unlimited)" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| prefer_unseen_random: Optional[bool] = Field( | ||||||||||||||||||||||||||||||||||||
| default=None, description="Prefer unseen arms during random exploration (epsilon)" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conditional passes kwargs only when non-empty, but empty dict is falsy. If
agent_overrides = {}, the condition evaluates False and takes the else branch. Consider usingis not Noneor explicit length check for clarity:However, given
validate_agent_overridesalways returns a dict, the current logic works but is subtle.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!