Skip to content

Commit b39061f

Browse files
committed
Added "coverage" test
1 parent ae8d1c1 commit b39061f

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from unittest.mock import AsyncMock, Mock
2+
3+
import pytest
4+
from starlette.requests import Request
5+
from starlette.types import Scope
6+
7+
import mcp.types as types
8+
from mcp.server.auth.middleware.auth_context import auth_context_var, get_access_token
9+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
10+
from mcp.server.auth.provider import AccessToken
11+
from mcp.server.lowlevel.server import Server
12+
from mcp.server.session import ServerSession
13+
from mcp.shared.message import ServerMessageMetadata
14+
from mcp.shared.session import RequestResponder
15+
16+
17+
@pytest.mark.anyio
18+
async def test_handle_request_sets_auth_context_from_request() -> None:
19+
server = Server("test-server")
20+
21+
@server.list_tools()
22+
async def handle_list_tools() -> list[types.Tool]:
23+
return [
24+
types.Tool(
25+
name="echo_access_token",
26+
description="Return access token",
27+
inputSchema={"type": "object", "properties": {}},
28+
)
29+
]
30+
31+
@server.call_tool()
32+
async def handle_call_tool(name: str, arguments: dict[str, object]) -> list[types.TextContent]:
33+
assert name == "echo_access_token"
34+
access_token = get_access_token()
35+
token = access_token.token if access_token else ""
36+
return [types.TextContent(type="text", text=token)]
37+
38+
access_token = AccessToken(token="test-token", client_id="client", scopes=["test"])
39+
user = AuthenticatedUser(access_token)
40+
headers: list[tuple[bytes, bytes]] = []
41+
scope: Scope = {
42+
"type": "http",
43+
"method": "POST",
44+
"path": "/mcp",
45+
"headers": headers,
46+
"user": user,
47+
}
48+
request = Request(scope)
49+
50+
message = Mock(spec=RequestResponder)
51+
message.request_id = "req-1"
52+
message.request_meta = None
53+
message.message_metadata = ServerMessageMetadata(request_context=request)
54+
message.respond = AsyncMock()
55+
56+
session = Mock(spec=ServerSession)
57+
session.client_params = None
58+
59+
call_request = types.CallToolRequest(params=types.CallToolRequestParams(name="echo_access_token", arguments={}))
60+
61+
await server._handle_request(message, call_request, session, {}, raise_exceptions=False)
62+
63+
assert auth_context_var.get() is None
64+
assert message.respond.called
65+
response = message.respond.call_args.args[0]
66+
assert isinstance(response.root, types.CallToolResult)
67+
content = response.root.content[0]
68+
assert isinstance(content, types.TextContent)
69+
assert content.text == "test-token"

0 commit comments

Comments
 (0)