-
Notifications
You must be signed in to change notification settings - Fork 598
Add DistributedCacheEventStreamStore
#1136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: mbuck/resumability-redelivery
Are you sure you want to change the base?
Changes from all commits
c1e510f
3655703
6160ab9
e01c4f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
|
|
||
| // This is a shared source file included in both ModelContextProtocol.Core and the test project. | ||
| // Do not reference symbols internal to the core project, as they won't be available in tests. | ||
|
|
||
| using System.Text; | ||
|
|
||
| namespace ModelContextProtocol.Server; | ||
|
|
||
| /// <summary> | ||
| /// Provides methods for formatting and parsing event IDs used by <see cref="DistributedCacheEventStreamStore"/>. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// Event IDs are formatted as "{base64(sessionId)}:{base64(streamId)}:{sequence}". | ||
| /// </remarks> | ||
| internal static class DistributedCacheEventIdFormatter | ||
| { | ||
| private const char Separator = ':'; | ||
|
|
||
| /// <summary> | ||
| /// Formats session ID, stream ID, and sequence number into an event ID string. | ||
| /// </summary> | ||
| public static string Format(string sessionId, string streamId, long sequence) | ||
| { | ||
| // Base64-encode session and stream IDs so the event ID can be parsed | ||
| // even if the original IDs contain the ':' separator character | ||
| var sessionBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(sessionId)); | ||
| var streamBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(streamId)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Separate from this PR, we should really add Base64 overloads that handle this without the intermediate byte[]. I will follow up. |
||
| return $"{sessionBase64}{Separator}{streamBase64}{Separator}{sequence}"; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Attempts to parse an event ID into its component parts. | ||
| /// </summary> | ||
| public static bool TryParse(string eventId, out string sessionId, out string streamId, out long sequence) | ||
| { | ||
| sessionId = string.Empty; | ||
| streamId = string.Empty; | ||
| sequence = 0; | ||
|
|
||
| var parts = eventId.Split(Separator); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On Core you could use the span-based Split, which would avoid the string[] and also avoid needing to materialize strings for parts[0]/[1]/[2]. |
||
| if (parts.Length != 3) | ||
| { | ||
| return false; | ||
| } | ||
|
|
||
| try | ||
| { | ||
| sessionId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[0])); | ||
| streamId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[1])); | ||
| return long.TryParse(parts[2], out sequence); | ||
| } | ||
| catch | ||
| { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,300 @@ | ||
| using Microsoft.Extensions.Caching.Distributed; | ||
| using ModelContextProtocol.Protocol; | ||
| using System.Net.ServerSentEvents; | ||
| using System.Runtime.CompilerServices; | ||
| using System.Text.Json; | ||
|
|
||
| namespace ModelContextProtocol.Server; | ||
|
|
||
| /// <summary> | ||
| /// An <see cref="ISseEventStreamStore"/> implementation backed by <see cref="IDistributedCache"/>. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// <para> | ||
| /// This implementation stores SSE events in a distributed cache, enabling resumability across | ||
| /// multiple server instances. Event IDs are encoded with session, stream, and sequence information | ||
| /// to allow efficient retrieval of events after a given point. | ||
| /// </para> | ||
| /// <para> | ||
| /// The writer maintains in-memory state for sequence number generation, as there is guaranteed | ||
| /// to be only one writer per stream. Readers may be created from separate processes. | ||
| /// </para> | ||
| /// </remarks> | ||
| public sealed class DistributedCacheEventStreamStore : ISseEventStreamStore | ||
| { | ||
| private readonly IDistributedCache _cache; | ||
| private readonly DistributedCacheEventStreamStoreOptions _options; | ||
|
|
||
| /// <summary> | ||
| /// Initializes a new instance of the <see cref="DistributedCacheEventStreamStore"/> class. | ||
| /// </summary> | ||
| /// <param name="cache">The distributed cache to use for storage.</param> | ||
| /// <param name="options">Optional configuration options for the store.</param> | ||
| public DistributedCacheEventStreamStore(IDistributedCache cache, DistributedCacheEventStreamStoreOptions? options = null) | ||
| { | ||
| Throw.IfNull(cache); | ||
| _cache = cache; | ||
| _options = options ?? new(); | ||
| } | ||
|
|
||
| /// <inheritdoc /> | ||
| public ValueTask<ISseEventStreamWriter> CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default) | ||
| { | ||
| Throw.IfNull(options); | ||
| var writer = new DistributedCacheEventStreamWriter(_cache, options.SessionId, options.StreamId, options.Mode, _options); | ||
| return new ValueTask<ISseEventStreamWriter>(writer); | ||
| } | ||
|
|
||
| /// <inheritdoc /> | ||
| public async ValueTask<ISseEventStreamReader?> GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default) | ||
| { | ||
| Throw.IfNull(lastEventId); | ||
|
|
||
| // Parse the event ID to get session, stream, and sequence information | ||
| if (!DistributedCacheEventIdFormatter.TryParse(lastEventId, out var sessionId, out var streamId, out var sequence)) | ||
| { | ||
| return null; | ||
| } | ||
|
|
||
| // Check if the stream exists by looking for its metadata | ||
| var metadataKey = CacheKeys.StreamMetadata(sessionId, streamId); | ||
| var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false); | ||
| if (metadataBytes is null) | ||
| { | ||
| return null; | ||
| } | ||
|
|
||
| var metadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata); | ||
| if (metadata is null) | ||
| { | ||
| return null; | ||
| } | ||
|
|
||
| var startSequence = sequence + 1; | ||
| return new DistributedCacheEventStreamReader(_cache, sessionId, streamId, startSequence, metadata, _options); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Provides methods for generating cache keys. | ||
| /// </summary> | ||
| internal static class CacheKeys | ||
| { | ||
| private const string Prefix = "mcp:sse:"; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we wanted to ship a new version that breaks the caching format and thus should ignore existing entries, how would we version it? change this prefix to include a version number? |
||
|
|
||
| public static string StreamMetadata(string sessionId, string streamId) => | ||
| $"{Prefix}meta:{sessionId}:{streamId}"; | ||
|
|
||
| public static string Event(string eventId) => | ||
| $"{Prefix}event:{eventId}"; | ||
|
|
||
| public static string StreamEventCount(string sessionId, string streamId) => | ||
| $"{Prefix}count:{sessionId}:{streamId}"; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Metadata about a stream stored in the cache. | ||
| /// </summary> | ||
| internal sealed class StreamMetadata | ||
| { | ||
| public SseEventStreamMode Mode { get; set; } | ||
| public bool IsCompleted { get; set; } | ||
| public long LastSequence { get; set; } | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Serialized representation of an SSE event stored in the cache. | ||
| /// </summary> | ||
| internal sealed class StoredEvent | ||
| { | ||
| public string? EventType { get; set; } | ||
| public string? EventId { get; set; } | ||
| public JsonRpcMessage? Data { get; set; } | ||
| } | ||
|
|
||
| private sealed class DistributedCacheEventStreamWriter : ISseEventStreamWriter | ||
| { | ||
| private readonly IDistributedCache _cache; | ||
| private readonly string _sessionId; | ||
| private readonly string _streamId; | ||
| private SseEventStreamMode _mode; | ||
| private readonly DistributedCacheEventStreamStoreOptions _options; | ||
| private long _sequence; | ||
| private bool _disposed; | ||
|
|
||
| public DistributedCacheEventStreamWriter( | ||
| IDistributedCache cache, | ||
| string sessionId, | ||
| string streamId, | ||
| SseEventStreamMode mode, | ||
| DistributedCacheEventStreamStoreOptions options) | ||
| { | ||
| _cache = cache; | ||
| _sessionId = sessionId; | ||
| _streamId = streamId; | ||
| _mode = mode; | ||
| _options = options; | ||
| } | ||
|
|
||
| public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default) | ||
| { | ||
| _mode = mode; | ||
| await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false); | ||
| } | ||
|
|
||
| public async ValueTask<SseItem<JsonRpcMessage?>> WriteEventAsync(SseItem<JsonRpcMessage?> sseItem, CancellationToken cancellationToken = default) | ||
| { | ||
| // Skip if already has an event ID | ||
| if (sseItem.EventId is not null) | ||
| { | ||
| return sseItem; | ||
| } | ||
|
|
||
| // Generate a new sequence number and event ID | ||
| var sequence = Interlocked.Increment(ref _sequence); | ||
| var eventId = DistributedCacheEventIdFormatter.Format(_sessionId, _streamId, sequence); | ||
| var newItem = sseItem with { EventId = eventId }; | ||
|
|
||
| // Store the event in the cache | ||
| var storedEvent = new StoredEvent | ||
| { | ||
| EventType = newItem.EventType, | ||
| EventId = eventId, | ||
| Data = newItem.Data, | ||
| }; | ||
|
|
||
| var eventBytes = JsonSerializer.SerializeToUtf8Bytes(storedEvent, McpJsonUtilities.JsonContext.Default.StoredEvent); | ||
| var eventKey = CacheKeys.Event(eventId); | ||
|
|
||
| await _cache.SetAsync(eventKey, eventBytes, new DistributedCacheEntryOptions | ||
| { | ||
| SlidingExpiration = _options.EventSlidingExpiration, | ||
| AbsoluteExpirationRelativeToNow = _options.EventAbsoluteExpiration, | ||
| }, cancellationToken).ConfigureAwait(false); | ||
|
|
||
| // Update metadata with the latest sequence | ||
| await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false); | ||
|
|
||
| return newItem; | ||
| } | ||
|
|
||
| private async ValueTask UpdateMetadataAsync(CancellationToken cancellationToken) | ||
| { | ||
| var metadata = new StreamMetadata | ||
| { | ||
| Mode = _mode, | ||
| IsCompleted = _disposed, | ||
| LastSequence = Interlocked.Read(ref _sequence), | ||
| }; | ||
|
|
||
| var metadataBytes = JsonSerializer.SerializeToUtf8Bytes(metadata, McpJsonUtilities.JsonContext.Default.StreamMetadata); | ||
| var metadataKey = CacheKeys.StreamMetadata(_sessionId, _streamId); | ||
|
|
||
| await _cache.SetAsync(metadataKey, metadataBytes, new DistributedCacheEntryOptions | ||
| { | ||
| SlidingExpiration = _options.MetadataSlidingExpiration, | ||
| AbsoluteExpirationRelativeToNow = _options.MetadataAbsoluteExpiration, | ||
| }, cancellationToken).ConfigureAwait(false); | ||
| } | ||
|
|
||
| public async ValueTask DisposeAsync() | ||
| { | ||
| if (_disposed) | ||
| { | ||
| return; | ||
| } | ||
|
|
||
| _disposed = true; | ||
|
|
||
| // Mark the stream as completed in the metadata | ||
| await UpdateMetadataAsync(CancellationToken.None).ConfigureAwait(false); | ||
| } | ||
| } | ||
|
|
||
| private sealed class DistributedCacheEventStreamReader : ISseEventStreamReader | ||
| { | ||
| private readonly IDistributedCache _cache; | ||
| private readonly long _startSequence; | ||
| private readonly StreamMetadata _initialMetadata; | ||
| private readonly DistributedCacheEventStreamStoreOptions _options; | ||
|
|
||
| public DistributedCacheEventStreamReader( | ||
| IDistributedCache cache, | ||
| string sessionId, | ||
| string streamId, | ||
| long startSequence, | ||
| StreamMetadata initialMetadata, | ||
| DistributedCacheEventStreamStoreOptions options) | ||
| { | ||
| _cache = cache; | ||
| SessionId = sessionId; | ||
| StreamId = streamId; | ||
| _startSequence = startSequence; | ||
| _initialMetadata = initialMetadata; | ||
| _options = options; | ||
| } | ||
|
|
||
| public string SessionId { get; } | ||
| public string StreamId { get; } | ||
|
|
||
| public async IAsyncEnumerable<SseItem<JsonRpcMessage?>> ReadEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) | ||
| { | ||
| // Start from the sequence after the last received event | ||
| var currentSequence = _startSequence; | ||
|
|
||
| // Use the initial metadata passed to the constructor for the first read. | ||
| var lastSequence = _initialMetadata.LastSequence; | ||
| var isCompleted = _initialMetadata.IsCompleted; | ||
| var mode = _initialMetadata.Mode; | ||
|
|
||
| while (!cancellationToken.IsCancellationRequested) | ||
| { | ||
| // Read all available events from currentSequence + 1 to lastSequence | ||
| for (; currentSequence <= lastSequence; currentSequence++) | ||
| { | ||
| cancellationToken.ThrowIfCancellationRequested(); | ||
|
|
||
| var eventId = DistributedCacheEventIdFormatter.Format(SessionId, StreamId, currentSequence); | ||
| var eventKey = CacheKeys.Event(eventId); | ||
| var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false) | ||
| ?? throw new McpException($"SSE event with ID '{eventId}' was not found in the cache. The event may have expired."); | ||
|
|
||
| var storedEvent = JsonSerializer.Deserialize(eventBytes, McpJsonUtilities.JsonContext.Default.StoredEvent); | ||
| if (storedEvent is not null) | ||
| { | ||
| yield return new SseItem<JsonRpcMessage?>(storedEvent.Data, storedEvent.EventType) | ||
| { | ||
| EventId = storedEvent.EventId, | ||
| }; | ||
| } | ||
| } | ||
|
|
||
| // If in polling mode, stop after returning currently available events | ||
| if (mode == SseEventStreamMode.Polling) | ||
| { | ||
| yield break; | ||
| } | ||
|
|
||
| // If the stream is completed and we've read all events, stop | ||
| if (isCompleted) | ||
| { | ||
| yield break; | ||
| } | ||
|
|
||
| // Wait before polling again for new events | ||
| await Task.Delay(_options.PollingInterval, cancellationToken).ConfigureAwait(false); | ||
|
|
||
| // Refresh metadata to get the latest sequence and completion status | ||
| var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId); | ||
| var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false) | ||
| ?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' was not found in the cache. The metadata may have expired."); | ||
|
|
||
| var currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata) | ||
| ?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' could not be deserialized."); | ||
|
|
||
| lastSequence = currentMetadata.LastSequence; | ||
| isCompleted = currentMetadata.IsCompleted; | ||
| mode = currentMetadata.Mode; | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ModelContextProtocol.Core the right assembly for this, or should it instead live in ModelContextProtocol or ModelContextProtocol.AspNetCore?