Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyrit/score/batch_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ async def score_responses_by_filters_async(
converted_value_sha256=converted_value_sha256,
)

if not message_pieces:
raise ValueError("No entries match the provided filters. Please check your filters.")

message_pieces = self._remove_duplicates(message_pieces)

if not message_pieces:
raise ValueError("No entries match the provided filters. Please check your filters.")

Expand Down
40 changes: 40 additions & 0 deletions tests/unit/score/test_batch_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,43 @@ async def test_score_responses_by_filters_groups_by_sequence_within_conversation
assert len(messages[0].message_pieces) == 1
assert len(messages[1].message_pieces) == 2
assert len(messages[2].message_pieces) == 1

@pytest.mark.asyncio
async def test_score_responses_by_filters_removes_duplicate_message_pieces(self) -> None:
"""Test that duplicate message pieces are filtered out before batch scoring."""
memory = MagicMock()
original_piece_id = uuid.uuid4()

pieces = [
MessagePiece(
id=original_piece_id,
role="assistant",
conversation_id="conv1",
sequence=1,
original_value="Original response",
),
MessagePiece(
role="assistant",
conversation_id="conv1",
sequence=1,
original_value="Duplicate response copy",
original_prompt_id=original_piece_id,
),
]

memory.get_message_pieces.return_value = pieces

with patch.object(CentralMemory, "get_memory_instance", return_value=memory):
scorer = MagicMock()
scorer.score_prompts_batch_async = AsyncMock(return_value=[])

batch_scorer = BatchScorer()

await batch_scorer.score_responses_by_filters_async(scorer=scorer, conversation_id="conv1")

call_args = scorer.score_prompts_batch_async.call_args
messages = call_args.kwargs["messages"]

assert len(messages) == 1
assert len(messages[0].message_pieces) == 1
assert messages[0].message_pieces[0].id == original_piece_id