Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions src/mcp/server/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

import abc
import time
from dataclasses import dataclass
from typing import Protocol, runtime_checkable

from mcp.server.session import ServerSession
from mcp.types import (
CheckpointCreateParams,
CheckpointCreateResult,
CheckpointValidateParams,
CheckpointValidateResult,
CheckpointResumeParams,
CheckpointResumeResult,
CheckpointDeleteParams,
CheckpointDeleteResult,
)


@runtime_checkable
class CheckpointBackend(Protocol):
"""Backend that actually stores and restores state behind handles."""

async def create_checkpoint(
self,
session: ServerSession,
params: CheckpointCreateParams,
) -> CheckpointCreateResult: ...

async def validate_checkpoint(
self,
session: ServerSession,
params: CheckpointValidateParams,
) -> CheckpointValidateResult: ...

async def resume_checkpoint(
self,
session: ServerSession,
params: CheckpointResumeParams,
) -> CheckpointResumeResult: ...

async def delete_checkpoint(
self,
session: ServerSession,
params: CheckpointDeleteParams,
) -> CheckpointDeleteResult: ...


@dataclass
class InMemoryHandleEntry:
value: object
digest: str
expires_at: float


class InMemoryCheckpointBackend(CheckpointBackend):
"""Simple in-memory backend you can use for tests/POC.

This is intentionally generic; concrete servers (data, browser, etc.)
decide *what* `value` is and how to interpret it.
"""

def __init__(self, ttl_seconds: int = 1800) -> None:
self._ttl = ttl_seconds
self._handles: dict[str, InMemoryHandleEntry] = {}

def _now(self) -> float:
return time.time()

async def create_checkpoint(
self,
session: ServerSession,
params: CheckpointCreateParams,
) -> CheckpointCreateResult:
# session.fastmcp or session.server can expose some "current state"
# For now you can override this backend in your server and implement
# your own snapshot logic.
raise NotImplementedError(
"Subclass InMemoryCheckpointBackend and override create_checkpoint "
"to capture concrete state (e.g. data tables, browser session)."
)

async def validate_checkpoint(
self,
session: ServerSession,
params: CheckpointValidateParams,
) -> CheckpointValidateResult:
entry = self._handles.get(params.handle)
if not entry:
return CheckpointValidateResult(
valid=False,
remainingTtlSeconds=0,
digestMatch=False,
)

now = self._now()
if now >= entry.expires_at:
return CheckpointValidateResult(
valid=False,
remainingTtlSeconds=0,
digestMatch=params.expectedDigest == entry.digest,
)

remaining = int(entry.expires_at - now)
return CheckpointValidateResult(
valid=True,
remainingTtlSeconds=remaining,
digestMatch=(
params.expectedDigest is None
or params.expectedDigest == entry.digest
),
)

async def resume_checkpoint(
self,
session: ServerSession,
params: CheckpointResumeParams,
) -> CheckpointResumeResult:
entry = self._handles.get(params.handle)
if not entry:
# You’ll map this to HANDLE_NOT_FOUND at JSON-RPC level
return CheckpointResumeResult(resumed=False, handle=params.handle)

if self._now() >= entry.expires_at:
# Map to EXPIRED
return CheckpointResumeResult(resumed=False, handle=params.handle)

# Subclasses should take `entry.value` and rehydrate into session state.
raise NotImplementedError(
"Subclass InMemoryCheckpointBackend.resume_checkpoint to rehydrate "
"concrete session state from stored value."
)

async def delete_checkpoint(
self,
session: ServerSession,
params: CheckpointDeleteParams,
) -> CheckpointDeleteResult:
deleted = params.handle in self._handles
self._handles.pop(params.handle, None)
return CheckpointDeleteResult(deleted=deleted)
4 changes: 4 additions & 0 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.server.checkpoint import CheckpointBackend
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations
from mcp.types import Prompt as MCPPrompt
Expand Down Expand Up @@ -173,6 +174,7 @@ def __init__( # noqa: PLR0913
lifespan: (Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None) = None,
auth: AuthSettings | None = None,
transport_security: TransportSecuritySettings | None = None,
checkpoint_backend: CheckpointBackend | None = None,
):
# Auto-enable DNS rebinding protection for localhost (IPv4 and IPv6)
if transport_security is None and host in ("127.0.0.1", "localhost", "::1"):
Expand Down Expand Up @@ -202,6 +204,7 @@ def __init__( # noqa: PLR0913
transport_security=transport_security,
)

