Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ jobs:
- name: Run pre-commit
run: uv run pre-commit run --all-files

- name: Run mypy
run: uv run mypy src/

test:
name: Test Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _is_valid_adot_document(item: Any) -> bool:
return isinstance(item, dict) and "scope" in item and "traceId" in item and "spanId" in item


def _validate_spans(spans):
def _validate_spans(spans: Any) -> bool:
"""Validate spans are OpenTelemetry Span objects."""
if not spans:
return False
Expand Down Expand Up @@ -127,14 +127,14 @@ def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> List[Eva
]

# Check if spans are already in ADOT format or need conversion
if _is_adot_format(evaluation_case.actual_trajectory):
if _is_adot_format(evaluation_case.actual_trajectory): # type: ignore[arg-type]
# Already in ADOT format (fetched from CloudWatch), use as-is
spans = evaluation_case.actual_trajectory
else:
# Raw OTel spans from in-memory exporter, validate and convert
if not _validate_spans(evaluation_case.actual_trajectory):
return [EvaluationOutput(score=0.0, test_pass=False, reason="Invalid span objects")]
spans = convert_strands_to_adot(evaluation_case.actual_trajectory)
spans = convert_strands_to_adot(evaluation_case.actual_trajectory) # type: ignore[arg-type]

request_payload = {"evaluatorId": self.evaluator_id, "evaluationInput": {"sessionSpans": spans}}

Expand Down Expand Up @@ -165,7 +165,7 @@ async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT])
return await asyncio.to_thread(self.evaluate, evaluation_case)


def create_strands_evaluator(evaluator_id: str, **kwargs) -> StrandsEvalsAgentCoreEvaluator:
def create_strands_evaluator(evaluator_id: str, **kwargs: Any) -> StrandsEvalsAgentCoreEvaluator:
"""Create Strands-compatible evaluator backed by AgentCore Evaluation API.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class SpanParser:
"""

@staticmethod
def extract_metadata(span) -> SpanMetadata:
def extract_metadata(span: Any) -> SpanMetadata:
"""Extract core span metadata."""
if not hasattr(span, "context") or not span.context:
raise ValueError(f"Span '{getattr(span, 'name', 'unknown')}' missing required context")
Expand All @@ -96,7 +96,7 @@ def extract_metadata(span) -> SpanMetadata:
)

@staticmethod
def extract_resource_info(span) -> ResourceInfo:
def extract_resource_info(span: Any) -> ResourceInfo:
"""Extract resource and scope information."""
resource_attrs = {}
if hasattr(span, "resource") and span.resource and hasattr(span.resource, "attributes"):
Expand All @@ -115,7 +115,7 @@ def extract_resource_info(span) -> ResourceInfo:
)

@staticmethod
def get_span_attributes(span) -> Dict[str, Any]:
def get_span_attributes(span: Any) -> Dict[str, Any]:
"""Safely extract span attributes."""
return dict(span.attributes) if hasattr(span, "attributes") and span.attributes else {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def extract_tool_execution(cls, events: List[Any]) -> Optional[ToolExecution]:
class StrandsToADOTConverter:
"""Convert Strands OTel spans to ADOT format."""

def __init__(self):
def __init__(self) -> None:
"""Initialize converter with parsers and builder."""
self.span_parser = SpanParser()
self.event_parser = StrandsEventParser()
self.doc_builder = ADOTDocumentBuilder()

def convert_span(self, span) -> List[Dict[str, Any]]:
def convert_span(self, span: Any) -> List[Dict[str, Any]]:
"""Convert a single span to ADOT documents."""
documents = []

Expand Down
10 changes: 5 additions & 5 deletions src/bedrock_agentcore/identity/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _get_iam_jwt_token(region: str) -> str:
try:
response = sts_client.get_web_identity_token(**params)
logger.info("Successfully obtained AWS IAM JWT token")
return response["WebIdentityToken"]
return response["WebIdentityToken"] # type: ignore[no-any-return]
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ["FeatureDisabledException", "FeatureDisabled"]:
Expand Down Expand Up @@ -231,7 +231,7 @@ def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable:
def decorator(func: Callable) -> Callable:
client = IdentityClient(_get_region())

async def _get_api_key():
async def _get_api_key() -> str:
return await client.get_api_key(
provider_name=provider_name,
agent_identity_token=await _get_workload_access_token(client),
Expand Down Expand Up @@ -268,7 +268,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
return decorator


def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]):
def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]) -> Optional[str]:
if user_provided_oauth2_callback_url:
return user_provided_oauth2_callback_url

Expand Down Expand Up @@ -298,7 +298,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str:

