1+ # pyright: reportMissingImports=false
2+ import os
3+ import logging
4+ from dotenv import load_dotenv # type: ignore
5+ from typing import Any , cast
6+ import base64 , json , time
7+ from starlette .requests import Request # type: ignore
8+
9+ from mcp .server .fastmcp .server import Context
10+ from mcp .server .auth .proxy .server import build_proxy_server # noqa: E402
11+ from mcp .server .auth .providers .transparent_proxy import ProxySettings # type: ignore
12+
13+ # Load environment variables from .env if present
14+ load_dotenv ()
15+
16+ # Configure logging after .env so LOG_LEVEL can come from environment
17+ LOG_LEVEL = os .getenv ("LOG_LEVEL" , "INFO" ).upper ()
18+
19+ logging .basicConfig (
20+ level = LOG_LEVEL ,
21+ format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s" ,
22+ datefmt = "%Y-%m-%d %H:%M:%S" ,
23+ )
24+
25+ # Dedicated logger for this server module
26+ logger = logging .getLogger ("proxy_oauth.server" )
27+
28+ # Suppress noisy INFO messages from the FastMCP low-level server unless we are
29+ # explicitly running in DEBUG mode. These logs (e.g. "Processing request of type
30+ # ListToolsRequest") are helpful for debugging but clutter normal output.
31+
32+ _mcp_lowlevel_logger = logging .getLogger ("mcp.server.lowlevel.server" )
33+ if LOG_LEVEL == "DEBUG" :
34+ # In full debug mode, allow the library to emit its detailed logs
35+ _mcp_lowlevel_logger .setLevel (logging .DEBUG )
36+ else :
37+ # Otherwise, only warnings and above
38+ _mcp_lowlevel_logger .setLevel (logging .WARNING )
39+
40+ # ----------------------------------------------------------------------------
41+ # Environment configuration
42+ # ----------------------------------------------------------------------------
43+ # Load and validate settings from the environment (uses .env automatically)
44+ settings = ProxySettings .load ()
45+
46+ # Upstream endpoints (fully-qualified URLs)
47+ UPSTREAM_AUTHORIZE : str = str (settings .upstream_authorize )
48+ UPSTREAM_TOKEN : str = str (settings .upstream_token )
49+ UPSTREAM_JWKS_URI = settings .jwks_uri
50+ # Derive base URL from the authorize endpoint for convenience / tests
51+ UPSTREAM_BASE : str = UPSTREAM_AUTHORIZE .rsplit ("/" , 1 )[0 ]
52+
53+ # Client credentials & defaults
54+ CLIENT_ID : str = settings .client_id or "demo-client-id"
55+ CLIENT_SECRET = settings .client_secret
56+ DEFAULT_SCOPE : str = settings .default_scope
57+
58+ # Optional audience passthrough (not part of ProxySettings yet)
59+ AUDIENCE = os .getenv ("PROXY_AUDIENCE" )
60+
61+ # Metadata URL (only used if we need to fetch from upstream)
62+ UPSTREAM_METADATA = f"{ UPSTREAM_BASE } /.well-known/oauth-authorization-server"
63+
64+ # ---------------------------------------------------------------------------
65+ # Logging helpers
66+ # ---------------------------------------------------------------------------
67+
68+ def _mask_secret (secret : str | None ) -> str | None : # noqa: D401
69+ """Return a masked version of the given secret.
70+
71+ The first and last four characters are preserved (if available) and the
72+ middle section is replaced by asterisks. If the secret is shorter than
73+ eight characters, the entire value is replaced by ``*``.
74+ """
75+
76+ if not secret :
77+ return None
78+
79+ if len (secret ) <= 8 :
80+ return "*" * len (secret )
81+
82+ return f"{ secret [:4 ]} { '*' * (len (secret ) - 8 )} { secret [- 4 :]} "
83+
84+ # Consolidated configuration (with sensitive data redacted)
85+ _masked_settings = settings .model_dump (exclude_none = True ).copy ()
86+
87+ if "client_secret" in _masked_settings :
88+ _masked_settings ["client_secret" ] = _mask_secret (_masked_settings ["client_secret" ])
89+
90+ # Log configuration at *debug* level only so it can be enabled when needed
91+ logger .debug ("[Proxy Config] %s" , _masked_settings )
92+
93+ # Server host/port
94+ PROXY_PORT = int (os .getenv ("PROXY_PORT" , "8000" ))
95+
96+ # ----------------------------------------------------------------------------
97+ # FastMCP server (now created via library helper)
98+ # ----------------------------------------------------------------------------
99+
100+ ISSUER_URL = os .getenv ("PROXY_ISSUER_URL" , "http://localhost:8000" )
101+
102+ # Create FastMCP instance using the reusable proxy builder
103+ mcp = build_proxy_server (port = PROXY_PORT , issuer_url = ISSUER_URL )
104+
105+ # ---------------------------------------------------------------------------
106+ # Minimal demo tool
107+ # ---------------------------------------------------------------------------
108+
109+ @mcp .tool ()
110+ def echo (message : str ) -> str :
111+ return f"Echo: { message } "
112+
113+
114+ @mcp .tool ()
115+ async def user_info (ctx : Context [Any , Any , Request ]) -> dict [str , Any ]:
116+ """
117+ Get information about the authenticated user.
118+
119+ This tool demonstrates accessing user information from the OAuth access token.
120+ The user must be authenticated via OAuth to access this tool.
121+
122+ Returns:
123+ Dictionary containing user information from the access token
124+ """
125+ from mcp .server .auth .middleware .auth_context import get_access_token
126+
127+ # Get the access token from the authentication context
128+ access_token = get_access_token ()
129+
130+ if not access_token :
131+ return {
132+ "error" : "No access token found - user not authenticated" ,
133+ "authenticated" : False
134+ }
135+
136+ # Attempt to decode the access token as JWT to extract useful user claims.
137+ # Many OAuth providers issue JWT access tokens (or ID tokens) that contain
138+ # the user's subject (sub) and preferred username. We parse the token
139+ # *without* signature verification – we only need the public claims for
140+ # display purposes. If the token is opaque or the decode fails, we simply
141+ # skip this step.
142+
143+ def _try_decode_jwt (token_str : str ) -> dict [str , Any ] | None : # noqa: D401
144+ """Best-effort JWT decode without verification.
145+
146+ Returns the payload dictionary if the token *looks* like a JWT and can
147+ be base64-decoded. If anything fails we return None.
148+ """
149+
150+ try :
151+ parts = token_str .split ("." )
152+ if len (parts ) != 3 :
153+ return None # Not a JWT
154+
155+ # JWT parts are URL-safe base64 without padding
156+ def _b64decode (segment : str ) -> bytes :
157+ padding = "=" * (- len (segment ) % 4 )
158+ return base64 .urlsafe_b64decode (segment + padding )
159+
160+ payload_bytes = _b64decode (parts [1 ])
161+ return json .loads (payload_bytes )
162+ except Exception : # noqa: BLE001
163+ return None
164+
165+ jwt_claims = _try_decode_jwt (access_token .token )
166+
167+ # Build response with token information plus any extracted claims
168+ response : dict [str , Any ] = {
169+ "authenticated" : True ,
170+ "client_id" : access_token .client_id ,
171+ "scopes" : access_token .scopes ,
172+ "token_type" : "Bearer" ,
173+ "expires_at" : access_token .expires_at ,
174+ "resource" : access_token .resource ,
175+ }
176+
177+ if jwt_claims :
178+ # Prefer the `userid` claim used in FastMCP examples; fall back to `sub` if absent.
179+ uid = jwt_claims .get ("userid" ) or jwt_claims .get ("sub" )
180+ if uid is not None :
181+ response ["userid" ] = uid # camelCase variant used in FastMCP reference
182+ response ["user_id" ] = uid # snake_case variant
183+ response ["username" ] = (
184+ jwt_claims .get ("preferred_username" )
185+ or jwt_claims .get ("nickname" )
186+ or jwt_claims .get ("name" )
187+ )
188+ response ["issuer" ] = jwt_claims .get ("iss" )
189+ response ["audience" ] = jwt_claims .get ("aud" )
190+ response ["issued_at" ] = jwt_claims .get ("iat" )
191+
192+ # Calculate expiration helpers
193+ if access_token .expires_at :
194+ response ["expires_at_iso" ] = time .strftime ('%Y-%m-%dT%H:%M:%S' , time .localtime (access_token .expires_at ))
195+ response ["expires_in_seconds" ] = max (0 , access_token .expires_at - int (time .time ()))
196+
197+ return response
198+
199+
200+ @mcp .tool ()
201+ async def test_endpoint (message : str = "Hello from proxy server!" ) -> dict [str , Any ]:
202+ """
203+ Test endpoint for debugging OAuth proxy functionality.
204+
205+ Args:
206+ message: Optional message to echo back
207+
208+ Returns:
209+ Test response with server information
210+ """
211+ return {
212+ "message" : message ,
213+ "server" : "Transparent OAuth Proxy Server" ,
214+ "status" : "active" ,
215+ "oauth_configured" : True
216+ }
217+
218+
219+ if __name__ == "__main__" :
220+ mcp .run (transport = "streamable-http" )
0 commit comments