Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/Common/Polyfills/System/Text/EncodingExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ public static int GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, Spa
}
}
}

/// <summary>
/// Decodes all the bytes in the specified span into a string.
/// </summary>
public static string GetString(this Encoding encoding, ReadOnlySpan<byte> bytes)
{
if (bytes.IsEmpty)
{
return string.Empty;
}

unsafe
{
fixed (byte* bytesPtr = bytes)
{
return encoding.GetString(bytesPtr, bytes.Length);
}
}
}
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Client;
/// <summary>Provides the client side of a stdio-based session transport.</summary>
internal sealed class StdioClientSessionTransport(
StdioClientTransportOptions options, Process process, string endpointName, Queue<string> stderrRollingLog, ILoggerFactory? loggerFactory) :
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory)
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, endpointName, loggerFactory)
{
private readonly StdioClientTransportOptions _options = options;
private readonly Process _process = process;
Expand Down
103 changes: 79 additions & 24 deletions src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol;
using System.Buffers;
using System.IO.Pipelines;
using System.Text;
using System.Text.Json;

Expand All @@ -12,7 +14,7 @@ internal class StreamClientSessionTransport : TransportBase

internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false);

private readonly TextReader _serverOutput;
private readonly PipeReader _serverOutputPipe;
private readonly Stream _serverInputStream;
private readonly SemaphoreSlim _sendLock = new(1, 1);
private CancellationTokenSource? _shutdownCts = new();
Expand All @@ -27,9 +29,6 @@ internal class StreamClientSessionTransport : TransportBase
/// <param name="serverOutput">
/// The server's output stream. Messages read from this stream will be received from the server.
/// </param>
/// <param name="encoding">
/// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
/// </param>
/// <param name="endpointName">
/// A name that identifies this transport endpoint in logs.
/// </param>
Expand All @@ -40,18 +39,14 @@ internal class StreamClientSessionTransport : TransportBase
/// This constructor starts a background task to read messages from the server output stream.
/// The transport will be marked as connected once initialized.
/// </remarks>
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory)
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, string endpointName, ILoggerFactory? loggerFactory)
: base(endpointName, loggerFactory)
{
Throw.IfNull(serverInput);
Throw.IfNull(serverOutput);

_serverInputStream = serverInput;
#if NET
_serverOutput = new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#else
_serverOutput = new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#endif
_serverOutputPipe = PipeReader.Create(serverOutput);

SetConnected();

Expand Down Expand Up @@ -105,20 +100,41 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)

while (true)
{
if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line)
{
LogTransportEndOfStream(Name);
break;
}
ReadResult result = await _serverOutputPipe.ReadAsync(cancellationToken).ConfigureAwait(false);
ReadOnlySequence<byte> buffer = result.Buffer;

if (string.IsNullOrWhiteSpace(line))
SequencePosition? position;
while ((position = buffer.PositionOf((byte)'\n')) != null)
{
continue;
ReadOnlySequence<byte> line = buffer.Slice(0, position.Value);

// Trim trailing \r for Windows-style CRLF line endings.
if (EndsWithCarriageReturn(line))
{
line = line.Slice(0, line.Length - 1);
}

if (!line.IsEmpty)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportReceivedMessageSensitive(Name, GetString(line));
}

await ProcessLineAsync(line, cancellationToken).ConfigureAwait(false);
}

// Advance past the '\n'.
buffer = buffer.Slice(buffer.GetPosition(1, position.Value));
}

LogTransportReceivedMessageSensitive(Name, line);
_serverOutputPipe.AdvanceTo(buffer.Start, buffer.End);

await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
if (result.IsCompleted)
{
LogTransportEndOfStream(Name);
break;
}
}
}
catch (OperationCanceledException)
Expand All @@ -137,25 +153,38 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
}
}

private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken)
private async Task ProcessLineAsync(ReadOnlySequence<byte> line, CancellationToken cancellationToken)
{
try
{
var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)));
if (message != null)
JsonRpcMessage? message;
if (line.IsSingleSegment)
{
message = JsonSerializer.Deserialize(line.First.Span, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}
else
{
var reader = new Utf8JsonReader(line, isFinalBlock: true, state: default);
message = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}

if (message is not null)
{
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
else
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, line);
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, GetString(line));
}
}
}
catch (JsonException ex)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseFailedSensitive(Name, line, ex);
LogTransportMessageParseFailedSensitive(Name, GetString(line), ex);
}
else
{
Expand All @@ -164,6 +193,32 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati
}
}

private static string GetString(in ReadOnlySequence<byte> sequence) =>
sequence.IsSingleSegment
? Encoding.UTF8.GetString(sequence.First.Span)
: Encoding.UTF8.GetString(sequence.ToArray());

private static bool EndsWithCarriageReturn(in ReadOnlySequence<byte> sequence)
{
if (sequence.IsSingleSegment)
{
ReadOnlySpan<byte> span = sequence.First.Span;
return span.Length > 0 && span[span.Length - 1] == (byte)'\r';
}

// Multi-segment: find the last non-empty segment to check its last byte.
ReadOnlyMemory<byte> last = default;
foreach (ReadOnlyMemory<byte> segment in sequence)
{
if (!segment.IsEmpty)
{
last = segment;
}
}

return !last.IsEmpty && last.Span[last.Length - 1] == (byte)'\r';
}

protected virtual async ValueTask CleanupAsync(Exception? error = null, CancellationToken cancellationToken = default)
{
LogTransportShuttingDown(Name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public Task<ITransport> ConnectAsync(CancellationToken cancellationToken = defau
return Task.FromResult<ITransport>(new StreamClientSessionTransport(
_serverInput,
_serverOutput,
encoding: null,
"Client (stream)",
_loggerFactory));
}
Expand Down
Loading