Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pyrit/prompt_normalizer/prompt_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,15 @@ async def send_prompt_async(

# handling empty responses message list and None responses
if not responses or not any(responses):
return None
empty_response = construct_response_from_request(
request=request.message_pieces[0],
response_text_pieces=[""],
response_type="text",
error="empty",
)
await self._calc_hash(request=empty_response)
self._memory.add_message_to_memory(request=empty_response)
return empty_response

# Process all response messages (targets return list[Message])
# Only apply response converters to the last message (final response)
Expand Down Expand Up @@ -191,7 +199,7 @@ async def send_prompt_batch_to_target_async(
"conversation_id",
]

responses = await batch_task_async(
return await batch_task_async(
prompt_target=target,
batch_size=batch_size,
items_to_batch=batch_items,
Expand All @@ -202,9 +210,6 @@ async def send_prompt_batch_to_target_async(
attack_identifier=attack_identifier,
)

# Filter out None responses (e.g., from empty responses)
return [response for response in responses if response is not None]

async def convert_values(
self,
converter_configurations: list[PromptConverterConfiguration],
Expand Down
81 changes: 58 additions & 23 deletions tests/unit/prompt_normalizer/test_prompt_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,22 @@ async def test_send_prompt_async_multiple_converters(mock_memory_instance, seed_

@pytest.mark.asyncio
async def test_send_prompt_async_no_response_adds_memory(mock_memory_instance, seed_group):
prompt_target = AsyncMock()
prompt_target = MagicMock()
prompt_target.send_prompt_async = AsyncMock(return_value=None)
prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget")

normalizer = PromptNormalizer()
message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user")

await normalizer.send_prompt_async(message=message, target=prompt_target)
assert mock_memory_instance.add_message_to_memory.call_count == 1
response = await normalizer.send_prompt_async(message=message, target=prompt_target)
assert mock_memory_instance.add_message_to_memory.call_count == 2

request = mock_memory_instance.add_message_to_memory.call_args[1]["request"]
assert_message_piece_hashes_set(request)
assert response.message_pieces[0].response_error == "empty"
assert response.message_pieces[0].original_value == ""
assert response.message_pieces[0].original_value_data_type == "text"
assert_message_piece_hashes_set(response)


@pytest.mark.asyncio
Expand Down Expand Up @@ -184,34 +189,29 @@ async def test_send_prompt_async_request_response_added_to_memory(mock_memory_in

@pytest.mark.asyncio
async def test_send_prompt_async_exception(mock_memory_instance, seed_group):
prompt_target = AsyncMock()
prompt_target = MagicMock()
prompt_target.send_prompt_async = AsyncMock(side_effect=ValueError("test_exception"))
prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget")

seed_prompt_value = seed_group.prompts[0].value

normalizer = PromptNormalizer()
message = Message.from_prompt(prompt=seed_prompt_value, role="user")

with patch("pyrit.models.construct_response_from_request") as mock_construct:
mock_construct.return_value = "test"
with pytest.raises(Exception, match="Error sending prompt with conversation ID"):
await normalizer.send_prompt_async(message=message, target=prompt_target)

try:
await normalizer.send_prompt_async(message=message, target=prompt_target)
except ValueError:
assert mock_memory_instance.add_message_to_memory.call_count == 2
assert mock_memory_instance.add_message_to_memory.call_count == 2

# Validate that first request is added to memory, then exception is added to memory
assert (
seed_prompt_value
== mock_memory_instance.add_message_to_memory.call_args_list[0][1]["request"]
.message_pieces[0]
.original_value
)
assert (
mock_memory_instance.add_message_to_memory.call_args_list[1][1]["request"]
.message_pieces[0]
.original_value
== "test_exception"
)
# Validate that first request is added to memory, then exception is added to memory
assert (
seed_prompt_value
== mock_memory_instance.add_message_to_memory.call_args_list[0][1]["request"].message_pieces[0].original_value
)
assert (
"test_exception"
in mock_memory_instance.add_message_to_memory.call_args_list[1][1]["request"].message_pieces[0].original_value
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -383,6 +383,41 @@ async def test_prompt_normalizer_send_prompt_batch_async_throws(
assert len(results) == 1


@pytest.mark.asyncio
async def test_prompt_normalizer_send_prompt_batch_async_preserves_empty_response_alignment(
mock_memory_instance,
):
prompt_target = MagicMock()
prompt_target._max_requests_per_minute = None
prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget")
prompt_target.send_prompt_async = AsyncMock(
side_effect=[
[MessagePiece(role="assistant", original_value="response 1", conversation_id="conv-1").to_message()],
None,
]
)

normalizer = PromptNormalizer()
requests = [
NormalizerRequest(
message=Message.from_prompt(prompt="prompt 1", role="user"),
conversation_id="conv-1",
),
NormalizerRequest(
message=Message.from_prompt(prompt="prompt 2", role="user"),
conversation_id="conv-2",
),
]

results = await normalizer.send_prompt_batch_to_target_async(requests=requests, target=prompt_target, batch_size=2)

assert len(results) == 2
assert results[0].message_pieces[0].original_value == "response 1"
assert results[1].message_pieces[0].response_error == "empty"
assert results[1].message_pieces[0].original_value == ""
assert results[1].message_pieces[0].conversation_id == "conv-2"


@pytest.mark.asyncio
async def test_build_message(mock_memory_instance, seed_group):
# This test is obsolete since _build_message was removed and message preparation
Expand Down