Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ jobs:
- name: Run pre-commit
run: uv run pre-commit run --all-files

- name: Run mypy
run: uv run mypy src/

test:
name: Test Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
Expand Down
47 changes: 47 additions & 0 deletions src/bedrock_agentcore/_utils/snake_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Utilities for wrapping boto3 methods to accept snake_case kwargs."""

import functools
import re
from typing import Any, Callable, Dict

_VALID_SNAKE_RE = re.compile(r"^[a-z][a-z0-9]*(_[a-z0-9]+)*$")


def snake_to_camel(name: str) -> str:
"""Convert a snake_case string to camelCase.

Already-camelCase strings pass through unchanged (no underscores to split on).
Raises ValueError for malformed snake_case (e.g. leading/trailing underscores,
consecutive underscores, uppercase characters).
"""
if "_" not in name:
return name
if not _VALID_SNAKE_RE.match(name):
raise ValueError(f"Invalid parameter name: '{name}'")
parts = name.split("_")
return parts[0] + "".join(p.title() for p in parts[1:])


def accept_snake_case_kwargs(method: Callable[..., Any]) -> Callable[..., Any]:
"""Wrap a boto3 method to accept both snake_case and camelCase kwargs.

Converts all snake_case kwargs to camelCase before forwarding.
Raises TypeError if both forms are provided (e.g. memory_id and memoryId).
"""

@functools.wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
converted: Dict[str, Any] = {}
original_keys: Dict[str, str] = {}
for key, value in kwargs.items():
camel_key = snake_to_camel(key)
if camel_key in converted:
raise TypeError(
f"Got both '{original_keys[camel_key]}' and '{key}' for the same parameter. "
f"Use one or the other."
)
original_keys[camel_key] = key
converted[camel_key] = value
return method(*args, **converted)

return wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _is_valid_adot_document(item: Any) -> bool:
return isinstance(item, dict) and "scope" in item and "traceId" in item and "spanId" in item


def _validate_spans(spans):
def _validate_spans(spans: Any) -> bool:
"""Validate spans are OpenTelemetry Span objects."""
if not spans:
return False
Expand Down Expand Up @@ -127,14 +127,14 @@ def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> List[Eva
]

# Check if spans are already in ADOT format or need conversion
if _is_adot_format(evaluation_case.actual_trajectory):
if _is_adot_format(evaluation_case.actual_trajectory): # type: ignore[arg-type]
# Already in ADOT format (fetched from CloudWatch), use as-is
spans = evaluation_case.actual_trajectory
else:
# Raw OTel spans from in-memory exporter, validate and convert
if not _validate_spans(evaluation_case.actual_trajectory):
return [EvaluationOutput(score=0.0, test_pass=False, reason="Invalid span objects")]
spans = convert_strands_to_adot(evaluation_case.actual_trajectory)
spans = convert_strands_to_adot(evaluation_case.actual_trajectory) # type: ignore[arg-type]

request_payload = {"evaluatorId": self.evaluator_id, "evaluationInput": {"sessionSpans": spans}}

Expand Down Expand Up @@ -165,7 +165,7 @@ async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT])
return await asyncio.to_thread(self.evaluate, evaluation_case)


def create_strands_evaluator(evaluator_id: str, **kwargs) -> StrandsEvalsAgentCoreEvaluator:
def create_strands_evaluator(evaluator_id: str, **kwargs: Any) -> StrandsEvalsAgentCoreEvaluator:
"""Create Strands-compatible evaluator backed by AgentCore Evaluation API.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class SpanParser:
"""

@staticmethod
def extract_metadata(span) -> SpanMetadata:
def extract_metadata(span: Any) -> SpanMetadata:
"""Extract core span metadata."""
if not hasattr(span, "context") or not span.context:
raise ValueError(f"Span '{getattr(span, 'name', 'unknown')}' missing required context")
Expand All @@ -96,7 +96,7 @@ def extract_metadata(span) -> SpanMetadata:
)

@staticmethod
def extract_resource_info(span) -> ResourceInfo:
def extract_resource_info(span: Any) -> ResourceInfo:
"""Extract resource and scope information."""
resource_attrs = {}
if hasattr(span, "resource") and span.resource and hasattr(span.resource, "attributes"):
Expand All @@ -115,7 +115,7 @@ def extract_resource_info(span) -> ResourceInfo:
)

@staticmethod
def get_span_attributes(span) -> Dict[str, Any]:
def get_span_attributes(span: Any) -> Dict[str, Any]:
"""Safely extract span attributes."""
return dict(span.attributes) if hasattr(span, "attributes") and span.attributes else {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def extract_tool_execution(cls, events: List[Any]) -> Optional[ToolExecution]:
class StrandsToADOTConverter:
"""Convert Strands OTel spans to ADOT format."""

def __init__(self):
def __init__(self) -> None:
"""Initialize converter with parsers and builder."""
self.span_parser = SpanParser()
self.event_parser = StrandsEventParser()
self.doc_builder = ADOTDocumentBuilder()

def convert_span(self, span) -> List[Dict[str, Any]]:
def convert_span(self, span: Any) -> List[Dict[str, Any]]:
"""Convert a single span to ADOT documents."""
documents = []

Expand Down
10 changes: 5 additions & 5 deletions src/bedrock_agentcore/identity/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _get_iam_jwt_token(region: str) -> str:
try:
response = sts_client.get_web_identity_token(**params)
logger.info("Successfully obtained AWS IAM JWT token")
return response["WebIdentityToken"]
return response["WebIdentityToken"] # type: ignore[no-any-return]
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ["FeatureDisabledException", "FeatureDisabled"]:
Expand Down Expand Up @@ -231,7 +231,7 @@ def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable:
def decorator(func: Callable) -> Callable:
client = IdentityClient(_get_region())

async def _get_api_key():
async def _get_api_key() -> str:
return await client.get_api_key(
provider_name=provider_name,
agent_identity_token=await _get_workload_access_token(client),
Expand Down Expand Up @@ -268,7 +268,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
return decorator


def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]):
def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]) -> Optional[str]:
if user_provided_oauth2_callback_url:
return user_provided_oauth2_callback_url

Expand Down Expand Up @@ -298,7 +298,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str:

config_path = Path(".agentcore.json")
workload_identity_name = None
config = {}
config: dict[str, str] = {}
if config_path.exists():
try:
with open(config_path, "r", encoding="utf-8") as file:
Expand Down Expand Up @@ -327,7 +327,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str:
except Exception:
print("Warning: could not write the created workload identity to file")

return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"]
return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"] # type: ignore[no-any-return]


def _get_region() -> str:
Expand Down
37 changes: 21 additions & 16 deletions src/bedrock_agentcore/memory/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from botocore.config import Config
from botocore.exceptions import ClientError

from bedrock_agentcore._utils.snake_case import accept_snake_case_kwargs
from bedrock_agentcore._utils.user_agent import build_user_agent_suffix

from .constants import (
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
self.gmdp_client.meta.region_name,
)

def __getattr__(self, name: str):
def __getattr__(self, name: str) -> Any:
"""Dynamically forward method calls to the appropriate boto3 client.

