Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 73 additions & 19 deletions pyrit/setup/configuration_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any, Optional, Union

from pyrit.common.path import DEFAULT_CONFIG_PATH
from pyrit.common.utils import verify_and_resolve_path
from pyrit.common.yaml_loadable import YamlLoadable
from pyrit.identifiers.class_name_utils import class_name_to_snake_case
from pyrit.setup.initialization import (
Expand Down Expand Up @@ -100,6 +101,8 @@ class ConfigurationLoader(YamlLoadable):
silent: bool = False
operator: Optional[str] = None
operation: Optional[str] = None
_initialization_scripts_base_path: Optional[pathlib.Path] = field(default=None, init=False, repr=False)
_env_files_base_path: Optional[pathlib.Path] = field(default=None, init=False, repr=False)

def __post_init__(self) -> None:
"""Validate and normalize the configuration after loading."""
Expand Down Expand Up @@ -179,6 +182,48 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfigurationLoader":
filtered_data = {k: v for k, v in data.items() if v is not None}
return cls(**filtered_data)

@classmethod
def from_yaml_file(cls, file: pathlib.Path | str) -> "ConfigurationLoader":
"""
Create a ConfigurationLoader from a YAML file and preserve its base directory.

Relative initialization script and env file paths should resolve from the
configuration file directory rather than the caller's working directory.

Returns:
A new ConfigurationLoader instance with per-field path resolution bases.
"""
resolved_file = verify_and_resolve_path(file)
config = YamlLoadable.from_yaml_file.__func__(cls, resolved_file)
config._set_path_resolution_base_paths(
initialization_scripts_base_path=resolved_file.parent,
env_files_base_path=resolved_file.parent,
)
return config

def _set_path_resolution_base_paths(
self,
*,
initialization_scripts_base_path: Optional[pathlib.Path],
env_files_base_path: Optional[pathlib.Path],
) -> None:
"""Set per-field base paths for resolving relative configuration paths."""
self._initialization_scripts_base_path = initialization_scripts_base_path
self._env_files_base_path = env_files_base_path

@staticmethod
def _resolve_config_path(path_str: str, base_path: Optional[pathlib.Path]) -> pathlib.Path:
"""
Resolve config-provided relative paths against an optional base directory.

Returns:
An absolute path when a relative base is available, or the original absolute path.
"""
config_path = pathlib.Path(path_str)
if config_path.is_absolute():
return config_path
return (base_path or pathlib.Path.cwd()) / config_path

@staticmethod
def load_with_overrides(
config_file: Optional[pathlib.Path] = None,
Expand Down Expand Up @@ -217,6 +262,8 @@ def load_with_overrides(
import logging

logger = logging.getLogger(__name__)
initialization_scripts_base_path: Optional[pathlib.Path] = None
env_files_base_path: Optional[pathlib.Path] = None

# Start with defaults - None means "use defaults", [] means "load nothing"
config_data: dict[str, Any] = {
Expand All @@ -239,6 +286,8 @@ def load_with_overrides(
# Preserve None vs [] distinction from config file
config_data["initialization_scripts"] = default_config.initialization_scripts
config_data["env_files"] = default_config.env_files
initialization_scripts_base_path = default_config._initialization_scripts_base_path
env_files_base_path = default_config._env_files_base_path
if default_config.operator:
config_data["operator"] = default_config.operator
if default_config.operation:
Expand All @@ -259,6 +308,8 @@ def load_with_overrides(
# Preserve None vs [] distinction from config file
config_data["initialization_scripts"] = explicit_config.initialization_scripts
config_data["env_files"] = explicit_config.env_files
initialization_scripts_base_path = explicit_config._initialization_scripts_base_path
env_files_base_path = explicit_config._env_files_base_path
if explicit_config.operator:
config_data["operator"] = explicit_config.operator
if explicit_config.operation:
Expand All @@ -280,11 +331,18 @@ def load_with_overrides(

if initialization_scripts is not None:
config_data["initialization_scripts"] = list(initialization_scripts)
initialization_scripts_base_path = None

if env_files is not None:
config_data["env_files"] = list(env_files)
env_files_base_path = None

return ConfigurationLoader.from_dict(config_data)
config = ConfigurationLoader.from_dict(config_data)
config._set_path_resolution_base_paths(
initialization_scripts_base_path=initialization_scripts_base_path,
env_files_base_path=env_files_base_path,
)
return config

@classmethod
def get_default_config_path(cls) -> pathlib.Path:
Expand Down Expand Up @@ -325,8 +383,12 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]:
f"Initializer '{config.name}' not found in registry.\nAvailable initializers: {available}"
)

# Instantiate with args if provided
instance = initializer_class(**config.args) if config.args else initializer_class()
# Instantiate and set params if provided
instance = initializer_class()
if config.args:
instance.set_params_from_args(args=config.args)
# Validate params early against supported_parameters to fail fast
instance._validate_params(params=instance.params)

resolved.append(instance)

Expand All @@ -348,14 +410,10 @@ def _resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]:
if len(self.initialization_scripts) == 0:
return []

resolved: list[pathlib.Path] = []
for script_str in self.initialization_scripts:
script_path = pathlib.Path(script_str)
if not script_path.is_absolute():
script_path = pathlib.Path.cwd() / script_path
resolved.append(script_path)

return resolved
return [
self._resolve_config_path(script_str, self._initialization_scripts_base_path)
for script_str in self.initialization_scripts
]

def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]:
"""
Expand All @@ -373,14 +431,10 @@ def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]:
if len(self.env_files) == 0:
return []

resolved: list[pathlib.Path] = []
for env_str in self.env_files:
env_path = pathlib.Path(env_str)
if not env_path.is_absolute():
env_path = pathlib.Path.cwd() / env_path
resolved.append(env_path)

return resolved
return [
self._resolve_config_path(env_str, self._env_files_base_path)
for env_str in self.env_files
]

async def initialize_pyrit_async(self) -> None:
"""
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/setup/test_configuration_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,26 @@ def test_resolve_initialization_scripts_relative_path(self):
# Check path ends with expected components (works on both Unix and Windows)
assert resolved[0].parts[-2:] == ("relative", "script.py")

@pytest.mark.parametrize(
("field_name", "resolver_name", "relative_path"),
[
("initialization_scripts", "_resolve_initialization_scripts", "scripts/init.py"),
("env_files", "_resolve_env_files", "env/local.env"),
],
)
def test_from_yaml_file_resolves_relative_paths_from_config_directory(
self, tmp_path, field_name, resolver_name, relative_path
):
"""Test relative paths from YAML are resolved from the config file directory."""
config_path = tmp_path / "configs" / "pyrit.yaml"
config_path.parent.mkdir()
config_path.write_text(f"{field_name}:\n - ./{relative_path}\n", encoding="utf-8")

config = ConfigurationLoader.from_yaml_file(config_path)
resolved = getattr(config, resolver_name)()

assert resolved == [config_path.parent / relative_path]

def test_resolve_env_files_none_returns_none(self):
"""Test that None (default) returns None to signal 'use defaults'."""
config = ConfigurationLoader()
Expand Down Expand Up @@ -421,6 +441,50 @@ def test_load_with_overrides_env_files_override(self, mock_default_path):

assert config.env_files == ["/path/to/.env"]

@mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH")
@pytest.mark.parametrize(
("field_name", "resolver_name", "relative_path"),
[
("initialization_scripts", "_resolve_initialization_scripts", "scripts/init.py"),
("env_files", "_resolve_env_files", "env/local.env"),
],
)
def test_load_with_overrides_resolves_relative_paths_from_config_directory(
self, mock_default_path, tmp_path, field_name, resolver_name, relative_path
):
"""Test config file relative paths are resolved from the config file directory."""
mock_default_path.exists.return_value = False
config_path = tmp_path / "configs" / "pyrit.yaml"
config_path.parent.mkdir()
config_path.write_text(f"{field_name}:\n - ./{relative_path}\n", encoding="utf-8")

config = ConfigurationLoader.load_with_overrides(config_file=config_path)
resolved = getattr(config, resolver_name)()

assert resolved == [config_path.parent / relative_path]

@mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH")
@pytest.mark.parametrize(
("field_name", "resolver_name", "relative_path"),
[
("initialization_scripts", "_resolve_initialization_scripts", "scripts/override.py"),
("env_files", "_resolve_env_files", "env/override.env"),
],
)
def test_load_with_overrides_cli_relative_paths_use_cwd(
self, mock_default_path, tmp_path, field_name, resolver_name, relative_path
):
"""Test CLI path overrides keep resolving relative paths from the current directory."""
mock_default_path.exists.return_value = False
config_path = tmp_path / "configs" / "pyrit.yaml"
config_path.parent.mkdir()
config_path.write_text(f"{field_name}:\n - ./from-config-placeholder\n", encoding="utf-8")

config = ConfigurationLoader.load_with_overrides(config_file=config_path, **{field_name: [relative_path]})
resolved = getattr(config, resolver_name)()

assert resolved == [pathlib.Path.cwd() / relative_path]

@mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH")
def test_load_with_overrides_converts_sequence_to_list(self, mock_default_path):
"""Test that Sequence inputs are converted to list for dataclass compatibility."""
Expand Down