Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pyrit/score/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,10 @@ async def score_prompts_batch_async(
ValueError: If objectives is empty or if the number of objectives doesn't match
the number of messages.
"""
if not objectives:
if objectives is None:
objectives = [""] * len(messages)

elif len(objectives) != len(messages):
raise ValueError("The number of tasks must match the number of messages.")
raise ValueError("The number of objectives must match the number of messages.")

if len(messages) == 0:
return []
Expand Down Expand Up @@ -456,7 +455,7 @@ async def score_image_batch_async(
Raises:
ValueError: If the number of objectives does not match the number of image_paths.
"""
if objectives and len(objectives) != len(image_paths):
if objectives is not None and len(objectives) != len(image_paths):
raise ValueError("The number of objectives must match the number of image_paths.")

if len(image_paths) == 0:
Expand All @@ -465,10 +464,10 @@ async def score_image_batch_async(
prompt_target = getattr(self, "_prompt_target", None)
results = await batch_task_async(
task_func=self.score_image_async,
task_arguments=["image_path", "objective"] if objectives else ["image_path"],
task_arguments=["image_path", "objective"] if objectives is not None else ["image_path"],
prompt_target=prompt_target,
batch_size=batch_size,
items_to_batch=[image_paths, objectives] if objectives else [image_paths],
items_to_batch=[image_paths, objectives] if objectives is not None else [image_paths],
)

return [score for sublist in results for score in sublist]
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/score/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,25 @@ async def test_scorer_score_responses_batch_async(patch_central_database):
assert len(fake_scores) == 2


@pytest.mark.asyncio
async def test_score_prompts_batch_async_rejects_explicit_empty_objectives():
"""Test explicit empty objectives are rejected for non-empty message batches."""
scorer = MockScorer()
message = MessagePiece(role="user", original_value="Hello user", sequence=1).to_message()

with pytest.raises(ValueError, match="objectives"):
await scorer.score_prompts_batch_async(messages=[message], objectives=[])


@pytest.mark.asyncio
async def test_score_image_batch_async_rejects_explicit_empty_objectives():
"""Test explicit empty objectives are rejected for non-empty image batches."""
scorer = MockScorer()

with pytest.raises(ValueError, match="objectives"):
await scorer.score_image_batch_async(image_paths=["test_image.png"], objectives=[])


@pytest.mark.asyncio
async def test_score_response_async_empty_scorers():
"""Test that score_response_async returns empty list when no scorers provided."""
Expand Down