Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,28 @@ def get_config(self) -> BedrockConfig:
"""
return resolve_config_metadata(self.config, self.config.get("model_id", ""))

def format_request(
self,
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
tool_choice: ToolChoice | None = None,
) -> dict[str, Any]:
"""Format a Bedrock converse stream request.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
tool_choice: Selection strategy for tool invocation.
system_prompt_content: System prompt content blocks to provide context to the model.

Returns:
A Bedrock converse stream request.
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return self._format_request(messages, tool_specs, system_prompt_content, tool_choice)

def _format_request(
self,
messages: Messages,
Expand All @@ -243,6 +265,9 @@ def _format_request(
) -> dict[str, Any]:
"""Format a Bedrock converse stream request.

.. deprecated::
Use :meth:`format_request` instead. This will be removed in September 2026.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
Expand All @@ -252,6 +277,12 @@ def _format_request(
Returns:
A Bedrock converse stream request.
"""
warnings.warn(
"_format_request is on the deprecation path, use format_request instead. "
"This will be removed in September 2026.",
DeprecationWarning,
stacklevel=2,
)
if not tool_specs:
has_tool_content = any(
any("toolUse" in block or "toolResult" in block for block in msg.get("content", [])) for msg in messages
Expand Down Expand Up @@ -830,7 +861,9 @@ async def count_tokens(
if system_prompt and system_prompt_content is None:
system_prompt_content = [{"text": system_prompt}]

request = self._format_request(messages, tool_specs, system_prompt_content)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
request = self._format_request(messages, tool_specs, system_prompt_content)
converse_input: dict[str, Any] = {}
if "messages" in request:
converse_input["messages"] = request["messages"]
Expand All @@ -852,13 +885,9 @@ async def count_tokens(
logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens)
return total_tokens
except Exception as e:
if (
isinstance(e, ClientError)
and e.response.get("Error", {}).get("Code") == "AccessDeniedException"
):
if isinstance(e, ClientError) and e.response.get("Error", {}).get("Code") == "AccessDeniedException":
logger.warning(
"model_id=<%s> | bedrock:CountTokens permission denied,"
" falling back to heuristic estimation: %s",
"model_id=<%s> | bedrock:CountTokens permission denied, falling back to heuristic estimation: %s",
model_id,
e,
)
Expand Down Expand Up @@ -964,7 +993,9 @@ def _stream(
"""
try:
logger.debug("formatting request")
request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice)
logger.debug("request=<%s>", request)

logger.debug("invoking model")
Expand All @@ -988,8 +1019,10 @@ def _stream(

else:
response = self.client.converse(**request)
for event in self._convert_non_streaming_to_streaming(response):
callback(event)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
for event in self._convert_non_streaming_to_streaming(response):
callback(event)

if (
"trace" in response
Expand Down Expand Up @@ -1044,15 +1077,38 @@ def _stream(
callback()
logger.debug("finished streaming response from model")

def convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.

Args:
response: The non-streaming response from the Bedrock model.

Returns:
An iterable of response events in the streaming format.
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
yield from self._convert_non_streaming_to_streaming(response)

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.

.. deprecated::
Use :meth:`convert_non_streaming_to_streaming` instead. This will be removed in September 2026.

Args:
response: The non-streaming response from the Bedrock model.

Returns:
An iterable of response events in the streaming format.
"""
warnings.warn(
"_convert_non_streaming_to_streaming is on the deprecation path, "
"use convert_non_streaming_to_streaming instead. "
"This will be removed in September 2026.",
DeprecationWarning,
stacklevel=2,
)
# Yield messageStart event
yield {"messageStart": {"role": response["output"]["message"]["role"]}}

Expand Down
76 changes: 74 additions & 2 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
from strands.types.tools import ToolSpec

pytestmark = [
pytest.mark.filterwarnings("ignore:_format_request is on the deprecation path:DeprecationWarning"),
pytest.mark.filterwarnings(
"ignore:_convert_non_streaming_to_streaming is on the deprecation path:DeprecationWarning"
),
]

FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID


Expand Down Expand Up @@ -2424,14 +2431,16 @@ def test_tool_choice_supported_no_warning(model, messages, tool_spec, captured_w
tool_choice = {"auto": {}}
model._format_request(messages, [tool_spec], tool_choice=tool_choice)

assert len(captured_warnings) == 0
non_deprecation_warnings = [w for w in captured_warnings if not issubclass(w.category, DeprecationWarning)]
assert len(non_deprecation_warnings) == 0


def test_tool_choice_none_no_warning(model, messages, captured_warnings):
"""Test that None toolChoice doesn't emit warning."""
model._format_request(messages, tool_choice=None)

assert len(captured_warnings) == 0
non_deprecation_warnings = [w for w in captured_warnings if not issubclass(w.category, DeprecationWarning)]
assert len(non_deprecation_warnings) == 0


def test_get_default_model_with_warning_supported_regions_shows_no_warning(captured_warnings):
Expand Down Expand Up @@ -3620,3 +3629,66 @@ def test_format_request_cache_tools_string_backward_compat(model, messages, mode

exp_cache_point = {"cachePoint": {"type": cache_type}}
assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point


def test_format_request_delegates_to_private(model, messages):
"""Test that format_request delegates to _format_request."""
with unittest.mock.patch.object(model, "_format_request", wraps=model._format_request) as mock_private:
result = model.format_request(messages)
mock_private.assert_called_once_with(messages, None, None, None)
assert result == model.format_request(messages)


def test_format_request_passes_all_arguments(model, messages):
"""Test that format_request passes all arguments to _format_request."""
tool_specs = [{"name": "test_tool", "description": "A test tool", "inputSchema": {"json": {}}}]
system_prompt_content = [{"text": "system prompt"}]
tool_choice = {"auto": {}}

with unittest.mock.patch.object(model, "_format_request", wraps=model._format_request) as mock_private:
model.format_request(messages, tool_specs, system_prompt_content, tool_choice)
mock_private.assert_called_once_with(messages, tool_specs, system_prompt_content, tool_choice)


def test_convert_non_streaming_to_streaming_delegates_to_private(model):
"""Test that convert_non_streaming_to_streaming delegates to _convert_non_streaming_to_streaming."""
response = {
"output": {"message": {"role": "assistant", "content": [{"text": "hello"}]}},
"stopReason": "end_turn",
}
with unittest.mock.patch.object(
model, "_convert_non_streaming_to_streaming", wraps=model._convert_non_streaming_to_streaming
) as mock_private:
result = list(model.convert_non_streaming_to_streaming(response))
mock_private.assert_called_once_with(response)
assert len(result) > 0


def test_convert_non_streaming_to_streaming_passes_all_arguments(model):
"""Test that convert_non_streaming_to_streaming passes the response to _convert_non_streaming_to_streaming."""
response = {
"output": {"message": {"role": "assistant", "content": [{"text": "hello"}]}},
"stopReason": "end_turn",
}
with unittest.mock.patch.object(
model, "_convert_non_streaming_to_streaming", wraps=model._convert_non_streaming_to_streaming
) as mock_private:
list(model.convert_non_streaming_to_streaming(response))
call_args = mock_private.call_args
assert call_args.args[0] is response


def test_format_request_private_emits_deprecation_warning(model, messages):
"""Test that _format_request emits a DeprecationWarning when called directly."""
with pytest.warns(DeprecationWarning, match="_format_request is on the deprecation path"):
model._format_request(messages)


def test_convert_non_streaming_to_streaming_private_emits_deprecation_warning(model):
"""Test that _convert_non_streaming_to_streaming emits a DeprecationWarning when called directly."""
response = {
"output": {"message": {"role": "assistant", "content": [{"text": "hello"}]}},
"stopReason": "end_turn",
}
with pytest.warns(DeprecationWarning, match="_convert_non_streaming_to_streaming is on the deprecation path"):
list(model._convert_non_streaming_to_streaming(response))