This method enables access to all boto3 client methods without explicitly
Expand Down Expand Up @@ -126,12 +127,12 @@ def __getattr__(self, name: str):
if name in self._ALLOWED_GMDP_METHODS and hasattr(self.gmdp_client, name):
method = getattr(self.gmdp_client, name)
logger.debug("Forwarding method '%s' to gmdp_client", name)
return method
return accept_snake_case_kwargs(method)

if name in self._ALLOWED_GMCP_METHODS and hasattr(self.gmcp_client, name):
method = getattr(self.gmcp_client, name)
logger.debug("Forwarding method '%s' to gmcp_client", name)
return method
return accept_snake_case_kwargs(method)

# Method not found on either client
raise AttributeError(
Expand Down Expand Up @@ -203,7 +204,7 @@ def create_or_get_memory(
try:
memory = self.create_memory_and_wait(
name=name,
strategies=strategies,
strategies=strategies, # type: ignore[arg-type]
description=description,
event_expiry_days=event_expiry_days,
memory_execution_role_arn=memory_execution_role_arn,
Expand All @@ -213,7 +214,7 @@ def create_or_get_memory(
except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException" and "already exists" in str(e):
memories = self.list_memories()
memory = next((m for m in memories if m["id"].startswith(name)), None)
memory = next((m for m in memories if m["id"].startswith(name)), None) # type: ignore[arg-type]
logger.info("Memory already exists. Using existing memory ID: %s", memory["id"])
return memory
else:
Expand Down Expand Up @@ -338,7 +339,7 @@ def retrieve_memories(
memoryId=memory_id, namespace=namespace, searchCriteria={"searchQuery": query, "topK": top_k}
)

memories = response.get("memoryRecordSummaries", [])
memories: list[Dict[str, Any]] = response.get("memoryRecordSummaries", [])
logger.info("Retrieved %d memories from namespace: %s", len(memories), namespace)
return memories

Expand Down Expand Up @@ -473,7 +474,7 @@ def create_event(

response = self.gmdp_client.create_event(**params)

event = response["event"]
event: Dict[str, Any] = response["event"]
logger.info("Created event: %s", event["eventId"])

return event
Expand Down Expand Up @@ -539,7 +540,7 @@ def create_blob_event(

response = self.gmdp_client.create_event(**params)

event = response["event"]
event: Dict[str, Any] = response["event"]
logger.info("Created blob event: %s", event["eventId"])

return event
Expand Down Expand Up @@ -635,7 +636,7 @@ def save_conversation(

response = self.gmdp_client.create_event(**params)

event = response["event"]
event: Dict[str, Any] = response["event"]
logger.info("Created event: %s", event["eventId"])

return event
Expand Down Expand Up @@ -777,7 +778,7 @@ def list_events(
)
"""
try:
all_events = []
all_events: List[Dict[str, Any]] = []
next_token = None

while len(all_events) < max_results:
Expand All @@ -793,7 +794,7 @@ def list_events(
params["nextToken"] = next_token

# Build filter map
filter_map = {}
filter_map: Dict[str, Any] = {}

# Add branch filter if specified (but not for "main")
if branch_name and branch_name != "main":
Expand Down Expand Up @@ -937,7 +938,7 @@ def list_branch_events(
params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_parent_branches}}

response = self.gmdp_client.list_events(**params)
events = response.get("events", [])
events: list[Dict[str, Any]] = response.get("events", [])

# Handle pagination
next_token = response.get("nextToken")
Expand Down Expand Up @@ -992,7 +993,11 @@ def get_conversation_tree(self, memory_id: str, actor_id: str, session_id: str)
break

# Build tree structure
tree = {"session_id": session_id, "actor_id": actor_id, "main_branch": {"events": [], "branches": {}}}
tree: Dict[str, Any] = {
"session_id": session_id,
"actor_id": actor_id,
"main_branch": {"events": [], "branches": {}},
}

# Group events by branch
for event in all_events:
Expand Down Expand Up @@ -1094,7 +1099,7 @@ def get_last_k_turns(
Returns:
List of turns, where each turn is a list of message dictionaries
"""
base_params = {
base_params: Dict[str, Any] = {
"memoryId": memory_id,
"actorId": actor_id,
"sessionId": session_id,
Expand Down Expand Up @@ -1222,7 +1227,7 @@ def get_memory_status(self, memory_id: str) -> str:
"""Get current memory status."""
try:
response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
return response["memory"]["status"]
return response["memory"]["status"] # type: ignore[no-any-return]
except ClientError as e:
logger.error("Failed to get memory status: %s", e)
raise
Expand Down Expand Up @@ -1265,7 +1270,7 @@ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]:
def delete_memory(self, memory_id: str) -> Dict[str, Any]:
"""Delete a memory resource."""
try:
response = self.gmcp_client.delete_memory(
response: Dict[str, Any] = self.gmcp_client.delete_memory(
memoryId=memory_id, clientToken=str(uuid.uuid4())
) # Input uses old field name
logger.info("Deleted memory: %s", memory_id)
Expand Down
2 changes: 1 addition & 1 deletion src/bedrock_agentcore/memory/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ConversationalMessage:
text: str
role: MessageRole

def __post_init__(self):
def __post_init__(self) -> None:
"""Validate message fields after initialization."""
if not isinstance(self.text, str):
raise ValueError("ConversationalMessage.text must be a string")
Expand Down
Loading
Loading