Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/base1.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,11 +674,13 @@ def _resolve_bedrock_aws_credentials(
if auth_type == "iam_role":
_drop_bedrock_access_keys(validated)
validated.pop(_BEDROCK_BEARER_TOKEN_FIELD, None)
validated.pop(_BEDROCK_LITELLM_BEARER_KWARG, None)
return validated

if auth_type == "access_keys":
_require_bedrock_access_keys(validated)
validated.pop(_BEDROCK_BEARER_TOKEN_FIELD, None)
validated.pop(_BEDROCK_LITELLM_BEARER_KWARG, None)
return validated

if auth_type == "bearer_token":
Expand All @@ -689,10 +691,16 @@ def _resolve_bedrock_aws_credentials(
# No auth_type: strip blank access keys (boto3 chain takes over) and
# drop any bearer token — bearer auth must be opted into explicitly
# via auth_type='bearer_token' rather than promoted from this branch.
# A non-blank `api_key` is preserved here to support `LLM.complete()`'s
# re-validation pass, where bearer-mode kwargs round-trip without their
# original `auth_type`. A blank `api_key` (Pydantic's `None` default
# for an unset field) is dropped so LiteLLM doesn't see `api_key=None`.
for key in _BEDROCK_AWS_KEY_FIELDS:
if not validated.get(key):
validated.pop(key, None)
validated.pop(_BEDROCK_BEARER_TOKEN_FIELD, None)
if not validated.get(_BEDROCK_LITELLM_BEARER_KWARG):
validated.pop(_BEDROCK_LITELLM_BEARER_KWARG, None)
return validated


Expand All @@ -703,6 +711,9 @@ class AWSBedrockLLMParameters(BaseChatCompletionParameters):
aws_secret_access_key: str | None = None
# AWS_BEARER_TOKEN_BEDROCK; resolver translates to LiteLLM's `api_key`.
aws_bearer_token: str | None = None
# Declared so it survives `LLM.complete()`'s re-validation of self.kwargs;
# otherwise Pydantic would drop it as an unknown field.
api_key: str | None = None
aws_region_name: str | None = None
aws_profile_name: str | None = None # For AWS SSO authentication
model_id: str | None = None # For Application Inference Profile (cost tracking)
Expand Down Expand Up @@ -1196,6 +1207,9 @@ class AWSBedrockEmbeddingParameters(BaseEmbeddingParameters):
aws_secret_access_key: str | None = None
# AWS_BEARER_TOKEN_BEDROCK; resolver translates to LiteLLM's `api_key`.
aws_bearer_token: str | None = None
# Declared so it survives Pydantic re-validation if the kwargs ever round-
# trip through validate() (parity with the LLM param class).
api_key: str | None = None
aws_region_name: str | None

@staticmethod
Expand Down
40 changes: 40 additions & 0 deletions unstract/sdk1/tests/test_bedrock_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,30 @@ def test_llm_bearer_token_strips_surrounding_whitespace() -> None:
assert out["api_key"] == "bedrock-key-abc"


def test_llm_bearer_token_survives_revalidation() -> None:
"""Bearer-mode kwargs must round-trip through a second validate() call.

``LLM.complete()`` re-runs ``validate({**self.kwargs, **kwargs})`` on
every call. The second pass has no ``auth_type`` and no
``aws_bearer_token`` (both stripped on the first pass), so the resolver
can't re-translate. ``api_key`` must survive Pydantic's
``model_dump()`` on the round-trip — otherwise LiteLLM falls through
to SigV4 signing and 401s with "Unable to locate credentials".
"""
first = AWSBedrockLLMParameters.validate(
{
"auth_type": "bearer_token",
"model": "anthropic.claude-3-haiku-20240307-v1:0",
"region_name": "us-east-1",
"aws_bearer_token": "bedrock-key-abc",
}
)
assert first["api_key"] == "bedrock-key-abc"

second = AWSBedrockLLMParameters.validate({**first, "max_tokens": 100})
assert second["api_key"] == "bedrock-key-abc"


def test_llm_iam_role_drops_stale_bearer_token() -> None:
out = AWSBedrockLLMParameters.validate(
{
Expand Down Expand Up @@ -489,6 +513,22 @@ def test_embedding_bearer_token_strips_surrounding_whitespace() -> None:
assert out["api_key"] == "bedrock-key-abc"


def test_embedding_bearer_token_survives_revalidation() -> None:
"""Defensive parity with the LLM round-trip test."""
first = AWSBedrockEmbeddingParameters.validate(
{
"auth_type": "bearer_token",
"model": "amazon.titan-embed-text-v2:0",
"region_name": "us-east-1",
"aws_bearer_token": "bedrock-key-abc",
}
)
assert first["api_key"] == "bedrock-key-abc"

second = AWSBedrockEmbeddingParameters.validate({**first})
assert second["api_key"] == "bedrock-key-abc"


def test_embedding_iam_role_drops_stale_bearer_token() -> None:
out = AWSBedrockEmbeddingParameters.validate(
{
Expand Down