Skip to content

Commit 3bf76f0

Browse files
committed
## Title: **Add Automatic Context Summarization to ClientSession**
1 parent 0b1b52b commit 3bf76f0

File tree

4 files changed

+432
-0
lines changed

4 files changed

+432
-0
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
3434
"jsonschema>=4.20.0",
3535
"pywin32>=310; sys_platform == 'win32'",
36+
"tiktoken>=0.9.0",
3637
]
3738

3839
[project.optional-dependencies]
@@ -59,6 +60,7 @@ dev = [
5960
"pytest-pretty>=1.2.0",
6061
"inline-snapshot>=0.23.0",
6162
"dirty-equals>=0.9.0",
63+
"pytest-asyncio>=1.1.0",
6264
]
6365
docs = [
6466
"mkdocs>=1.6.1",
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from datetime import timedelta
2+
from typing import Any
3+
4+
import tiktoken
5+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
6+
7+
from mcp.client.session import ClientSession
8+
from mcp.shared.context import RequestContext
9+
from mcp.shared.message import SessionMessage
10+
from mcp.types import CreateMessageRequestParams, CreateMessageResult, SamplingMessage, TextContent
11+
12+
DEFAULT_MAX_TOKENS = 4000
13+
DEFAULT_SUMMARIZE_THRESHOLD = 0.8
14+
DEFAULT_SUMMARY_PROMPT = "Summarize the following conversation succinctly, preserving key facts:\n\n"
15+
16+
17+
class ClientSessionSummarizing(ClientSession):
18+
def __init__(
19+
self,
20+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
21+
write_stream: MemoryObjectSendStream[SessionMessage],
22+
read_timeout_seconds: timedelta | None = None,
23+
sampling_callback: Any | None = None,
24+
elicitation_callback: Any | None = None,
25+
list_roots_callback: Any | None = None,
26+
logging_callback: Any | None = None,
27+
message_handler: Any | None = None,
28+
client_info: Any | None = None,
29+
max_tokens: int | None = None,
30+
summarize_threshold: float | None = None,
31+
summary_prompt: str | None = None,
32+
) -> None:
33+
super().__init__(
34+
read_stream=read_stream,
35+
write_stream=write_stream,
36+
read_timeout_seconds=read_timeout_seconds,
37+
sampling_callback=sampling_callback,
38+
elicitation_callback=elicitation_callback,
39+
list_roots_callback=list_roots_callback,
40+
logging_callback=logging_callback,
41+
message_handler=message_handler,
42+
client_info=client_info,
43+
)
44+
self.history: list[SamplingMessage] = []
45+
self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
46+
self.summarize_threshold = summarize_threshold or DEFAULT_SUMMARIZE_THRESHOLD
47+
self.summary_prompt = summary_prompt or DEFAULT_SUMMARY_PROMPT
48+
# Override the sampling callback to include our summarization logic
49+
self._sampling_callback = self._summarizing_sampling_callback
50+
51+
async def _summarizing_sampling_callback(
52+
self,
53+
context: RequestContext["ClientSession", Any],
54+
params: CreateMessageRequestParams,
55+
) -> CreateMessageResult:
56+
"""Custom sampling callback that includes summarization logic."""
57+
# Add messages to history
58+
self.history.extend(params.messages)
59+
60+
# Check if we need to summarize
61+
if self.token_count() > self.max_tokens * self.summarize_threshold:
62+
await self.summarize_context()
63+
64+
# For now, return a simple response
65+
# In a real implementation, you might want to call an LLM service here
66+
return CreateMessageResult(
67+
role="assistant",
68+
content=TextContent(type="text", text="Message processed with summarization"),
69+
model="summarizing-model",
70+
stopReason="endTurn",
71+
)
72+
73+
def token_count(self) -> int:
74+
"""Calculate token count for all messages in history."""
75+
tokenizer = tiktoken.get_encoding("cl100k_base")
76+
total_tokens = 0
77+
78+
for message in self.history:
79+
if isinstance(message.content, TextContent):
80+
total_tokens += len(tokenizer.encode(message.content.text))
81+
elif isinstance(message.content, str):
82+
total_tokens += len(tokenizer.encode(message.content))
83+
84+
return total_tokens
85+
86+
async def summarize_context(self) -> None:
87+
"""Summarize the conversation history and replace it with a summary."""
88+
if not self.history:
89+
return
90+
91+
# Create a summary prompt from all messages
92+
summary_text = self.summary_prompt
93+
for message in self.history:
94+
if isinstance(message.content, TextContent):
95+
summary_text += f"{message.role}: {message.content.text}\n"
96+
elif isinstance(message.content, str):
97+
summary_text += f"{message.role}: {message.content}\n"
98+
99+
# Create a summary message
100+
summary_message = SamplingMessage(role="assistant", content=TextContent(type="text", text=summary_text))
101+
102+
# Replace history with summary
103+
self.history = [summary_message]
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from typing import Any
2+
3+
import anyio
4+
import pytest
5+
6+
from mcp.client.client_session_summarizing import (
7+
DEFAULT_MAX_TOKENS,
8+
DEFAULT_SUMMARIZE_THRESHOLD,
9+
DEFAULT_SUMMARY_PROMPT,
10+
ClientSessionSummarizing,
11+
)
12+
from mcp.shared.context import RequestContext
13+
from mcp.types import (
14+
CreateMessageRequestParams,
15+
CreateMessageResult,
16+
SamplingMessage,
17+
TextContent,
18+
)
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_summarizing_session():
23+
send_stream, receive_stream = anyio.create_memory_object_stream(10)
24+
try:
25+
session = ClientSessionSummarizing(
26+
read_stream=receive_stream,
27+
write_stream=send_stream,
28+
)
29+
30+
# Create real messages instead of simple strings
31+
messages = [SamplingMessage(role="user", content=TextContent(type="text", text="Hello")) for _ in range(3500)]
32+
session.history = messages # Simulate approaching token limit
33+
34+
assert session.token_count() > session.max_tokens * session.summarize_threshold
35+
36+
# Test that summarization works
37+
await session.summarize_context()
38+
39+
# After summarization, history should contain only one message
40+
assert len(session.history) == 1
41+
assert isinstance(session.history[0], SamplingMessage)
42+
assert session.history[0].role == "assistant"
43+
44+
finally:
45+
await send_stream.aclose()
46+
await receive_stream.aclose()
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_sampling_callback():
51+
"""Test sampling callback with ClientSessionSummarizing"""
52+
send_stream, receive_stream = anyio.create_memory_object_stream(10)
53+
try:
54+
session = ClientSessionSummarizing(
55+
read_stream=receive_stream,
56+
write_stream=send_stream,
57+
)
58+
59+
# Create request parameters
60+
params = CreateMessageRequestParams(
61+
messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello world"))], maxTokens=100
62+
)
63+
64+
# Create simple context for testing
65+
context: Any = RequestContext(session=session, request_id=1, meta=None, lifespan_context=None)
66+
67+
# Call sampling callback
68+
result = await session._summarizing_sampling_callback(context, params)
69+
70+
# Verify the result is correct
71+
assert isinstance(result, CreateMessageResult)
72+
assert result.role == "assistant"
73+
assert isinstance(result.content, TextContent)
74+
assert "Message processed with summarization" in result.content.text
75+
76+
# Verify message was added to history
77+
assert len(session.history) == 1
78+
assert session.history[0].role == "user"
79+
assert isinstance(session.history[0].content, TextContent)
80+
assert session.history[0].content.text == "Hello world"
81+
82+
finally:
83+
await send_stream.aclose()
84+
await receive_stream.aclose()
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_custom_summary_prompt():
89+
"""Test that user can define custom prompt"""
90+
send_stream, receive_stream = anyio.create_memory_object_stream(10)
91+
try:
92+
custom_prompt = "Custom summary prompt:\n\n"
93+
session = ClientSessionSummarizing(
94+
read_stream=receive_stream,
95+
write_stream=send_stream,
96+
summary_prompt=custom_prompt,
97+
)
98+
99+
# Verify user can define custom prompt
100+
assert session.summary_prompt == custom_prompt
101+
assert session.summary_prompt != DEFAULT_SUMMARY_PROMPT
102+
103+
# Test that summarization uses custom prompt
104+
session.history = [SamplingMessage(role="user", content=TextContent(type="text", text="Test message"))]
105+
106+
await session.summarize_context()
107+
108+
# Verify summary contains custom prompt
109+
assert len(session.history) == 1
110+
summary_content = session.history[0].content
111+
assert isinstance(summary_content, TextContent)
112+
assert custom_prompt in summary_content.text
113+
114+
finally:
115+
await send_stream.aclose()
116+
await receive_stream.aclose()
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_default_summary_prompt():
121+
"""Test that user gets default prompt if not specified"""
122+
send_stream, receive_stream = anyio.create_memory_object_stream(10)
123+
try:
124+
session = ClientSessionSummarizing(
125+
read_stream=receive_stream,
126+
write_stream=send_stream,
127+
)
128+
129+
# Verify user gets default prompt
130+
assert session.summary_prompt == DEFAULT_SUMMARY_PROMPT
131+
132+
finally:
133+
await send_stream.aclose()
134+
await receive_stream.aclose()
135+
136+
137+
@pytest.mark.asyncio
138+
async def test_custom_max_tokens():
139+
"""Test that user can define custom max tokens"""
140+
send_stream, receive_stream = anyio.create_memory_object_stream(10)
141+
try:
142+
custom_max_tokens = 2000
143+
session = ClientSessionSummarizing(
144+
read_stream=receive_stream,
145+
write_stream=send_stream,
146+
max_tokens=custom_max_tokens,
147+
)
148+
149+
# Verify user can define custom max tokens
150+
assert session.max_tokens == custom_max_tokens
151+
assert session.max_tokens != DEFAULT_MAX_TOKENS
152+
153+
finally:
154+
await send_stream.aclose()
155+
await receive_stream.aclose()
156+
157+
158+
@pytest.mark.asyncio
159+
async def test_custom_summarize_threshold():
160+
"""Test that user can define custom summarize threshold"""
161+
send_stream, receive_stream = anyio.create_memory_object_stream(10)
162+
try:
163+
custom_threshold = 0.5
164+
session = ClientSessionSummarizing(
165+
read_stream=receive_stream,
166+
write_stream=send_stream,
167+
summarize_threshold=custom_threshold,
168+
)
169+
170+
# Verify user can define custom threshold
171+
assert session.summarize_threshold == custom_threshold
172+
assert session.summarize_threshold != DEFAULT_SUMMARIZE_THRESHOLD
173+
174+
finally:
175+
await send_stream.aclose()
176+
await receive_stream.aclose()
177+
178+
179+
@pytest.mark.asyncio
180+
async def test_default_parameters():
181+
"""Test that user gets default parameters if not specified"""
182+
send_stream, receive_stream = anyio.create_memory_object_stream(10)
183+
try:
184+
session = ClientSessionSummarizing(
185+
read_stream=receive_stream,
186+
write_stream=send_stream,
187+
)
188+
189+
# Verify user gets default parameters
190+
assert session.max_tokens == DEFAULT_MAX_TOKENS
191+
assert session.summarize_threshold == DEFAULT_SUMMARIZE_THRESHOLD
192+
assert session.summary_prompt == DEFAULT_SUMMARY_PROMPT
193+
194+
finally:
195+
await send_stream.aclose()
196+
await receive_stream.aclose()

0 commit comments

Comments
 (0)