1+ # pyright: reportMissingImports=false
2+ import os
3+ from urllib .parse import urljoin
4+ from dotenv import load_dotenv # type: ignore
5+ from typing import Any , cast
6+ import base64 , json , time
7+ from starlette .requests import Request # type: ignore
8+
9+ from mcp .server .fastmcp .server import Context
10+ from mcp .server .auth .proxy .server import build_proxy_server # noqa: E402
11+
12+ # Load environment variables from .env if present
13+ load_dotenv ()
14+
15+ # ----------------------------------------------------------------------------
16+ # Environment configuration
17+ # ----------------------------------------------------------------------------
18+ _upstream_base = os .getenv ("PROXY_UPSTREAM_BASE" , "https://auth.example.com/" )
19+ # Sanitize trailing slash
20+ if _upstream_base .endswith ("/" ):
21+ _upstream_base = _upstream_base [:- 1 ]
22+ UPSTREAM_BASE : str = _upstream_base
23+
24+ print ("[Proxy Config] UPSTREAM_BASE:" , UPSTREAM_BASE )
25+ print ("[Proxy Config] CLIENT_ID:" , os .getenv ("PROXY_CLIENT_ID" ) or os .getenv ("UPSTREAM_CLIENT_ID" ))
26+
27+ CLIENT_ID = os .getenv ("PROXY_CLIENT_ID" ) or os .getenv ("UPSTREAM_CLIENT_ID" ) or "demo-client-id"
28+ CLIENT_SECRET = os .getenv ("PROXY_CLIENT_SECRET" ) or os .getenv ("UPSTREAM_CLIENT_SECRET" ) # may be None
29+ DEFAULT_SCOPE = os .getenv ("PROXY_DEFAULT_SCOPE" , "openid profile email" )
30+ AUDIENCE = os .getenv ("PROXY_AUDIENCE" ) # optional
31+
32+ # ---------------------------------------------------------------------------
33+ # Resolve upstream endpoints – prefer explicit *_ENDPOINT variables (matches
34+ # naming used in fastmcp example) and fall back to BASE + path.
35+ # ---------------------------------------------------------------------------
36+
37+ UPSTREAM_AUTHORIZE = os .getenv ("UPSTREAM_AUTHORIZATION_ENDPOINT" ) or f"{ UPSTREAM_BASE } /authorize"
38+ UPSTREAM_TOKEN = os .getenv ("UPSTREAM_TOKEN_ENDPOINT" ) or f"{ UPSTREAM_BASE } /token"
39+ UPSTREAM_JWKS_URI = os .getenv ("UPSTREAM_JWKS_URI" )
40+ UPSTREAM_REVOCATION = os .getenv ("UPSTREAM_REVOCATION_ENDPOINT" ) or f"{ UPSTREAM_BASE } /revoke"
41+
42+ # Metadata URL (only used if we need to fetch from upstream)
43+ UPSTREAM_METADATA = f"{ UPSTREAM_BASE } /.well-known/oauth-authorization-server"
44+
45+ print ("[Proxy Config] UPSTREAM_AUTHORIZE:" , UPSTREAM_AUTHORIZE )
46+ print ("[Proxy Config] UPSTREAM_TOKEN:" , UPSTREAM_TOKEN )
47+
48+ # Server host/port
49+ PROXY_PORT = int (os .getenv ("PROXY_PORT" , "8000" ))
50+
51+ # ----------------------------------------------------------------------------
52+ # FastMCP server (now created via library helper)
53+ # ----------------------------------------------------------------------------
54+
55+ ISSUER_URL = os .getenv ("PROXY_ISSUER_URL" , "http://localhost:8000" )
56+
57+ # Create FastMCP instance using the reusable proxy builder
58+ mcp = build_proxy_server (port = PROXY_PORT , issuer_url = ISSUER_URL )
59+
60+ # ---------------------------------------------------------------------------
61+ # Minimal demo tool
62+ # ---------------------------------------------------------------------------
63+
64+ @mcp .tool ()
65+ def echo (message : str ) -> str :
66+ return f"Echo: { message } "
67+
68+
69+ @mcp .tool ()
70+ async def user_info (ctx : Context [Any , Any , Request ]) -> dict [str , Any ]:
71+ """
72+ Get information about the authenticated user.
73+
74+ This tool demonstrates accessing user information from the OAuth access token.
75+ The user must be authenticated via OAuth to access this tool.
76+
77+ Returns:
78+ Dictionary containing user information from the access token
79+ """
80+ from mcp .server .auth .middleware .auth_context import get_access_token
81+
82+ # Get the access token from the authentication context
83+ access_token = get_access_token ()
84+
85+ if not access_token :
86+ return {
87+ "error" : "No access token found - user not authenticated" ,
88+ "authenticated" : False
89+ }
90+
91+ # Attempt to decode the access token as JWT to extract useful user claims.
92+ # Many OAuth providers issue JWT access tokens (or ID tokens) that contain
93+ # the user's subject (sub) and preferred username. We parse the token
94+ # *without* signature verification – we only need the public claims for
95+ # display purposes. If the token is opaque or the decode fails, we simply
96+ # skip this step.
97+
98+ def _try_decode_jwt (token_str : str ) -> dict [str , Any ] | None : # noqa: D401
99+ """Best-effort JWT decode without verification.
100+
101+ Returns the payload dictionary if the token *looks* like a JWT and can
102+ be base64-decoded. If anything fails we return None.
103+ """
104+
105+ try :
106+ parts = token_str .split ("." )
107+ if len (parts ) != 3 :
108+ return None # Not a JWT
109+
110+ # JWT parts are URL-safe base64 without padding
111+ def _b64decode (segment : str ) -> bytes :
112+ padding = "=" * (- len (segment ) % 4 )
113+ return base64 .urlsafe_b64decode (segment + padding )
114+
115+ payload_bytes = _b64decode (parts [1 ])
116+ return json .loads (payload_bytes )
117+ except Exception : # noqa: BLE001
118+ return None
119+
120+ jwt_claims = _try_decode_jwt (access_token .token )
121+
122+ # Build response with token information plus any extracted claims
123+ response : dict [str , Any ] = {
124+ "authenticated" : True ,
125+ "client_id" : access_token .client_id ,
126+ "scopes" : access_token .scopes ,
127+ "token_type" : "Bearer" ,
128+ "expires_at" : access_token .expires_at ,
129+ "resource" : access_token .resource ,
130+ }
131+
132+ if jwt_claims :
133+ # Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if absent.
134+ uid = jwt_claims .get ("userid" ) or jwt_claims .get ("sub" )
135+ if uid is not None :
136+ response ["userid" ] = uid # camelCase variant used in FastMCP reference
137+ response ["user_id" ] = uid # snake_case variant
138+ response ["username" ] = (
139+ jwt_claims .get ("preferred_username" )
140+ or jwt_claims .get ("nickname" )
141+ or jwt_claims .get ("name" )
142+ )
143+ response ["issuer" ] = jwt_claims .get ("iss" )
144+ response ["audience" ] = jwt_claims .get ("aud" )
145+ response ["issued_at" ] = jwt_claims .get ("iat" )
146+
147+ # Calculate expiration helpers
148+ if access_token .expires_at :
149+ response ["expires_at_iso" ] = time .strftime ('%Y-%m-%dT%H:%M:%S' , time .localtime (access_token .expires_at ))
150+ response ["expires_in_seconds" ] = max (0 , access_token .expires_at - int (time .time ()))
151+
152+ return response
153+
154+
155+ @mcp .tool ()
156+ async def test_endpoint (message : str = "Hello from proxy server!" ) -> dict [str , Any ]:
157+ """
158+ Test endpoint for debugging OAuth proxy functionality.
159+
160+ Args:
161+ message: Optional message to echo back
162+
163+ Returns:
164+ Test response with server information
165+ """
166+ return {
167+ "message" : message ,
168+ "server" : "Transparent OAuth Proxy Server" ,
169+ "status" : "active" ,
170+ "oauth_configured" : True
171+ }
172+
173+
174+ if __name__ == "__main__" :
175+ mcp .run (transport = "streamable-http" )
0 commit comments