Skip to content

Commit f2d33a5

Browse files
SkyWalker-CMDSherif Riad
andauthored
Implement 3LO Server on localhost:8081 to handle generating OAuth2 tokens (#282)
* Implement 3LO Server on localhost:8081 to handle generating OAuth2 tokens * Update uv.lock with the latest AgentCore SDK * Add more unit test coverage for the new OAuth2 callback server * Add more unit test coverage to the OAuth2 3LO Server * Bump BedrockAgent Core SDK client to version 1.0.3 * Add more unit test coverage for the OAuth2 callback server * Add more unit test coverage for the local OAuth2 callback server * Add more unit test coverage for local OAuth2 callback server --------- Co-authored-by: Sherif Riad <sherifri@amazon.com>
1 parent ec7880e commit f2d33a5

File tree

10 files changed

+377
-23
lines changed

10 files changed

+377
-23
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ classifiers = [
2626
"Topic :: Software Development :: Libraries :: Python Modules",
2727
]
2828
dependencies = [
29-
"boto3>=1.40.35",
30-
"botocore>=1.40.35",
31-
"bedrock-agentcore>=0.1.7",
29+
"boto3>=1.40.51",
30+
"botocore>=1.40.51",
31+
"bedrock-agentcore>=1.0.3",
3232
"docstring_parser>=0.15,<1.0",
3333
"httpx>=0.28.1",
3434
"jinja2>=3.1.6",

src/bedrock_agentcore_starter_toolkit/cli/runtime/commands.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
from pathlib import Path
7+
from threading import Thread
78
from typing import List, Optional
89

910
import typer
@@ -12,6 +13,7 @@
1213
from rich.panel import Panel
1314
from rich.syntax import Syntax
1415

16+
from ...operations.identity.oauth2_callback_server import start_oauth2_callback_server
1517
from ...operations.runtime import (
1618
configure_bedrock_agentcore,
1719
destroy_bedrock_agentcore,
@@ -575,12 +577,23 @@ def launch(
575577
_print_success(f"Docker image built: {result.tag}")
576578
_print_success("Ready to run locally")
577579
console.print("Starting server at http://localhost:8080")
580+
console.print("Starting OAuth2 3LO callback server at http://localhost:8081")
578581
console.print("[yellow]Press Ctrl+C to stop[/yellow]\n")
579582

580583
if result.runtime is None or result.port is None:
581584
_handle_error("Unable to launch locally")
582585

583586
try:
587+
oauth2_callback_endpoint = Thread(
588+
target=start_oauth2_callback_server,
589+
args=(
590+
config_path,
591+
agent,
592+
),
593+
name="OAuth2 3LO Callback Server",
594+
daemon=True,
595+
)
596+
oauth2_callback_endpoint.start()
584597
result.runtime.run_local(result.tag, result.port, result.env_vars)
585598
except KeyboardInterrupt:
586599
console.print("\n[yellow]Stopped[/yellow]")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Bedrock AgentCore Identity operations."""
2+
3+
from .oauth2_callback_server import WORKLOAD_USER_ID, start_oauth2_callback_server
4+
5+
__all__ = ["start_oauth2_callback_server", "WORKLOAD_USER_ID"]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Provides a Starlette-based web server that handles OAuth2 3LO callbacks."""
2+
3+
from pathlib import Path
4+
5+
import uvicorn
6+
from bedrock_agentcore.services.identity import IdentityClient, UserIdIdentifier
7+
from starlette.applications import Starlette
8+
from starlette.requests import Request
9+
from starlette.responses import JSONResponse
10+
from starlette.routing import Route
11+
12+
from ...cli.common import console
13+
from ...utils.runtime.config import BedrockAgentCoreAgentSchema, load_config
14+
15+
OAUTH2_CALLBACK_SERVER_PORT = 8081
16+
OAUTH2_CALLBACK_ENDPOINT = "/oauth2/callback"
17+
WORKLOAD_USER_ID = "userId"
18+
19+
20+
def start_oauth2_callback_server(config_path: Path, agent_name: str, debug: bool = False):
21+
"""Starts a server to complete the OAuth2 3LO flow with AgentCore Identity."""
22+
callback_server = BedrockAgentCoreIdentity3loCallback(config_path=config_path, agent_name=agent_name, debug=debug)
23+
callback_server.run()
24+
25+
26+
class BedrockAgentCoreIdentity3loCallback(Starlette):
27+
"""Bedrock AgentCore application class that extends Starlette for OAuth2 3LO callback flow."""
28+
29+
def __init__(self, config_path: Path, agent_name: str, debug: bool = False):
30+
"""Initialize Bedrock AgentCore Identity callback server."""
31+
self.config_path = config_path
32+
self.agent_name = agent_name
33+
routes = [
34+
Route(OAUTH2_CALLBACK_ENDPOINT, self._handle_3lo_callback, methods=["GET"]),
35+
]
36+
super().__init__(routes=routes, debug=debug)
37+
38+
def run(self, **kwargs):
39+
"""Start the Bedrock AgentCore Identity OAuth2 callback server."""
40+
uvicorn_params = {
41+
"host": "127.0.0.1",
42+
"port": OAUTH2_CALLBACK_SERVER_PORT,
43+
"access_log": self.debug,
44+
"log_level": "info" if self.debug else "warning",
45+
}
46+
uvicorn_params.update(kwargs)
47+
48+
uvicorn.run(self, **uvicorn_params)
49+
50+
def _handle_3lo_callback(self, request: Request) -> JSONResponse:
51+
"""Handle OAuth2 3LO callbacks with AgentCore Identity."""
52+
session_id = request.query_params.get("session_id")
53+
if not session_id:
54+
console.print("Missing session_id in OAuth2 3LO callback")
55+
return JSONResponse(status_code=400, content={"message": "missing session_id query parameter"})
56+
57+
project_config = load_config(self.config_path)
58+
agent_config: BedrockAgentCoreAgentSchema = project_config.get_agent_config(self.agent_name)
59+
oauth2_config = agent_config.oauth_configuration
60+
61+
user_id = None
62+
if oauth2_config:
63+
user_id = oauth2_config.get(WORKLOAD_USER_ID)
64+
65+
if not user_id:
66+
console.print(f"Missing {WORKLOAD_USER_ID} in Agent OAuth2 Config")
67+
return JSONResponse(status_code=500, content={"message": "Internal Server Error"})
68+
69+
console.print(f"Handling 3LO callback for workload_user_id={user_id} | session_id={session_id}", soft_wrap=True)
70+
71+
region = agent_config.aws.region
72+
if not region:
73+
console.print("AWS Region not configured")
74+
return JSONResponse(status_code=500, content={"message": "Internal Server Error"})
75+
76+
identity_client = IdentityClient(region)
77+
identity_client.complete_resource_token_auth(
78+
session_uri=session_id, user_identifier=UserIdIdentifier(user_id=user_id)
79+
)
80+
81+
return JSONResponse(status_code=200, content={"message": "OAuth2 3LO flow completed successfully"})
82+
83+
@classmethod
84+
def get_oauth2_callback_endpoint(cls) -> str:
85+
"""Returns the url for the local OAuth2 callback server."""
86+
return f"http://localhost:{OAUTH2_CALLBACK_SERVER_PORT}{OAUTH2_CALLBACK_ENDPOINT}"

src/bedrock_agentcore_starter_toolkit/operations/runtime/invoke.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from bedrock_agentcore.services.identity import IdentityClient
99

10+
from ...operations.identity.oauth2_callback_server import WORKLOAD_USER_ID, BedrockAgentCoreIdentity3loCallback
1011
from ...services.runtime import BedrockAgentCoreClient, generate_session_id
1112
from ...utils.runtime.config import load_config, save_config
1213
from ...utils.runtime.schema import BedrockAgentCoreConfigSchema
@@ -121,9 +122,19 @@ def invoke_bedrock_agentcore(
121122
workload_name=workload_name, user_token=bearer_token, user_id=user_id
122123
)["workloadAccessToken"]
123124

125+
agent_config.oauth_configuration[WORKLOAD_USER_ID] = user_id # type: ignore : populated by _get_workload_name(...)
126+
save_config(project_config, config_path)
127+
128+
oauth2_callback_url = BedrockAgentCoreIdentity3loCallback.get_oauth2_callback_endpoint()
129+
_update_workload_identity_with_oauth2_callback_url(
130+
identity_client, workload_name=workload_name, oauth2_callback_url=oauth2_callback_url
131+
)
132+
124133
# TODO: store and read port config of local running container
125134
client = LocalBedrockAgentCoreClient("http://127.0.0.1:8080")
126-
response = client.invoke_endpoint(session_id, payload_str, workload_access_token, custom_headers)
135+
response = client.invoke_endpoint(
136+
session_id, payload_str, workload_access_token, oauth2_callback_url, custom_headers
137+
)
127138

128139
else:
129140
if not agent_arn:
@@ -163,6 +174,24 @@ def invoke_bedrock_agentcore(
163174
)
164175

165176

177+
def _update_workload_identity_with_oauth2_callback_url(
178+
identity_client: IdentityClient,
179+
workload_name: str,
180+
oauth2_callback_url: str,
181+
) -> None:
182+
workload_identity = identity_client.get_workload_identity(name=workload_name)
183+
allowed_resource_oauth_2_return_urls = workload_identity.get("allowedResourceOauth2ReturnUrls") or []
184+
if oauth2_callback_url in allowed_resource_oauth_2_return_urls:
185+
return
186+
187+
log.info("Updating workload %s with callback url %s", workload_name, oauth2_callback_url)
188+
189+
identity_client.update_workload_identity(
190+
name=workload_name,
191+
allowed_resource_oauth_2_return_urls=[*allowed_resource_oauth_2_return_urls, oauth2_callback_url],
192+
)
193+
194+
166195
def _get_workload_name(
167196
project_config: BedrockAgentCoreConfigSchema,
168197
project_config_path: Path,

src/bedrock_agentcore_starter_toolkit/services/runtime.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,17 +576,19 @@ def invoke_endpoint(
576576
session_id: str,
577577
payload: str,
578578
workload_access_token: str,
579+
oauth2_callback_url: str,
579580
custom_headers: Optional[dict] = None,
580581
):
581582
"""Invoke the endpoint with the given parameters."""
582-
from bedrock_agentcore.runtime.models import ACCESS_TOKEN_HEADER, SESSION_HEADER
583+
from bedrock_agentcore.runtime.models import ACCESS_TOKEN_HEADER, OAUTH2_CALLBACK_URL_HEADER, SESSION_HEADER
583584

584585
url = f"{self.endpoint}/invocations"
585586

586587
headers = {
587588
"Content-Type": "application/json",
588589
ACCESS_TOKEN_HEADER: workload_access_token,
589590
SESSION_HEADER: session_id,
591+
OAUTH2_CALLBACK_URL_HEADER: oauth2_callback_url,
590592
}
591593

592594
# Merge custom headers if provided
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from unittest.mock import Mock, patch
2+
3+
from bedrock_agentcore.services.identity import UserIdIdentifier
4+
from starlette.testclient import TestClient
5+
6+
from bedrock_agentcore_starter_toolkit.operations.identity.oauth2_callback_server import (
7+
OAUTH2_CALLBACK_ENDPOINT,
8+
WORKLOAD_USER_ID,
9+
BedrockAgentCoreIdentity3loCallback,
10+
)
11+
from bedrock_agentcore_starter_toolkit.utils.runtime.config import save_config
12+
from bedrock_agentcore_starter_toolkit.utils.runtime.schema import (
13+
AWSConfig,
14+
BedrockAgentCoreAgentSchema,
15+
BedrockAgentCoreConfigSchema,
16+
NetworkConfiguration,
17+
ObservabilityConfig,
18+
)
19+
20+
21+
def create_test_config(tmp_path, *, agent_name="test-agent", user_id="test-user-id", region="us-west-2"):
22+
config_path = tmp_path / ".bedrock_agentcore.yaml"
23+
24+
agent_config = BedrockAgentCoreAgentSchema(
25+
name=agent_name,
26+
entrypoint="test_agent.py",
27+
container_runtime="docker",
28+
aws=AWSConfig(
29+
region=region,
30+
account="123456789012",
31+
execution_role=None,
32+
execution_role_auto_create=True,
33+
ecr_repository=None,
34+
ecr_auto_create=True,
35+
network_configuration=NetworkConfiguration(),
36+
observability=ObservabilityConfig(),
37+
),
38+
oauth_configuration={WORKLOAD_USER_ID: user_id} if user_id else {},
39+
)
40+
41+
project_config = BedrockAgentCoreConfigSchema(default_agent=agent_name, agents={agent_name: agent_config})
42+
save_config(project_config, config_path)
43+
44+
return config_path
45+
46+
47+
class TestBedrockAgentCoreIdentity3loCallback:
48+
def test_init(self, tmp_path):
49+
config_path = create_test_config(tmp_path)
50+
server = BedrockAgentCoreIdentity3loCallback(config_path=config_path, agent_name="test-agent")
51+
52+
assert server.config_path == config_path
53+
assert server.agent_name == "test-agent"
54+
assert len(server.routes) == 1
55+
assert server.routes[0].path == OAUTH2_CALLBACK_ENDPOINT
56+
57+
def test_get_callback_endpoint(self):
58+
endpoint = BedrockAgentCoreIdentity3loCallback.get_oauth2_callback_endpoint()
59+
assert endpoint == "http://localhost:8081/oauth2/callback"
60+
61+
def test_handle_3lo_callback_missing_session_id(self, tmp_path):
62+
config_path = create_test_config(tmp_path)
63+
server = BedrockAgentCoreIdentity3loCallback(config_path=config_path, agent_name="test-agent")
64+
client = TestClient(server)
65+
response = client.get(OAUTH2_CALLBACK_ENDPOINT)
66+
67+
assert response.status_code == 400
68+
assert response.json().get("message") == "missing session_id query parameter"
69+
70+
@patch("bedrock_agentcore_starter_toolkit.operations.identity.oauth2_callback_server.IdentityClient")
71+
def test_handle_3lo_callback_success(self, mock_identity_client, tmp_path):
72+
config_path = create_test_config(tmp_path)
73+
server = BedrockAgentCoreIdentity3loCallback(config_path=config_path, agent_name="test-agent")
74+
75+
mock_client_instance = Mock()
76+
mock_identity_client.return_value = mock_client_instance
77+
78+
client = TestClient(server)
79+
response = client.get(f"{OAUTH2_CALLBACK_ENDPOINT}?session_id=test-session-123")
80+
81+
assert response.status_code == 200
82+
assert response.json().get("message") == "OAuth2 3LO flow completed successfully"
83+
mock_identity_client.assert_called_once_with("us-west-2")
84+
mock_client_instance.complete_resource_token_auth.assert_called_once_with(
85+
session_uri="test-session-123", user_identifier=UserIdIdentifier(user_id="test-user-id")
86+
)
87+
88+
def test_handle_3lo_callback_missing_user_id(self, tmp_path):
89+
config_path = create_test_config(tmp_path, user_id="")
90+
server = BedrockAgentCoreIdentity3loCallback(config_path=config_path, agent_name="test-agent")
91+
client = TestClient(server)
92+
response = client.get(f"{OAUTH2_CALLBACK_ENDPOINT}?session_id=test-session-123")
93+
94+
assert response.status_code == 500
95+
assert response.json().get("message") == "Internal Server Error"
96+
97+
def test_handle_3lo_callback_missing_region(self, tmp_path):
98+
config_path = create_test_config(tmp_path, region="")
99+
server = BedrockAgentCoreIdentity3loCallback(config_path=config_path, agent_name="test-agent")
100+
client = TestClient(server)
101+
response = client.get(f"{OAUTH2_CALLBACK_ENDPOINT}?session_id=test-session-123")
102+
103+
assert response.status_code == 500
104+
assert response.json().get("message") == "Internal Server Error"

0 commit comments

Comments
 (0)