config_path = Path(".agentcore.json")
workload_identity_name = None
config = {}
config: dict[str, str] = {}
if config_path.exists():
try:
with open(config_path, "r", encoding="utf-8") as file:
Expand Down Expand Up @@ -327,7 +327,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str:
except Exception:
print("Warning: could not write the created workload identity to file")

return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"]
return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"] # type: ignore[no-any-return]


def _get_region() -> str:
Expand Down
32 changes: 18 additions & 14 deletions src/bedrock_agentcore/memory/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
self.gmdp_client.meta.region_name,
)

def __getattr__(self, name: str):
def __getattr__(self, name: str) -> Any:
"""Dynamically forward method calls to the appropriate boto3 client.

This method enables access to all boto3 client methods without explicitly
Expand Down Expand Up @@ -203,7 +203,7 @@ def create_or_get_memory(
try:
memory = self.create_memory_and_wait(
name=name,
strategies=strategies,
strategies=strategies, # type: ignore[arg-type]
description=description,
event_expiry_days=event_expiry_days,
memory_execution_role_arn=memory_execution_role_arn,
Expand All @@ -213,7 +213,7 @@ def create_or_get_memory(
except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException" and "already exists" in str(e):
memories = self.list_memories()
memory = next((m for m in memories if m["id"].startswith(name)), None)
memory = next((m for m in memories if m["id"].startswith(name)), None) # type: ignore[arg-type]
logger.info("Memory already exists. Using existing memory ID: %s", memory["id"])
return memory
else:
Expand Down Expand Up @@ -338,7 +338,7 @@ def retrieve_memories(
memoryId=memory_id, namespace=namespace, searchCriteria={"searchQuery": query, "topK": top_k}
)

memories = response.get("memoryRecordSummaries", [])
memories: list[Dict[str, Any]] = response.get("memoryRecordSummaries", [])
logger.info("Retrieved %d memories from namespace: %s", len(memories), namespace)
return memories

Expand Down Expand Up @@ -473,7 +473,7 @@ def create_event(

response = self.gmdp_client.create_event(**params)

event = response["event"]
event: Dict[str, Any] = response["event"]
logger.info("Created event: %s", event["eventId"])

return event
Expand Down Expand Up @@ -539,7 +539,7 @@ def create_blob_event(

response = self.gmdp_client.create_event(**params)

event = response["event"]
event: Dict[str, Any] = response["event"]
logger.info("Created blob event: %s", event["eventId"])

return event
Expand Down Expand Up @@ -635,7 +635,7 @@ def save_conversation(

response = self.gmdp_client.create_event(**params)

event = response["event"]
event: Dict[str, Any] = response["event"]
logger.info("Created event: %s", event["eventId"])

return event
Expand Down Expand Up @@ -777,7 +777,7 @@ def list_events(
)
"""
try:
all_events = []
all_events: List[Dict[str, Any]] = []
next_token = None

while len(all_events) < max_results:
Expand All @@ -793,7 +793,7 @@ def list_events(
params["nextToken"] = next_token

# Build filter map
filter_map = {}
filter_map: Dict[str, Any] = {}

# Add branch filter if specified (but not for "main")
if branch_name and branch_name != "main":
Expand Down Expand Up @@ -937,7 +937,7 @@ def list_branch_events(
params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_parent_branches}}

response = self.gmdp_client.list_events(**params)
events = response.get("events", [])
events: list[Dict[str, Any]] = response.get("events", [])

# Handle pagination
next_token = response.get("nextToken")
Expand Down Expand Up @@ -992,7 +992,11 @@ def get_conversation_tree(self, memory_id: str, actor_id: str, session_id: str)
break

# Build tree structure
tree = {"session_id": session_id, "actor_id": actor_id, "main_branch": {"events": [], "branches": {}}}
tree: Dict[str, Any] = {
"session_id": session_id,
"actor_id": actor_id,
"main_branch": {"events": [], "branches": {}},
}

