|
| 1 | +from unittest.mock import AsyncMock, Mock |
| 2 | + |
| 3 | +import pytest |
| 4 | +from starlette.requests import Request |
| 5 | +from starlette.types import Scope |
| 6 | + |
| 7 | +import mcp.types as types |
| 8 | +from mcp.server.auth.middleware.auth_context import auth_context_var, get_access_token |
| 9 | +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser |
| 10 | +from mcp.server.auth.provider import AccessToken |
| 11 | +from mcp.server.lowlevel.server import Server |
| 12 | +from mcp.server.session import ServerSession |
| 13 | +from mcp.shared.message import ServerMessageMetadata |
| 14 | +from mcp.shared.session import RequestResponder |
| 15 | + |
| 16 | + |
| 17 | +@pytest.mark.anyio |
| 18 | +async def test_handle_request_sets_auth_context_from_request() -> None: |
| 19 | + server = Server("test-server") |
| 20 | + |
| 21 | + @server.list_tools() |
| 22 | + async def handle_list_tools() -> list[types.Tool]: |
| 23 | + return [ |
| 24 | + types.Tool( |
| 25 | + name="echo_access_token", |
| 26 | + description="Return access token", |
| 27 | + inputSchema={"type": "object", "properties": {}}, |
| 28 | + ) |
| 29 | + ] |
| 30 | + |
| 31 | + @server.call_tool() |
| 32 | + async def handle_call_tool(name: str, arguments: dict[str, object]) -> list[types.TextContent]: |
| 33 | + assert name == "echo_access_token" |
| 34 | + access_token = get_access_token() |
| 35 | + token = access_token.token if access_token else "" |
| 36 | + return [types.TextContent(type="text", text=token)] |
| 37 | + |
| 38 | + access_token = AccessToken(token="test-token", client_id="client", scopes=["test"]) |
| 39 | + user = AuthenticatedUser(access_token) |
| 40 | + headers: list[tuple[bytes, bytes]] = [] |
| 41 | + scope: Scope = { |
| 42 | + "type": "http", |
| 43 | + "method": "POST", |
| 44 | + "path": "/mcp", |
| 45 | + "headers": headers, |
| 46 | + "user": user, |
| 47 | + } |
| 48 | + request = Request(scope) |
| 49 | + |
| 50 | + message = Mock(spec=RequestResponder) |
| 51 | + message.request_id = "req-1" |
| 52 | + message.request_meta = None |
| 53 | + message.message_metadata = ServerMessageMetadata(request_context=request) |
| 54 | + message.respond = AsyncMock() |
| 55 | + |
| 56 | + session = Mock(spec=ServerSession) |
| 57 | + session.client_params = None |
| 58 | + |
| 59 | + call_request = types.CallToolRequest(params=types.CallToolRequestParams(name="echo_access_token", arguments={})) |
| 60 | + |
| 61 | + await server._handle_request(message, call_request, session, {}, raise_exceptions=False) |
| 62 | + |
| 63 | + assert auth_context_var.get() is None |
| 64 | + assert message.respond.called |
| 65 | + response = message.respond.call_args.args[0] |
| 66 | + assert isinstance(response.root, types.CallToolResult) |
| 67 | + content = response.root.content[0] |
| 68 | + assert isinstance(content, types.TextContent) |
| 69 | + assert content.text == "test-token" |
0 commit comments