Skip to content

Commit e072708

Browse files
committed
split resource server and auth server
Signed-off-by: Jesse Sanford <108698+jessesanford@users.noreply.github.com>
1 parent cb646af commit e072708

File tree

4 files changed

+382
-19
lines changed

4 files changed

+382
-19
lines changed

examples/servers/proxy-auth/tests/test_proxy_oauth_endpoints.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
import httpx # type: ignore
1515
import pytest # type: ignore
1616

17+
# Import constants at the module level
18+
from proxy_auth.combo_server import (
19+
CLIENT_ID,
20+
UPSTREAM_AUTHORIZE,
21+
UPSTREAM_BASE,
22+
UPSTREAM_TOKEN,
23+
)
24+
1725

1826
@pytest.fixture
1927
def proxy_server(monkeypatch):
@@ -35,23 +43,39 @@ def proxy_server(monkeypatch):
3543
# Stub library-level fetch_upstream_metadata to avoid network I/O.
3644
from mcp.server.auth.proxy import routes as proxy_routes
3745

46+
<<<<<<< HEAD
47+
=======
48+
# Import the module and the combo_server instance
49+
from proxy_auth import combo_server
50+
51+
>>>>>>> fbb3cb4 (fix imports)
3852
async def _fake_metadata() -> dict[str, Any]: # noqa: D401
53+
# Access module-level constants directly
3954
return {
40-
"issuer": proxy_server_module.UPSTREAM_BASE,
41-
"authorization_endpoint": proxy_server_module.UPSTREAM_AUTHORIZE,
42-
"token_endpoint": proxy_server_module.UPSTREAM_TOKEN,
55+
"issuer": UPSTREAM_BASE,
56+
"authorization_endpoint": UPSTREAM_AUTHORIZE,
57+
"token_endpoint": UPSTREAM_TOKEN,
4358
"registration_endpoint": "/register",
4459
"jwks_uri": "",
4560
}
4661

62+
<<<<<<< HEAD
4763
monkeypatch.setattr(proxy_routes, "fetch_upstream_metadata", _fake_metadata, raising=True)
4864
return proxy_server_module
65+
=======
66+
monkeypatch.setattr(
67+
proxy_routes, "fetch_upstream_metadata", _fake_metadata, raising=True
68+
)
69+
70+
# Return the combo_server instance
71+
return combo_server
72+
>>>>>>> fbb3cb4 (fix imports)
4973

5074

5175
@pytest.fixture
5276
def app(proxy_server):
5377
"""Return the Starlette ASGI app for tests."""
54-
return proxy_server.mcp.streamable_http_app()
78+
return proxy_server.streamable_http_app()
5579

5680

5781
@pytest.fixture
@@ -78,24 +102,24 @@ async def test_metadata_endpoint(client):
78102

79103

80104
@pytest.mark.anyio
81-
async def test_registration_endpoint(client, proxy_server):
105+
async def test_registration_endpoint(client):
82106
payload = {"redirect_uris": ["https://client.example.com/callback"]}
83107
r = await client.post("/register", json=payload)
84108
assert r.status_code == 201
85109
body = r.json()
86-
assert body["client_id"] == proxy_server.CLIENT_ID
110+
assert body["client_id"] == CLIENT_ID
87111
assert body["redirect_uris"] == payload["redirect_uris"]
88112
# client_secret may be None, but the field should exist (masked or real)
89113
assert "client_secret" in body
90114

91115

92116
@pytest.mark.anyio
93-
async def test_authorize_redirect(client, proxy_server):
117+
async def test_authorize_redirect(client):
94118
params = {
95119
"response_type": "code",
96120
"state": "xyz",
97121
"redirect_uri": "https://client.example.com/callback",
98-
"client_id": proxy_server.CLIENT_ID,
122+
"client_id": CLIENT_ID,
99123
"code_challenge": "testchallenge",
100124
"code_challenge_method": "S256",
101125
}
@@ -105,18 +129,22 @@ async def test_authorize_redirect(client, proxy_server):
105129
location = r.headers["location"]
106130
parsed = urllib.parse.urlparse(location)
107131
assert parsed.scheme.startswith("http")
132+
<<<<<<< HEAD
108133
assert parsed.netloc == urllib.parse.urlparse(proxy_server.UPSTREAM_AUTHORIZE).netloc
134+
=======
135+
assert parsed.netloc == urllib.parse.urlparse(UPSTREAM_AUTHORIZE).netloc
136+
>>>>>>> fbb3cb4 (fix imports)
109137

110138
qs = urllib.parse.parse_qs(parsed.query)
111139
# Proxy should inject client_id & default scope
112-
assert qs["client_id"][0] == proxy_server.CLIENT_ID
140+
assert qs["client_id"][0] == CLIENT_ID
113141
assert "scope" in qs
114142
# Original params preserved
115143
assert qs["state"][0] == "xyz"
116144

117145

