Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pyrit/message_normalizer/conversation_context_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ class ConversationContextNormalizer(MessageStringNormalizer):
...
"""

_ROLE_LABELS = {
"user": "User",
"assistant": "Assistant",
"tool": "Tool",
"developer": "Developer",
}

async def normalize_string_async(self, messages: list[Message]) -> str:
"""
Normalize a list of messages into a turn-based context string.
Expand Down Expand Up @@ -55,7 +62,7 @@ async def normalize_string_async(self, messages: list[Message]) -> str:

# Format the piece content
content = self._format_piece_content(piece)
role_label = "User" if piece.api_role == "user" else "Assistant"
role_label = self._ROLE_LABELS.get(piece.api_role, piece.api_role.capitalize())
context_parts.append(f"{role_label}: {content}")

return "\n".join(context_parts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,31 @@ async def test_shows_original_if_different_from_converted(self):

assert "converted text" in result
assert "(original: original text)" in result

@pytest.mark.asyncio
async def test_preserves_tool_role_label(self):
"""Test that tool messages keep the Tool label in context output."""
normalizer = ConversationContextNormalizer()
messages = [
_make_message("user", "Call the weather tool"),
_make_message("tool", "72F and sunny"),
]

result = await normalizer.normalize_string_async(messages)

assert "Tool: 72F and sunny" in result
assert "Assistant: 72F and sunny" not in result

@pytest.mark.asyncio
async def test_preserves_developer_role_label(self):
"""Test that developer messages keep the Developer label in context output."""
normalizer = ConversationContextNormalizer()
messages = [
_make_message("user", "Use concise units"),
_make_message("developer", "Prefer metric units"),
]

result = await normalizer.normalize_string_async(messages)

assert "Developer: Prefer metric units" in result
assert "Assistant: Prefer metric units" not in result