1+ # pyright: reportMissingImports=false
2+ # pytest test suite for proxy_auth/combo_server.py
3+ # These tests spin up the FastMCP Starlette application in-process and
4+ # exercise the custom HTTP routes as well as the `user_info` tool.
5+
6+ from __future__ import annotations
7+
8+ import base64
9+ import json
10+ import urllib .parse
11+ from collections .abc import AsyncGenerator
12+ from typing import Any
13+
14+ import httpx # type: ignore
15+ import pytest # type: ignore
16+
17+
18+ @pytest .fixture
19+ def proxy_server (monkeypatch ):
20+ """Import the proxy OAuth demo server with safe environment + stubs."""
21+ import os
22+
23+ # Avoid real outbound calls by pretending the upstream endpoints were
24+ # supplied explicitly via env vars – this makes `fetch_upstream_metadata`
25+ # construct metadata locally instead of performing an HTTP GET.
26+ os .environ .setdefault ("UPSTREAM_AUTHORIZATION_ENDPOINT" , "https://upstream.example.com/authorize" )
27+ os .environ .setdefault ("UPSTREAM_TOKEN_ENDPOINT" , "https://upstream.example.com/token" )
28+ os .environ .setdefault ("UPSTREAM_JWKS_URI" , "https://upstream.example.com/jwks" )
29+ os .environ .setdefault ("UPSTREAM_CLIENT_ID" , "client123" )
30+ os .environ .setdefault ("UPSTREAM_CLIENT_SECRET" , "secret123" )
31+
32+ # Deferred import so the env vars above are in effect.
33+ from proxy_auth import combo_server as proxy_server_module
34+
35+ # Stub library-level fetch_upstream_metadata to avoid network I/O.
36+ from mcp .server .auth .proxy import routes as proxy_routes
37+
38+ async def _fake_metadata () -> dict [str , Any ]: # noqa: D401
39+ return {
40+ "issuer" : proxy_server_module .UPSTREAM_BASE ,
41+ "authorization_endpoint" : proxy_server_module .UPSTREAM_AUTHORIZE ,
42+ "token_endpoint" : proxy_server_module .UPSTREAM_TOKEN ,
43+ "registration_endpoint" : "/register" ,
44+ "jwks_uri" : "" ,
45+ }
46+
47+ monkeypatch .setattr (proxy_routes , "fetch_upstream_metadata" , _fake_metadata , raising = True )
48+ return proxy_server_module
49+
50+
51+ @pytest .fixture
52+ def app (proxy_server ):
53+ """Return the Starlette ASGI app for tests."""
54+ return proxy_server .mcp .streamable_http_app ()
55+
56+
57+ @pytest .fixture
58+ async def client (app ) -> AsyncGenerator [httpx .AsyncClient , None ]:
59+ """Async HTTP client bound to the in-memory ASGI application."""
60+ async with httpx .AsyncClient (transport = httpx .ASGITransport (app = app ), base_url = "http://testserver" ) as c :
61+ yield c
62+
63+
64+ # ---------------------------------------------------------------------------
65+ # HTTP endpoint tests
66+ # ---------------------------------------------------------------------------
67+
68+
69+ @pytest .mark .anyio
70+ async def test_metadata_endpoint (client ):
71+ r = await client .get ("/.well-known/oauth-authorization-server" )
72+ assert r .status_code == 200
73+ data = r .json ()
74+ assert "issuer" in data
75+ assert data ["authorization_endpoint" ].endswith ("/authorize" )
76+ assert data ["token_endpoint" ].endswith ("/token" )
77+ assert data ["registration_endpoint" ].endswith ("/register" )
78+
79+
80+ @pytest .mark .anyio
81+ async def test_registration_endpoint (client , proxy_server ):
82+ payload = {"redirect_uris" : ["https://client.example.com/callback" ]}
83+ r = await client .post ("/register" , json = payload )
84+ assert r .status_code == 201
85+ body = r .json ()
86+ assert body ["client_id" ] == proxy_server .CLIENT_ID
87+ assert body ["redirect_uris" ] == payload ["redirect_uris" ]
88+ # client_secret may be None, but the field should exist (masked or real)
89+ assert "client_secret" in body
90+
91+
92+ @pytest .mark .anyio
93+ async def test_authorize_redirect (client , proxy_server ):
94+ params = {
95+ "response_type" : "code" ,
96+ "state" : "xyz" ,
97+ "redirect_uri" : "https://client.example.com/callback" ,
98+ "client_id" : proxy_server .CLIENT_ID ,
99+ "code_challenge" : "testchallenge" ,
100+ "code_challenge_method" : "S256" ,
101+ }
102+ r = await client .get ("/authorize" , params = params , follow_redirects = False )
103+ assert r .status_code in {302 , 307 }
104+
105+ location = r .headers ["location" ]
106+ parsed = urllib .parse .urlparse (location )
107+ assert parsed .scheme .startswith ("http" )
108+ assert parsed .netloc == urllib .parse .urlparse (proxy_server .UPSTREAM_AUTHORIZE ).netloc
109+
110+ qs = urllib .parse .parse_qs (parsed .query )
111+ # Proxy should inject client_id & default scope
112+ assert qs ["client_id" ][0 ] == proxy_server .CLIENT_ID
113+ assert "scope" in qs
114+ # Original params preserved
115+ assert qs ["state" ][0 ] == "xyz"
116+
117+
118+ @pytest .mark .anyio
119+ async def test_revoke_proxy (client , monkeypatch , proxy_server ):
120+ original_post = httpx .AsyncClient .post
121+
122+ async def _mock_post (self , url , data = None , timeout = 10 , ** kwargs ): # noqa: D401
123+ if url .endswith ("/revoke" ):
124+ return httpx .Response (200 , json = {"revoked" : True })
125+ # For the test client's own request to /revoke, delegate to original implementation
126+ return await original_post (self , url , data = data , timeout = timeout , ** kwargs )
127+
128+ monkeypatch .setattr (httpx .AsyncClient , "post" , _mock_post , raising = True )
129+
130+ r = await client .post ("/revoke" , data = {"token" : "dummy" })
131+ assert r .status_code == 200
132+ assert r .json () == {"revoked" : True }
133+
134+
135+ @pytest .mark .anyio
136+ async def test_token_passthrough (client , monkeypatch , proxy_server ):
137+ """Ensure /token is proxied unchanged and response is returned verbatim."""
138+
139+ # Capture outgoing POSTs made by ProxyTokenHandler
140+ captured : dict [str , Any ] = {}
141+
142+ original_post = httpx .AsyncClient .post
143+
144+ async def _mock_post (self , url , * args , ** kwargs ): # noqa: D401
145+ if str (url ).startswith (proxy_server .UPSTREAM_TOKEN ):
146+ # Record exactly what was sent upstream
147+ captured ["url" ] = str (url )
148+ captured ["data" ] = kwargs .get ("data" )
149+ # Return a dummy upstream response
150+ return httpx .Response (
151+ 200 ,
152+ json = {
153+ "access_token" : "xyz" ,
154+ "token_type" : "bearer" ,
155+ "expires_in" : 3600 ,
156+ },
157+ )
158+ # Delegate any other POSTs to the real implementation
159+ return await original_post (self , url , * args , ** kwargs )
160+
161+ monkeypatch .setattr (httpx .AsyncClient , "post" , _mock_post , raising = True )
162+
163+ # ---------------- Act ----------------
164+ form = {
165+ "grant_type" : "authorization_code" ,
166+ "code" : "dummy-code" ,
167+ "client_id" : proxy_server .CLIENT_ID ,
168+ }
169+ r = await client .post ("/token" , data = form )
170+
171+ # ---------------- Assert -------------
172+ assert r .status_code == 200
173+ assert r .json ()["access_token" ] == "xyz"
174+
175+ # Verify the request payload was forwarded without modification
176+ assert captured ["data" ] == form
177+
178+
179+ # ---------------------------------------------------------------------------
180+ # Tool invocation – user_info
181+ # ---------------------------------------------------------------------------
182+
183+
184+ @pytest .mark .anyio
185+ async def test_user_info_tool (monkeypatch , proxy_server ):
186+ """Call the `user_info` tool directly with a mocked access token."""
187+ # Craft a dummy JWT with useful claims (header/payload/signature parts)
188+ payload = (
189+ base64 .urlsafe_b64encode (
190+ json .dumps (
191+ {
192+ "sub" : "test-user" ,
193+ "preferred_username" : "tester" ,
194+ }
195+ ).encode ()
196+ )
197+ .decode ()
198+ .rstrip ("=" )
199+ )
200+ dummy_token = f"header.{ payload } .signature"
201+
202+ from mcp .server .auth .middleware import auth_context
203+ from mcp .server .auth .provider import AccessToken # local import to avoid cycles
204+
205+ def _fake_get_access_token (): # noqa: D401
206+ return AccessToken (token = dummy_token , client_id = "client123" , scopes = ["openid" ], expires_at = None )
207+
208+ monkeypatch .setattr (auth_context , "get_access_token" , _fake_get_access_token , raising = True )
209+
210+ result = await proxy_server .mcp .call_tool ("user_info" , {})
211+
212+ # call_tool returns (content_blocks, raw_result)
213+ if isinstance (result , tuple ):
214+ _ , raw = result
215+ else :
216+ raw = result # fallback
217+
218+ assert raw ["authenticated" ] is True
219+ assert ("userid" in raw and raw ["userid" ] == "test-user" ) or ("user_id" in raw and raw ["user_id" ] == "test-user" )
220+ assert raw ["username" ] == "tester"
0 commit comments