Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class BedrockConfig(TypedDict, total=False):
guardrail_redact_input_message: If a Bedrock Input guardrail triggers, replace the input with this message.
guardrail_redact_output: Flag to redact output if guardrail is triggered. Defaults to False.
guardrail_redact_output_message: If a Bedrock Output guardrail triggers, replace output with this message.
guardrail_last_turn_only: Flag to send only the last turn to guardrails instead of full conversation.
Defaults to False.
max_tokens: Maximum number of tokens to generate in the response
model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0")
include_tool_result_status: Flag to include status field in tool results.
Expand All @@ -105,6 +107,7 @@ class BedrockConfig(TypedDict, total=False):
guardrail_redact_input_message: Optional[str]
guardrail_redact_output: Optional[bool]
guardrail_redact_output_message: Optional[str]
guardrail_last_turn_only: Optional[bool]
max_tokens: Optional[int]
model_id: str
include_tool_result_status: Optional[Literal["auto"] | bool]
Expand Down Expand Up @@ -206,9 +209,12 @@ def _format_request(
Returns:
A Bedrock converse stream request.
"""
messages_for_request = messages

if not tool_specs:
has_tool_content = any(
any("toolUse" in block or "toolResult" in block for block in msg.get("content", [])) for msg in messages
any("toolUse" in block or "toolResult" in block for block in msg.get("content", []))
for msg in messages_for_request
)
if has_tool_content:
tool_specs = [noop_tool.tool_spec]
Expand All @@ -224,7 +230,10 @@ def _format_request(

return {
"modelId": self.config["model_id"],
"messages": self._format_bedrock_messages(messages),
"messages": self._format_bedrock_messages(
messages_for_request,
guardrail_last_turn_only=bool(self.config.get("guardrail_last_turn_only", False)),
),
"system": system_blocks,
**(
{
Expand Down Expand Up @@ -295,16 +304,20 @@ def _format_request(
),
}

def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
def _format_bedrock_messages(
self, messages: Messages, guardrail_last_turn_only: bool = False
) -> list[dict[str, Any]]:
"""Format messages for Bedrock API compatibility.

This function ensures messages conform to Bedrock's expected format by:
- Filtering out SDK_UNKNOWN_MEMBER content blocks
- Eagerly filtering content blocks to only include Bedrock-supported fields
- Ensuring all message content blocks are properly formatted for the Bedrock API
- Optionally wrapping the last user message in guardrailConverseContent blocks

Args:
messages: List of messages to format
guardrail_last_turn_only: If True, wrap the last user message content in guardrailConverseContent blocks

Returns:
Messages formatted for Bedrock API compatibility
Expand All @@ -321,7 +334,15 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
filtered_unknown_members = False
dropped_deepseek_reasoning_content = False

for message in messages:
# Find the index of the last user message if wrapping is enabled
last_user_idx = -1
if guardrail_last_turn_only:
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_idx = i
break

for idx, message in enumerate(messages):
cleaned_content: list[dict[str, Any]] = []

for content_block in message["content"]:
Expand All @@ -338,6 +359,11 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:

# Format content blocks for Bedrock API compatibility
formatted_content = self._format_request_message_content(content_block)

# Wrap text content in guardrailConverseContent if this is the last user message
if guardrail_last_turn_only and idx == last_user_idx and "text" in formatted_content:
formatted_content = {"guardContent": {"text": {"text": formatted_content["text"]}}}

cleaned_content.append(formatted_content)

# Create new message with cleaned content (skip if empty)
Expand Down
79 changes: 79 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2196,3 +2196,82 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client
"(documentChar, documentPage, documentChunk, searchResultLocation, or web) "
"with the location fields nested inside."
)


@pytest.mark.asyncio
async def test_format_request_with_guardrail_last_turn_only(model):
"""Test _format_request passes apply_last_turn flag correctly."""
model.update_config(guardrail_id="test-guardrail", guardrail_version="DRAFT", guardrail_last_turn_only=True)

messages = [
{"role": "user", "content": [{"text": "First message"}]},
{"role": "assistant", "content": [{"text": "First response"}]},
{"role": "user", "content": [{"text": "Latest message"}]},
]

request = model._format_request(messages)

# All messages should be in the request
formatted_messages = request["messages"]
assert len(formatted_messages) == 3

# Last user message should be wrapped
assert "guardContent" in formatted_messages[2]["content"][0]
assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "Latest message"

# First user message should NOT be wrapped
assert "text" in formatted_messages[0]["content"][0]
assert formatted_messages[0]["content"][0]["text"] == "First message"


def test_format_bedrock_messages_multimodal_content(model):
"""Test that only text blocks are wrapped, not images."""
messages = [
{
"role": "user",
"content": [
{"text": "Look at this image"},
{"image": {"format": "png", "source": {"bytes": b"fake_image_data"}}},
],
}
]

result = model._format_bedrock_messages(messages, guardrail_last_turn_only=True)