self._checkpoint_backend = checkpoint_backend
self._mcp_server = MCPServer(
name=name or "FastMCP",
instructions=instructions,
Expand All @@ -210,6 +213,7 @@ def __init__( # noqa: PLR0913
# TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server.
# We need to create a Lifespan type that is a generic on the server type, like Starlette does.
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore
checkpoint_backend=self._checkpoint_backend,
)
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
Expand Down
7 changes: 7 additions & 0 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ async def main():
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
from mcp.server.checkpoint import CheckpointBackend

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -146,6 +147,9 @@ def __init__(
[Server[LifespanResultT, RequestT]],
AbstractAsyncContextManager[LifespanResultT],
] = lifespan,
*,
stateless: bool = False,
checkpoint_backend: CheckpointBackend | None = None,
):
self.name = name
self.version = version
Expand All @@ -159,6 +163,8 @@ def __init__(
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
self._tool_cache: dict[str, types.Tool] = {}
self._experimental_handlers: ExperimentalHandlers | None = None
self._stateless = stateless
self._checkpoint_backend = checkpoint_backend
logger.debug("Initializing server %r", name)

def create_initialization_options(
Expand Down Expand Up @@ -650,6 +656,7 @@ async def run(
write_stream,
initialization_options,
stateless=stateless,
checkpoint_backend=self._checkpoint_backend,
)
)

Expand Down
14 changes: 11 additions & 3 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
"""

from enum import Enum
from typing import Any, TypeVar, overload
from typing import Any, TypeVar, overload, TYPE_CHECKING

import anyio
import anyio.lowlevel
Expand All @@ -57,7 +57,8 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
RequestResponder,
)
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

if TYPE_CHECKING:
from mcp.server.checkpoint import CheckpointBackend

class InitializationState(Enum):
NotInitialized = 1
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
write_stream: MemoryObjectSendStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
checkpoint_backend: "CheckpointBackend | None" = None,
) -> None:
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
self._initialization_state = (
Expand All @@ -102,6 +104,7 @@ def __init__(
ServerRequestResponder
](0)
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
self._checkpoint_backend = checkpoint_backend

@property
def client_params(self) -> types.InitializeRequestParams | None:
Expand All @@ -116,6 +119,11 @@ def experimental(self) -> ExperimentalServerSessionFeatures:
if self._experimental_features is None:
self._experimental_features = ExperimentalServerSessionFeatures(self)
return self._experimental_features

@property
def checkpoint_backend(self) -> "CheckpointBackend | None":
"""Optional checkpoint backend attached to this session."""
return self._checkpoint_backend

def check_client_capability(self, capability: types.ClientCapabilities) -> bool: # pragma: no cover
"""Check if the client supports a specific capability."""
Expand Down Expand Up @@ -688,4 +696,4 @@ async def _handle_incoming(self, req: ServerRequestResponder) -> None:
def incoming_messages(
self,
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
return self._incoming_message_stream_reader
return self._incoming_message_stream_reader
64 changes: 64 additions & 0 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,3 +1996,67 @@ class ServerNotification(RootModel[ServerNotificationType]):

class ServerResult(RootModel[ServerResultType]):
pass


# --- Checkpoint protocol extensions -----------------------------------------

class CheckpointHandle(BaseModel):
"""Opaque checkpoint handle returned by servers."""
handle: str
digest: str
ttlSeconds: int


class CheckpointCreateParams(BaseModel):
"""Params for checkpoint/create.

For v1 you can keep this empty – the server infers the session
from transport/session context – but we define it for forward compat.
"""
# Optional: allow tools to tag a logical name
label: str | None = None


class CheckpointCreateResult(BaseModel):
"""Result of checkpoint/create."""
handle: str
digest: str
ttlSeconds: int


class CheckpointValidateParams(BaseModel):
"""Params for checkpoint/validate."""
handle: str
expectedDigest: str | None = None


class CheckpointValidateResult(BaseModel):
"""Result of checkpoint/validate."""
valid: bool
remainingTtlSeconds: int
digestMatch: bool


class CheckpointResumeParams(BaseModel):
"""Params for checkpoint/resume."""
handle: str


class CheckpointResumeResult(BaseModel):
"""Result of checkpoint/resume.

You can expand this later if you want to
surface metadata to the client.
"""
resumed: bool
handle: str


class CheckpointDeleteParams(BaseModel):
"""Params for checkpoint/delete."""
handle: str


class CheckpointDeleteResult(BaseModel):
"""Result of checkpoint/delete."""
deleted: bool
Loading