Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions infra/lambda/cloudwatch_logs_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""CloudWatch Logs tool for AgentCore Gateway.

Handles cloudwatch_start_query, cloudwatch_get_query_results, and cloudwatch_stop_query
tool calls routed via Gateway.
"""

import json
import logging
import os
from datetime import datetime, timezone

import boto3
from botocore.exceptions import ClientError

logs_client = boto3.client("logs", region_name=os.environ.get("AWS_REGION", "us-east-1"))

logger = logging.getLogger()
logger.setLevel(logging.INFO)

DELIMITER = "___"


def _create_error_response(error_message):
return {
"success": False,
"error": error_message,
"timestamp": datetime.now(timezone.utc).isoformat(),
}


def lambda_handler(event, context):
logger.info(f"Event: {json.dumps(event)}")

original_tool_name = context.client_context.custom["bedrockAgentCoreToolName"]
tool_name = original_tool_name[original_tool_name.index(DELIMITER) + len(DELIMITER) :]

if tool_name == "cloudwatch_start_query":
return cloudwatch_start_query(event)
elif tool_name == "cloudwatch_get_query_results":
return cloudwatch_get_query_results(event)
elif tool_name == "cloudwatch_stop_query":
return cloudwatch_stop_query(event)
else:
return {"success": False, "error": f"Unknown tool: {tool_name}"}


def cloudwatch_start_query(event):
log_group_name = event.get("log_group_name", "")
query_string = event.get("query_string", "")
start_time = event.get("start_time", 0)
end_time = event.get("end_time", 0)

if not all([log_group_name, query_string, start_time, end_time]):
return _create_error_response("Missing required parameters: log_group_name, query_string, start_time, end_time")

try:
response = logs_client.start_query(
logGroupName=log_group_name,
startTime=int(start_time),
endTime=int(end_time),
queryString=query_string,
)
return {
"success": True,
"query_id": response["queryId"],
"timestamp": datetime.now(timezone.utc).isoformat(),
}
except ClientError as e:
error_code = e.response["Error"]["Code"]
error_message = e.response["Error"]["Message"]
logger.error(f"CloudWatch StartQuery error: {error_code} - {error_message}")
return _create_error_response(f"{error_code}: {error_message}")


def cloudwatch_get_query_results(event):
query_id = event.get("query_id", "")

if not query_id:
return _create_error_response("Missing required parameter: query_id")

try:
response = logs_client.get_query_results(queryId=query_id)
results = []
for row in response.get("results", []):
result_row = {}
for field in row:
result_row[field["field"]] = field["value"]
results.append(result_row)

return {
"success": True,
"status": response.get("status"),
"results": results,
"statistics": response.get("statistics", {}),
"timestamp": datetime.now(timezone.utc).isoformat(),
}
except ClientError as e:
error_code = e.response["Error"]["Code"]
error_message = e.response["Error"]["Message"]
logger.error(f"CloudWatch GetQueryResults error: {error_code} - {error_message}")
return _create_error_response(f"{error_code}: {error_message}")


def cloudwatch_stop_query(event):
query_id = event.get("query_id", "")

if not query_id:
return _create_error_response("Missing required parameter: query_id")

try:
logs_client.stop_query(queryId=query_id)
return {
"success": True,
"query_id": query_id,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
except ClientError as e:
error_code = e.response["Error"]["Code"]
error_message = e.response["Error"]["Message"]
logger.error(f"CloudWatch StopQuery error: {error_code} - {error_message}")
return _create_error_response(f"{error_code}: {error_message}")
130 changes: 130 additions & 0 deletions infra/lambda/cloudwatch_metrics_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""CloudWatch Metrics tool for AgentCore Gateway.

Handles cloudwatch_get_metric_statistics and cloudwatch_list_metrics
tool calls routed via Gateway.
"""

import json
import logging
import os
from datetime import datetime, timezone

import boto3
from botocore.exceptions import ClientError

cloudwatch_client = boto3.client("cloudwatch", region_name=os.environ.get("AWS_REGION", "us-east-1"))

logger = logging.getLogger()
logger.setLevel(logging.INFO)

DELIMITER = "___"


def _create_error_response(error_message):
return {
"success": False,
"error": error_message,
"timestamp": datetime.now(timezone.utc).isoformat(),
}


def lambda_handler(event, context):
logger.info(f"Event: {json.dumps(event)}")

original_tool_name = context.client_context.custom["bedrockAgentCoreToolName"]
tool_name = original_tool_name[original_tool_name.index(DELIMITER) + len(DELIMITER) :]

if tool_name == "cloudwatch_get_metric_statistics":
return cloudwatch_get_metric_statistics(event)
elif tool_name == "cloudwatch_list_metrics":
return cloudwatch_list_metrics(event)
else:
return {"success": False, "error": f"Unknown tool: {tool_name}"}


def cloudwatch_get_metric_statistics(event):
namespace = event.get("namespace", "")
metric_name = event.get("metric_name", "")
start_time = event.get("start_time", "")
end_time = event.get("end_time", "")
period = event.get("period", 0)
statistics = event.get("statistics", [])
dimensions = event.get("dimensions", [])

if not all([namespace, metric_name, start_time, end_time, period, statistics]):
return _create_error_response(
"Missing required parameters: namespace, metric_name, start_time, end_time, period, statistics"
)

try:
params = {
"Namespace": namespace,
"MetricName": metric_name,
"StartTime": start_time,
"EndTime": end_time,
"Period": int(period),
"Statistics": statistics,
}
if dimensions:
params["Dimensions"] = [{"Name": d["name"], "Value": d["value"]} for d in dimensions]

