Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pyrit/setup/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def _load_initializers_from_scripts(
obj = getattr(module, name)
# Check if it's a class, is a subclass of PyRITInitializer,
# and is not the base class itself
if isinstance(obj, type) and issubclass(obj, PyRITInitializer) and obj is not PyRITInitializer:
if (
isinstance(obj, type)
and issubclass(obj, PyRITInitializer)
and obj is not PyRITInitializer
and obj.__module__ == module.__name__
):
try:
# Instantiate the initializer class
initializer = obj()
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/setup/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,59 @@ def test_script_not_found_raises_error(self):
with pytest.raises(FileNotFoundError):
_load_initializers_from_scripts(script_paths=["nonexistent_script.py"])

def test_ignores_imported_initializer_classes(self):
"""Test that imported initializer classes are not instantiated from the script."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = pathlib.Path(temp_dir)
helper_path = temp_path / "helper_init.py"
script_path = temp_path / "script_init.py"

helper_path.write_text(
"""
from pyrit.setup.initializers import PyRITInitializer

class ImportedInitializer(PyRITInitializer):
@property
def name(self) -> str:
return "Imported"

@property
def description(self) -> str:
return "Imported initializer"

async def initialize_async(self) -> None:
pass
"""
)

script_path.write_text(
f"""
import sys

sys.path.insert(0, {temp_dir!r})

from helper_init import ImportedInitializer
from pyrit.setup.initializers import PyRITInitializer

class LocalInitializer(PyRITInitializer):
@property
def name(self) -> str:
return "Local"

@property
def description(self) -> str:
return "Local initializer"

async def initialize_async(self) -> None:
pass
"""
)

initializers = _load_initializers_from_scripts(script_paths=[script_path])

assert len(initializers) == 1
assert initializers[0].name == "Local"


class TestInitializePyrit:
"""Tests for initialize_pyrit_async function - basic orchestration tests."""
Expand Down