Skip to content

Commit 49e20a4

Browse files
committed
feat: add RedisEventStore for production SSE resumability
1 parent e8e6484 commit 49e20a4

7 files changed

Lines changed: 566 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies = [
4747
rich = ["rich>=13.9.4"]
4848
cli = ["typer>=0.16.0", "python-dotenv>=1.0.0"]
4949
ws = ["websockets>=15.0.1"]
50+
redis = ["redis[asyncio]>=4.2.0"]
5051

5152
[project.scripts]
5253
mcp = "mcp.cli:app [cli]"
@@ -91,6 +92,7 @@ dev = [
9192
"pillow>=12.0",
9293
"strict-no-cover",
9394
"logfire>=3.0.0",
95+
"fakeredis>=2.26.0",
9496
]
9597
docs = [
9698
"mkdocs>=1.6.1",

src/mcp/server/contrib/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Optional production-grade add-ons for MCP servers.
2+
3+
WARNING: These modules require optional dependencies that are NOT installed by default.
4+
Install the relevant extra before importing:
5+
6+
pip install "mcp[redis]"
7+
8+
Then import directly from the submodule:
9+
10+
from mcp.server.contrib.event_stores import RedisEventStore
11+
"""
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""EventStore implementations for production deployments."""
2+
3+
from mcp.server.contrib.event_stores.redis import RedisEventStore
4+
5+
__all__ = ["RedisEventStore"]
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""Redis-backed EventStore for MCP SSE stream resumability.
2+
3+
Requires the redis extra:
4+
pip install "mcp[redis]"
5+
6+
Quickstart:
7+
import redis.asyncio as aioredis
8+
from mcp.server.contrib.event_stores import RedisEventStore
9+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
10+
11+
redis_client = aioredis.from_url("redis://localhost:6379")
12+
store = RedisEventStore(redis_client, ttl=3600)
13+
14+
session_manager = StreamableHTTPSessionManager(
15+
app=mcp_server,
16+
event_store=store,
17+
)
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import logging
23+
from typing import Any
24+
25+
from mcp.server.streamable_http import (
26+
EventCallback,
27+
EventId,
28+
EventMessage,
29+
EventStore,
30+
StreamId,
31+
)
32+
from mcp.types import JSONRPCMessage, jsonrpc_message_adapter
33+
34+
logger = logging.getLogger(__name__)
35+
36+
37+
class RedisEventStore(EventStore):
38+
"""EventStore backed by Redis for production multi-process deployments.
39+
40+
Redis data layout:
41+
{prefix}counter — STRING, atomic INCR source for EventIds
42+
{prefix}event:{event_id} — HASH, fields: stream_id + payload
43+
{prefix}stream:{stream_id} — ZSET, members: event_ids, scores: int(event_id)
44+
45+
Args:
46+
redis: An already-connected redis.asyncio.Redis instance.
47+
key_prefix: Prefix for all Redis keys. Use different prefixes when
48+
multiple MCP servers share one Redis instance.
49+
Default: "mcp:".
50+
ttl: Seconds after which keys expire automatically.
51+
None means keys never expire — strongly discouraged in
52+
production. Recommended: at least 2× session_idle_timeout.
53+
"""
54+
55+
def __init__(
56+
self,
57+
redis: Any, # redis.asyncio.Redis at runtime
58+
*,
59+
key_prefix: str = "mcp:",
60+
ttl: int | None = None,
61+
) -> None:
62+
self._redis = redis
63+
self._prefix = key_prefix
64+
self._ttl = ttl
65+
66+
if ttl is None:
67+
logger.warning(
68+
"RedisEventStore created with ttl=None. "
69+
"Events will accumulate indefinitely in Redis. "
70+
"Set ttl= to a positive number of seconds "
71+
"(recommended: at least 2× your session_idle_timeout)."
72+
)
73+
74+
# Key helpers
75+
76+
def _counter_key(self) -> str:
77+
return f"{self._prefix}counter"
78+
79+
def _event_key(self, event_id: EventId) -> str:
80+
return f"{self._prefix}event:{event_id}"
81+
82+
def _stream_key(self, stream_id: StreamId) -> str:
83+
return f"{self._prefix}stream:{stream_id}"
84+
85+
# EventStore interface
86+
87+
async def store_event(
88+
self,
89+
stream_id: StreamId,
90+
message: JSONRPCMessage | None,
91+
) -> EventId:
92+
"""Store an event and return its unique, monotonically increasing ID."""
93+
# Atomic increment — safe under concurrent writes from multiple workers
94+
event_id_int: int = await self._redis.incr(self._counter_key())
95+
event_id: EventId = str(event_id_int)
96+
97+
# Serialise — empty string is the sentinel for priming events (no payload)
98+
if message is None:
99+
payload = ""
100+
else:
101+
payload = jsonrpc_message_adapter.dump_json(
102+
message,
103+
by_alias=True,
104+
exclude_none=True,
105+
).decode("utf-8")
106+
107+
# Store event metadata: which stream it belongs to + its payload
108+
await self._redis.hset(
109+
self._event_key(event_id),
110+
mapping={
111+
"stream_id": stream_id,
112+
"payload": payload,
113+
},
114+
)
115+
116+
# Register in the stream's sorted set — score = int(event_id) for range queries
117+
await self._redis.zadd(
118+
self._stream_key(stream_id),
119+
{event_id: event_id_int},
120+
)
121+
122+
# Refresh TTL on all touched keys (if configured)
123+
if self._ttl is not None:
124+
await self._redis.expire(self._event_key(event_id), self._ttl)
125+
await self._redis.expire(self._stream_key(stream_id), self._ttl)
126+
await self._redis.expire(self._counter_key(), self._ttl)
127+
128+
return event_id
129+
130+
async def replay_events_after(
131+
self,
132+
last_event_id: EventId,
133+
send_callback: EventCallback,
134+
) -> StreamId | None:
135+
"""Replay all events on the same stream that occurred after last_event_id."""
136+
# Look up which stream owns this event ID
137+
stream_id_raw: bytes | None = await self._redis.hget(self._event_key(last_event_id), "stream_id")
138+
139+
if stream_id_raw is None:
140+
# Unknown or expired event ID — return None, don't raise
141+
return None
142+
143+
stream_id: StreamId = stream_id_raw.decode("utf-8")
144+
145+
# Fetch all event IDs in this stream with id strictly greater than last_event_id
146+
last_int = int(last_event_id)
147+
raw_ids: list[bytes] = await self._redis.zrangebyscore(
148+
self._stream_key(stream_id),
149+
min=last_int + 1,
150+
max="+inf",
151+
)
152+
153+
for eid_bytes in raw_ids:
154+
eid: EventId = eid_bytes.decode("utf-8")
155+
156+
payload_raw: bytes | None = await self._redis.hget(self._event_key(eid), "payload")
157+
158+
if payload_raw is None:
159+
# Key expired between ZRANGEBYSCORE and HGET — skip silently
160+
logger.debug("Event %s payload missing during replay (expired?)", eid)
161+
continue
162+
163+
payload_str = payload_raw.decode("utf-8")
164+
165+
if not payload_str:
166+
# Empty string = priming event — never sent to clients
167+
continue
168+
169+
message = jsonrpc_message_adapter.validate_json(payload_str)
170+
await send_callback(EventMessage(message=message, event_id=eid))
171+
172+
return stream_id

tests/server/contrib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# contrib tests package

0 commit comments

Comments
 (0)