response = cloudwatch_client.get_metric_statistics(**params)
datapoints = []
for dp in response.get("Datapoints", []):
point = {"timestamp": dp["Timestamp"].isoformat()}
for stat in ["Average", "Sum", "Minimum", "Maximum", "SampleCount"]:
if stat in dp:
point[stat.lower()] = dp[stat]
datapoints.append(point)

datapoints.sort(key=lambda x: x["timestamp"])

return {
"success": True,
"namespace": namespace,
"metric_name": metric_name,
"datapoints": datapoints,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
except ClientError as e:
error_code = e.response["Error"]["Code"]
error_message = e.response["Error"]["Message"]
logger.error(f"CloudWatch GetMetricStatistics error: {error_code} - {error_message}")
return _create_error_response(f"{error_code}: {error_message}")


def cloudwatch_list_metrics(event):
namespace = event.get("namespace", "")
metric_name = event.get("metric_name", "")
dimensions = event.get("dimensions", [])

try:
params = {}
if namespace:
params["Namespace"] = namespace
if metric_name:
params["MetricName"] = metric_name
if dimensions:
params["Dimensions"] = [{"Name": d["name"], "Value": d.get("value", "")} for d in dimensions]

response = cloudwatch_client.list_metrics(**params)
metrics = []
for m in response.get("Metrics", []):
metrics.append(
{
"namespace": m.get("Namespace"),
"metric_name": m.get("MetricName"),
"dimensions": [{"name": d["Name"], "value": d["Value"]} for d in m.get("Dimensions", [])],
}
)

return {
"success": True,
"metrics": metrics,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
except ClientError as e:
error_code = e.response["Error"]["Code"]
error_message = e.response["Error"]["Message"]
logger.error(f"CloudWatch ListMetrics error: {error_code} - {error_message}")
return _create_error_response(f"{error_code}: {error_message}")
143 changes: 143 additions & 0 deletions infra/lambda/interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Gateway interceptor for request filtering and response redaction."""

import json
import logging
import re

logger = logging.getLogger()
logger.setLevel(logging.INFO)

BLOCKED_PATTERNS = [
re.compile(p, re.IGNORECASE)
for p in [
r"ignore\s+(previous|all)\s+instructions",
r"disregard\s+(previous|all)\s+instructions",
r"forget\s+(previous|all)\s+instructions",
r"you\s+are\s+now\s+in\s+developer\s+mode",
r"jailbreak",
r"bypass\s+(security|filter|restriction)",
r"<script[^>]*>",
r"javascript:",
r"eval\s*\(",
r"exec\s*\(",
r"__import__",
r"os\.system",
r"subprocess",
]
]

PII_PATTERNS = [
(re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"), "[EMAIL_REDACTED]"),
(re.compile(r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b"), "[PHONE_REDACTED]"),
(re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), "[SSN_REDACTED]"),
(re.compile(r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b"), "[CC_REDACTED]"),
(re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b"), "[IP_REDACTED]"),
]


def check_malicious_content(text):
for pattern in BLOCKED_PATTERNS:
if pattern.search(text):
return True, pattern.pattern
return False, ""


def redact_pii(text):
for pattern, replacement in PII_PATTERNS:
text = pattern.sub(replacement, text)
return text


def redact_dict(obj):
if isinstance(obj, str):
return redact_pii(obj)
elif isinstance(obj, dict):
return {k: redact_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [redact_dict(item) for item in obj]
return obj


def lambda_handler(event, context):
logger.info(f"Interceptor event: {json.dumps(event)}")

mcp_data = event.get("mcp", {})
gateway_response = mcp_data.get("gatewayResponse")

if gateway_response is not None:
return handle_response(mcp_data, gateway_response)
else:
return handle_request(mcp_data)


def extract_session_id_from_meta(request_body):
"""Extract session_id from OTEL baggage in _meta field of tools/call requests."""
params = request_body.get("params", {})
if not isinstance(params, dict):
return None
meta = params.get("_meta", {})
if not isinstance(meta, dict):
return None
baggage_str = meta.get("baggage", "")
for pair in baggage_str.split(","):
kv = pair.strip().split("=", 1)
if len(kv) == 2 and kv[0] == "session.id":
return kv[1]
return None


def handle_request(mcp_data):
gateway_request = mcp_data.get("gatewayRequest", {})
request_body = gateway_request.get("body", {})
content_to_check = json.dumps(request_body)

is_malicious, matched_pattern = check_malicious_content(content_to_check)

if is_malicious:
logger.warning(f"Blocked malicious request matching: {matched_pattern}")
return {
"interceptorOutputVersion": "1.0",
"mcp": {
"transformedGatewayResponse": {
"statusCode": 403,
"body": {
"jsonrpc": "2.0",
"id": request_body.get("id", 1),
"error": {
"code": -32600,
"message": "Request blocked: potentially harmful content detected",
},
},
}
},
}

if request_body.get("method") == "tools/call":
session_id = extract_session_id_from_meta(request_body)
if session_id:
params = request_body.get("params", {})
arguments = params.get("arguments", {})
arguments["session_id"] = session_id
params["arguments"] = arguments
request_body["params"] = params

return {
"interceptorOutputVersion": "1.0",
"mcp": {"transformedGatewayRequest": {"body": request_body}},
}


def handle_response(mcp_data, gateway_response):
response_body = gateway_response.get("body") or {}
status_code = gateway_response.get("statusCode", 200)
redacted_body = redact_dict(response_body)

return {
"interceptorOutputVersion": "1.0",
"mcp": {
"transformedGatewayResponse": {
"statusCode": status_code,
"body": redacted_body,
}
},
}
Loading
Loading