# Group events by branch
for event in all_events:
Expand Down Expand Up @@ -1094,7 +1098,7 @@ def get_last_k_turns(
Returns:
List of turns, where each turn is a list of message dictionaries
"""
base_params = {
base_params: Dict[str, Any] = {
"memoryId": memory_id,
"actorId": actor_id,
"sessionId": session_id,
Expand Down Expand Up @@ -1222,7 +1226,7 @@ def get_memory_status(self, memory_id: str) -> str:
"""Get current memory status."""
try:
response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name
return response["memory"]["status"]
return response["memory"]["status"] # type: ignore[no-any-return]
except ClientError as e:
logger.error("Failed to get memory status: %s", e)
raise
Expand Down Expand Up @@ -1265,7 +1269,7 @@ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]:
def delete_memory(self, memory_id: str) -> Dict[str, Any]:
"""Delete a memory resource."""
try:
response = self.gmcp_client.delete_memory(
response: Dict[str, Any] = self.gmcp_client.delete_memory(
memoryId=memory_id, clientToken=str(uuid.uuid4())
) # Input uses old field name
logger.info("Deleted memory: %s", memory_id)
Expand Down
2 changes: 1 addition & 1 deletion src/bedrock_agentcore/memory/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ConversationalMessage:
text: str
role: MessageRole

def __post_init__(self):
def __post_init__(self) -> None:
"""Validate message fields after initialization."""
if not isinstance(self.text, str):
raise ValueError("ConversationalMessage.text must be a string")
Expand Down
15 changes: 8 additions & 7 deletions src/bedrock_agentcore/memory/controlplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def create_memory(

try:
response = self.client.create_memory(**params)
memory = response["memory"]
memory: Dict[str, Any] = response["memory"]
memory_id = memory["id"]

logger.info("Created memory: %s", memory_id)
Expand All @@ -118,7 +118,7 @@ def get_memory(self, memory_id: str, include_strategies: bool = True) -> Dict[st
"""
try:
response = self.client.get_memory(memoryId=memory_id)
memory = response["memory"]
memory: Dict[str, Any] = response["memory"]

# Add strategy count
strategies = memory.get("strategies", [])
Expand All @@ -144,7 +144,7 @@ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]:
List of memory summaries
"""
try:
memories = []
memories: List[Dict[str, Any]] = []
next_token = None

while len(memories) < max_results:
Expand Down Expand Up @@ -239,7 +239,7 @@ def update_memory(

try:
response = self.client.update_memory(**params)
memory = response["memory"]
memory: Dict[str, Any] = response["memory"]
logger.info("Updated memory: %s", memory_id)

if wait_for_active:
Expand Down Expand Up @@ -300,7 +300,7 @@ def delete_memory(
logger.warning("Error waiting for strategies to become ACTIVE: %s", e)

# Now delete the memory
response = self.client.delete_memory(memoryId=memory_id, clientToken=str(uuid.uuid4()))
response: Dict[str, Any] = self.client.delete_memory(memoryId=memory_id, clientToken=str(uuid.uuid4()))

logger.info("Initiated deletion of memory: %s", memory_id)

Expand Down Expand Up @@ -399,7 +399,8 @@ def get_strategy(self, memory_id: str, strategy_id: str) -> Dict[str, Any]:

for strategy in strategies:
if strategy.get("strategyId") == strategy_id:
return strategy
result: Dict[str, Any] = strategy
return result

raise ValueError(f"Strategy {strategy_id} not found in memory {memory_id}")

Expand Down Expand Up @@ -567,7 +568,7 @@ def _wait_for_status(

start_time = time.time()
last_memory_status = None
strategy_statuses = {}
strategy_statuses: Dict[str, str] = {}

while time.time() - start_time < max_wait:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
if "conversational" in payload_item:
conv = payload_item["conversational"]
session_msg = SessionMessage.from_dict(json.loads(conv["content"]["text"]))
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message)
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) # type: ignore[assignment, arg-type]
if session_msg.message.get("content"):
messages.append(session_msg)
elif "blob" in payload_item:
Expand All @@ -88,7 +88,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2:
try:
session_msg = SessionMessage.from_dict(json.loads(blob_data[0]))
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message)
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) # type: ignore[assignment, arg-type]
if session_msg.message.get("content"):
messages.append(session_msg)
except (json.JSONDecodeError, ValueError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _bedrock_to_openai(message: dict) -> dict:
}
)

result: dict[str, Any] = {"role": role}
result: dict[str, Any] = {"role": role} # type: ignore[no-redef]

if tool_calls:
result["content"] = "\n".join(text_parts) if text_parts else None
Expand Down Expand Up @@ -144,7 +144,7 @@ def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]
if not has_non_empty:
return []

openai_msg = _bedrock_to_openai(message)
openai_msg = _bedrock_to_openai(message) # type: ignore[arg-type]
role = openai_msg.get("role", "user")
return [(json.dumps(openai_msg), role)]

Expand Down Expand Up @@ -177,7 +177,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
if openai_msg and isinstance(openai_msg, dict):
bedrock_msg = _openai_to_bedrock(openai_msg)
if bedrock_msg.get("content"):
session_msg = SessionMessage(message=bedrock_msg, message_id=0)
session_msg = SessionMessage(message=bedrock_msg, message_id=0) # type: ignore[arg-type]
messages.append(session_msg)

return messages
Expand Down
Loading
Loading