Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrit/models/seeds/seed_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def from_messages(
current_sequence = starting_sequence

for message in messages:
role: ChatMessageRole = "assistant" if message.api_role == "assistant" else "user"
role: ChatMessageRole = message.api_role

for piece in message.message_pieces:
seed_prompt = SeedPrompt(
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/models/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,25 @@ def test_from_messages_multiple_messages():
assert result[2].sequence == 2


@pytest.mark.parametrize(
("role", "expected_role"),
[
("system", "system"),
("developer", "developer"),
("tool", "tool"),
("simulated_assistant", "assistant"),
],
)
def test_from_messages_preserves_supported_roles(role, expected_role):
"""Test from_messages preserves supported API roles instead of collapsing to user."""
message = Message(message_pieces=[MessagePiece(role=role, original_value=f"{role} message")])

result = SeedPrompt.from_messages([message])

assert len(result) == 1
assert result[0].role == expected_role


def test_from_messages_multipart_message():
"""Test from_messages with a multipart message (e.g., text + image)."""
conv_id = str(uuid.uuid4())
Expand Down
Loading