# Should have 2 content blocks
assert len(result[0]["content"]) == 2

# Text should be wrapped
assert "guardContent" in result[0]["content"][0]
assert result[0]["content"][0]["guardContent"]["text"]["text"] == "Look at this image"

# Image should NOT be wrapped
assert "image" in result[0]["content"][1]


def test_format_bedrock_messages_wraps_last_user_text(model):
"""Test that only the last user message text is wrapped in guardContent."""
messages = [
{"role": "user", "content": [{"text": "First message"}]},
{"role": "assistant", "content": [{"text": "First response"}]},
{"role": "user", "content": [{"text": "Latest message"}]},
]

result = model._format_bedrock_messages(messages, guardrail_last_turn_only=True)

# All messages should be present
assert len(result) == 3

# First user message should NOT be wrapped
assert result[0]["role"] == "user"
assert "text" in result[0]["content"][0]
assert result[0]["content"][0]["text"] == "First message"

# Assistant message should be unchanged
assert result[1]["role"] == "assistant"
assert result[1]["content"][0]["text"] == "First response"

# Last user message should be wrapped in guardContent
assert result[2]["role"] == "user"
assert "guardContent" in result[2]["content"][0]
assert result[2]["content"][0]["guardContent"]["text"]["text"] == "Latest message"
87 changes: 87 additions & 0 deletions tests_integ/test_bedrock_guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,93 @@ def list_users() -> str:
assert tool_result["content"][0]["text"] == INPUT_REDACT_MESSAGE


def test_guardrail_last_turn_only(boto_session, bedrock_guardrail):
"""Test that guardrail_last_turn_only only sends the last turn to guardrails."""
bedrock_model = BedrockModel(
guardrail_id=bedrock_guardrail,
guardrail_version="DRAFT",
guardrail_last_turn_only=True,
boto_session=boto_session,
)

agent = Agent(model=bedrock_model, system_prompt="You are a helpful assistant.", callback_handler=None)

# First conversation turn - should not trigger guardrail
response1 = agent("Hello, how are you?")
assert response1.stop_reason != "guardrail_intervened"

# Second conversation turn with blocked word - should trigger guardrail
# Since guardrail_last_turn_only=True, only this message and the previous assistant response
# should be evaluated by the guardrail, not the entire conversation history
response2 = agent("CACTUS")
assert response2.stop_reason == "guardrail_intervened"
assert str(response2).strip() == BLOCKED_INPUT


def test_guardrail_last_turn_only_recovery_scenario(boto_session, bedrock_guardrail):
"""Test guardrail recovery: blocked content followed by normal question.

This tests the key benefit of guardrail_last_turn_only:
1. First turn: blocked content triggers guardrail
2. Second turn: normal question should work because only last turn is analyzed
"""
bedrock_model = BedrockModel(
guardrail_id=bedrock_guardrail,
guardrail_version="DRAFT",
guardrail_last_turn_only=True,
boto_session=boto_session,
)

agent = Agent(model=bedrock_model, system_prompt="You are a helpful assistant.", callback_handler=None)

# First turn - should be blocked by guardrail
response1 = agent("CACTUS")
assert response1.stop_reason == "guardrail_intervened"
assert str(response1).strip() == BLOCKED_INPUT

# Second turn - should work normally with last turn only
# This is the key test: normal questions should work after blocked content
response2 = agent("What is the weather like today?")
assert response2.stop_reason != "guardrail_intervened"
assert str(response2).strip() != BLOCKED_INPUT

# Verify the conversation has both messages
assert len(agent.messages) == 4 # 2 user + 2 assistant messages


def test_guardrail_last_turn_only_output_intervention(boto_session, bedrock_guardrail):
"""Test that guardrail_last_turn_only works with OUTPUT guardrails.

This tests that when the assistant tries to output blocked content,
the OUTPUT guardrail intervenes, even with guardrail_last_turn_only=True.
Then verifies that subsequent normal responses work correctly.
"""
bedrock_model = BedrockModel(
guardrail_id=bedrock_guardrail,
guardrail_version="DRAFT",
guardrail_last_turn_only=True,
guardrail_stream_processing_mode="sync",
boto_session=boto_session,
)

agent = Agent(
model=bedrock_model,
system_prompt="When asked to say the word, say CACTUS. Otherwise respond normally.",
callback_handler=None,
load_tools_from_directory=False,
)

# First turn - assistant tries to output "CACTUS", should be blocked by OUTPUT guardrail
response1 = agent("Say the word.")
assert response1.stop_reason == "guardrail_intervened"
assert BLOCKED_OUTPUT in str(response1)

# Second turn - normal question should work fine
response2 = agent("What is 2+2?")
assert response2.stop_reason != "guardrail_intervened"
assert BLOCKED_OUTPUT not in str(response2)


def test_guardrail_input_intervention_properly_redacts_in_session(boto_session, bedrock_guardrail, temp_dir):
bedrock_model = BedrockModel(
guardrail_id=bedrock_guardrail,
Expand Down