Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions pyrit/datasets/seed_datasets/remote/promptintel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
self,
*,
api_key: Optional[str] = None,
severity: Optional[PromptIntelSeverity] = None,
categories: Optional[list[PromptIntelCategory]] = None,
severity: Optional[PromptIntelSeverity | str] = None,
categories: Optional[list[PromptIntelCategory | str]] = None,
search: Optional[str] = None,
max_prompts: Optional[int] = None,
) -> None:
Expand All @@ -93,6 +93,8 @@ def __init__(
ValueError: If an invalid severity or category is provided.
"""
self._api_key = api_key
normalized_severity: Optional[PromptIntelSeverity] = None
normalized_categories: Optional[list[PromptIntelCategory]] = None

if severity is not None:
valid_severities = {s.value for s in PromptIntelSeverity}
Expand All @@ -101,20 +103,28 @@ def __init__(
raise ValueError(
f"Invalid severity: {sev_value}. Valid values: {[s.value for s in PromptIntelSeverity]}"
)
normalized_severity = (
severity if isinstance(severity, PromptIntelSeverity) else PromptIntelSeverity(sev_value)
)

if categories is not None:
valid_categories = {c.value for c in PromptIntelCategory}
category_values = [cat.value if isinstance(cat, PromptIntelCategory) else cat for cat in categories]
invalid_categories = {
cat.value if isinstance(cat, PromptIntelCategory) else cat for cat in categories
} - valid_categories
category_value for category_value in category_values if category_value not in valid_categories
}
if invalid_categories:
raise ValueError(
f"Invalid categories: {', '.join(str(c) for c in invalid_categories)}. "
f"Valid values: {[c.value for c in PromptIntelCategory]}"
)
normalized_categories = [
category if isinstance(category, PromptIntelCategory) else PromptIntelCategory(category)
for category in categories
]

self._severity = severity
self._categories = categories
self._severity = normalized_severity
self._categories = normalized_categories
self._search = search
self._max_prompts = max_prompts
self.source = "https://promptintel.novahunting.ai"
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/datasets/test_promptintel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,17 @@ async def test_severity_filter_passed_to_api(self, api_key, mock_promptintel_res
call_kwargs = mock_get.call_args
assert call_kwargs.kwargs["params"]["severity"] == "critical"

@pytest.mark.asyncio
async def test_string_severity_filter_passed_to_api(self, api_key, mock_promptintel_response):
loader = _PromptIntelDataset(api_key=api_key, severity="critical")
mock_resp = _make_mock_response(json_data=mock_promptintel_response)

with patch("requests.get", return_value=mock_resp) as mock_get:
await loader.fetch_dataset()

call_kwargs = mock_get.call_args
assert call_kwargs.kwargs["params"]["severity"] == "critical"

@pytest.mark.asyncio
async def test_category_filter_passed_to_api(self, api_key, mock_promptintel_response):
loader = _PromptIntelDataset(api_key=api_key, categories=[PromptIntelCategory.MANIPULATION])
Expand All @@ -358,6 +369,17 @@ async def test_category_filter_passed_to_api(self, api_key, mock_promptintel_res
call_kwargs = mock_get.call_args
assert call_kwargs.kwargs["params"]["category"] == "manipulation"

@pytest.mark.asyncio
async def test_string_category_filter_passed_to_api(self, api_key, mock_promptintel_response):
loader = _PromptIntelDataset(api_key=api_key, categories=["manipulation"])
mock_resp = _make_mock_response(json_data=mock_promptintel_response)

with patch("requests.get", return_value=mock_resp) as mock_get:
await loader.fetch_dataset()

call_kwargs = mock_get.call_args
assert call_kwargs.kwargs["params"]["category"] == "manipulation"

@pytest.mark.asyncio
async def test_multiple_categories_make_separate_api_calls(self, api_key):
manipulation_response = {
Expand Down