Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions pyrit/message_normalizer/generic_system_squash.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import copy

from pyrit.message_normalizer.message_normalizer import MessageListNormalizer
from pyrit.models import Message
from pyrit.models import Message, MessagePiece


class GenericSystemSquashNormalizer(MessageListNormalizer[Message]):
Expand Down Expand Up @@ -43,12 +44,35 @@ async def normalize_async(self, messages: list[Message]) -> list[Message]:
# Only system message, convert to user message
return [Message.from_prompt(prompt=first_piece.converted_value, role="user")]

# Combine system with first user message
user_message_index = next(
(i for i, message in enumerate(messages[1:], start=1) if message.api_role == "user"),
-1,
)
if user_message_index == -1:
# Preserve the instruction content without rewriting non-user messages.
return [Message.from_prompt(prompt=first_piece.converted_value, role="user")] + list(messages[1:])

# Combine system with the first user message
system_content = first_piece.converted_value
user_piece = messages[1].get_piece()
user_piece = messages[user_message_index].get_piece()
user_content = user_piece.converted_value

combined_content = f"### Instructions ###\n\n{system_content}\n\n######\n\n{user_content}"
squashed_message = Message.from_prompt(prompt=combined_content, role="user")
squashed_message = copy.deepcopy(messages[user_message_index])

if squashed_message.message_pieces[0].converted_value_data_type == "text":
squashed_message.message_pieces[0].original_value = combined_content
squashed_message.message_pieces[0].converted_value = combined_content
else:
squashed_message.message_pieces.insert(
0,
MessagePiece(
role="user",
original_value=combined_content,
conversation_id=user_piece.conversation_id,
sequence=user_piece.sequence,
),
)

# Return the squashed message followed by remaining messages (skip first two)
return [squashed_message] + list(messages[2:])
return list(messages[1:user_message_index]) + [squashed_message] + list(messages[user_message_index + 1 :])
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,56 @@ async def test_generic_squash_normalize_to_dicts_async():
assert "### Instructions ###" in result[0]["converted_value"]
assert "System message" in result[0]["converted_value"]
assert "User message" in result[0]["converted_value"]


@pytest.mark.asyncio
async def test_generic_squash_preserves_multipart_user_message():
"""Test that squashing keeps non-text user pieces instead of collapsing to plain text."""
conversation_id = "conv-1"
messages = [
_make_message("system", "System message"),
Message(
message_pieces=[
MessagePiece(
role="user",
original_value="User message",
conversation_id=conversation_id,
sequence=0,
),
MessagePiece(
role="user",
original_value="/tmp/example.png",
original_value_data_type="image_path",
conversation_id=conversation_id,
sequence=0,
),
]
),
]

result = await GenericSystemSquashNormalizer().normalize_async(messages)

assert len(result) == 1
assert result[0].api_role == "user"
assert len(result[0].message_pieces) == 2
assert result[0].get_value() == "### Instructions ###\n\nSystem message\n\n######\n\nUser message"
assert result[0].message_pieces[1].converted_value == "/tmp/example.png"
assert result[0].message_pieces[1].converted_value_data_type == "image_path"


@pytest.mark.asyncio
async def test_generic_squash_uses_first_user_message_instead_of_rewriting_assistant():
"""Test that squash targets the first user message even if assistant messages appear first."""
messages = [
_make_message("system", "System message"),
_make_message("assistant", "Assistant message"),
_make_message("user", "User message"),
]

result = await GenericSystemSquashNormalizer().normalize_async(messages)

assert len(result) == 2
assert result[0].api_role == "assistant"
assert result[0].get_value() == "Assistant message"
assert result[1].api_role == "user"
assert result[1].get_value() == "### Instructions ###\n\nSystem message\n\n######\n\nUser message"