118146
@pytest.mark.anyio
119-
async def test_revoke_proxy(client, monkeypatch, proxy_server):
147+
async def test_revoke_proxy(client, monkeypatch):
120148
original_post = httpx.AsyncClient.post
121149

122150
async def _mock_post(self, url, data=None, timeout=10, **kwargs): # noqa: D401
@@ -133,7 +161,7 @@ async def _mock_post(self, url, data=None, timeout=10, **kwargs): # noqa: D401
133161

134162

135163
@pytest.mark.anyio
136-
async def test_token_passthrough(client, monkeypatch, proxy_server):
164+
async def test_token_passthrough(client, monkeypatch):
137165
"""Ensure /token is proxied unchanged and response is returned verbatim."""
138166

139167
# Capture outgoing POSTs made by ProxyTokenHandler
@@ -142,7 +170,7 @@ async def test_token_passthrough(client, monkeypatch, proxy_server):
142170
original_post = httpx.AsyncClient.post
143171

144172
async def _mock_post(self, url, *args, **kwargs): # noqa: D401
145-
if str(url).startswith(proxy_server.UPSTREAM_TOKEN):
173+
if str(url).startswith(UPSTREAM_TOKEN):
146174
# Record exactly what was sent upstream
147175
captured["url"] = str(url)
148176
captured["data"] = kwargs.get("data")
@@ -164,7 +192,7 @@ async def _mock_post(self, url, *args, **kwargs): # noqa: D401
164192
form = {
165193
"grant_type": "authorization_code",
166194
"code": "dummy-code",
167-
"client_id": proxy_server.CLIENT_ID,
195+
"client_id": CLIENT_ID,
168196
}
169197
r = await client.post("/token", data=form)
170198

@@ -207,7 +235,7 @@ def _fake_get_access_token(): # noqa: D401
207235

208236
monkeypatch.setattr(auth_context, "get_access_token", _fake_get_access_token, raising=True)
209237

210-
result = await proxy_server.mcp.call_tool("user_info", {})
238+
result = await proxy_server.call_tool("user_info", {})
211239

212240
# call_tool returns (content_blocks, raw_result)
213241
if isinstance(result, tuple):
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# pyright: reportMissingImports=false
2+
import logging
3+
import os
4+
import time
5+
6+
from dotenv import load_dotenv # type: ignore
7+
from mcp.server.auth.provider import AccessToken, OAuthToken
8+
from mcp.server.auth.providers.transparent_proxy import (
9+
ProxySettings, # type: ignore
10+
TransparentOAuthProxyProvider,
11+
ProxyTokenHandler,
12+
)
13+
from mcp.server.auth.routes import cors_middleware, create_auth_routes
14+
from mcp.server.auth.settings import ClientRegistrationOptions
15+
from pydantic import AnyHttpUrl
16+
from starlette.applications import Starlette
17+
from starlette.requests import Request # type: ignore
18+
from starlette.responses import JSONResponse, Response
19+
from starlette.routing import Route
20+
from uvicorn import Config, Server
21+
22+
# Load environment variables from .env if present
23+
load_dotenv()
24+
25+
# Configure logging after .env so LOG_LEVEL can come from environment
26+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
27+
28+
logging.basicConfig(
29+
level=LOG_LEVEL,
30+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
31+
datefmt="%Y-%m-%d %H:%M:%S",
32+
)
33+
34+
# Dedicated logger for this server module
35+
logger = logging.getLogger("proxy_oauth.auth_server")
36+
37+
# Suppress noisy INFO messages from the FastMCP low-level server unless we are
38+
# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type
39+
# ListToolsRequest") are helpful for debugging but clutter normal output.
40+
41+
_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server")
42+
if LOG_LEVEL == "DEBUG":
43+
# In full debug mode, allow the library to emit its detailed logs
44+
_mcp_lowlevel_logger.setLevel(logging.DEBUG)
45+
else:
46+
# Otherwise, only warnings and above
47+
_mcp_lowlevel_logger.setLevel(logging.WARNING)
48+
49+
# ----------------------------------------------------------------------------
50+
# Environment configuration
51+
# ----------------------------------------------------------------------------
52+
# Load and validate settings from the environment (uses .env automatically)
53+
settings = ProxySettings.load()
54+
55+
# Upstream endpoints (fully-qualified URLs)
56+
UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize)
57+
UPSTREAM_TOKEN: str = str(settings.upstream_token)
58+
UPSTREAM_JWKS_URI = settings.jwks_uri
59+
# Derive base URL from the authorize endpoint for convenience / tests
60+
UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0]
61+
62+
# Client credentials & defaults
63+
CLIENT_ID: str = settings.client_id or "demo-client-id"
64+
CLIENT_SECRET = settings.client_secret
65+
DEFAULT_SCOPE: str = settings.default_scope
66+
67+
# Optional audience passthrough (not part of ProxySettings yet)
68+
AUDIENCE = os.getenv("PROXY_AUDIENCE")
69+
70+
# Metadata URL (only used if we need to fetch from upstream)
71+
UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server"
72+
73+
# ---------------------------------------------------------------------------
74+
# Logging helpers
75+
# ---------------------------------------------------------------------------
76+
77+
78+
def _mask_secret(secret: str | None) -> str | None: # noqa: D401
79+
"""Return a masked version of the given secret.
80+
81+
The first and last four characters are preserved (if available) and the
82+
middle section is replaced by asterisks. If the secret is shorter than
83+
eight characters, the entire value is replaced by ``*``.
84+
"""
85+
86+
if not secret:
87+
return None
88+
89+
if len(secret) <= 8:
90+
return "*" * len(secret)
91+
92+
return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}"
93+
94+
95+
# Consolidated configuration (with sensitive data redacted)
96+
_masked_settings = settings.model_dump(exclude_none=True).copy()
97+
98+
if "client_secret" in _masked_settings:
99+
_masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"])
100+
101+
# Log configuration at *debug* level only so it can be enabled when needed
102+
logger.debug("[Auth Proxy Config] %s", _masked_settings)
103+
104+
# Server host/port
105+
AUTH_SERVER_PORT = int(os.getenv("AUTH_SERVER_PORT", "9000"))
106+
AUTH_SERVER_HOST = os.getenv("AUTH_SERVER_HOST", "localhost")
107+
AUTH_SERVER_URL = os.getenv(
108+
"AUTH_SERVER_URL", f"http://{AUTH_SERVER_HOST}:{AUTH_SERVER_PORT}"
109+
)
110+
111+
# ----------------------------------------------------------------------------
112+
# Auth Server
113+
# ----------------------------------------------------------------------------
114+
115+
# Create auth provider
116+
oauth_provider = TransparentOAuthProxyProvider(settings=settings)
117+
118+
# Enable client registration
119+
client_registration_options = ClientRegistrationOptions(
120+
enabled=True,
121+
valid_scopes=["openid"],
122+
default_scopes=["openid"],
123+
)
124+
125+
# Create auth routes
126+
routes = create_auth_routes(
127+
provider=oauth_provider,
128+
issuer_url=AnyHttpUrl(AUTH_SERVER_URL),
129+
service_documentation_url=None,
130+
client_registration_options=client_registration_options,
131+
revocation_options=None,
132+
)
133+
134+
# Add token endpoint handler
135+
# We need to replace any existing token endpoint route
136+
routes = [r for r in routes if not (hasattr(r, "path") and r.path == "/token")]
137+
138+
# Create token handler and add it to routes
139+
proxy_token_handler = ProxyTokenHandler(oauth_provider)
140+
routes.append(Route("/token", endpoint=proxy_token_handler.handle, methods=["POST"]))
141+
142+
# Add token introspection endpoint for Resource Servers
143+
async def introspect_handler(request: Request) -> Response:
144+
"""
145+
Token introspection endpoint for Resource Servers.
146+
147+
Resource Servers call this endpoint to validate tokens without
148+
needing direct access to token storage.
149+
"""
150+
form = await request.form()
151+
token = form.get("token")
152+
if not token or not isinstance(token, str):
153+
return JSONResponse({"active": False}, status_code=400)
154+
155+
# For the transparent proxy, we don't actually validate tokens
156+
# Just create a dummy AccessToken like the provider does
157+
access_token = AccessToken(
158+
token=token, client_id=str(CLIENT_ID), scopes=[DEFAULT_SCOPE], expires_at=None
159+
)
160+
161+
return JSONResponse(
162+
{
163+
"active": True,
164+
"client_id": access_token.client_id,
165+
"scope": " ".join(access_token.scopes),
166+
"exp": access_token.expires_at,
167+
"iat": int(time.time()),
168+
"token_type": "Bearer",
169+
"aud": access_token.resource, # RFC 8707 audience claim
170+
}
171+
)
172+
173+
174+
routes.append(
175+
Route(
176+
"/introspect",
177+
endpoint=cors_middleware(introspect_handler, ["POST", "OPTIONS"]),
178+
methods=["POST", "OPTIONS"],
179+
)
180+
)
181+
182+
# Create Starlette app with routes
183+
auth_app = Starlette(routes=routes)
184+
185+
186+
async def run_server():
187+
"""Run the Authorization Server."""
188+
config = Config(
189+
auth_app,
190+
host=AUTH_SERVER_HOST,
191+
port=AUTH_SERVER_PORT,
192+
log_level="info",
193+
)
194+
server = Server(config)
195+
196+
logger.info(f"🚀 MCP Authorization Server running on {AUTH_SERVER_URL}")
197+
198+
await server.serve()
199+
200+
201+
if __name__ == "__main__":
202+
import asyncio
203+
204+
asyncio.run(run_server())

0 commit comments

Comments
 (0)