Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 160 additions & 30 deletions crates/rust-mcp-transport/src/mcp_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ use tokio::{
};

const CHANNEL_CAPACITY: usize = 36;
// Maximum size (in bytes) of a single newline-delimited incoming message.
// A peer cannot force unbounded buffering: longer lines are dropped.
const MAX_LINE_LENGTH: usize = 4 * 1024 * 1024;

pub struct MCPStream {}

Expand Down Expand Up @@ -107,41 +110,47 @@ impl MCPStream {
X: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
{
tokio::spawn(async move {
let mut lines_stream = BufReader::new(readable).lines();
let mut reader = BufReader::new(readable);

loop {
tokio::select! {
_ = cancellation_token.cancelled() =>
{
break;
_ = cancellation_token.cancelled() => {
break;
},

line = lines_stream.next_line() =>{
match line {
Ok(Some(line)) => {
tracing::trace!("raw payload: {}",line);

// deserialize and send it to the stream
let message: X = match serde_json::from_str(&line){
Ok(mcp_message) => mcp_message,
Err(_) => {
// continue if malformed message is received
continue;
},
};

tx.send(message).await.map_err(GenericSendError::new)?;
}
Ok(None) => {
// EOF reached, exit loop
break;
}
Err(e) => {
// Handle error in reading from readable_std
return Err(TransportError::ProcessError(format!(
"Error reading from readable_std: {e}"
)));
}
result = read_capped_line(&mut reader, MAX_LINE_LENGTH) => {
match result {
Ok(LineRead::Eof) => {
// EOF reached, exit loop
break;
}
Ok(LineRead::TooLong) => {
// Drop the oversized message and keep the stream alive.
tracing::warn!(
"dropping incoming message exceeding {MAX_LINE_LENGTH} bytes"
);
continue;
}
Ok(LineRead::Line(line)) => {
tracing::trace!("raw payload: {}", line);

// deserialize and send it to the stream
let message: X = match serde_json::from_str(&line) {
Ok(mcp_message) => mcp_message,
Err(_) => {
// continue if malformed message is received
continue;
}
};

tx.send(message).await.map_err(GenericSendError::new)?;
}
Err(e) => {
// Handle error in reading from readable_std
return Err(TransportError::ProcessError(format!(
"Error reading from readable_std: {e}"
)));
}
}
}
}
Expand All @@ -150,3 +159,124 @@ impl MCPStream {
})
}
}

/// Outcome of reading a single newline-delimited line with a size cap.
enum LineRead {
/// A complete line (newline stripped) within the size cap.
Line(String),
/// The line exceeded the cap and was discarded up to the next newline.
TooLong,
/// The underlying reader reached end-of-file.
Eof,
}

/// Reads a single newline-delimited line, buffering at most `max` bytes.
///
/// Unlike `AsyncBufReadExt::lines`, a peer cannot force unbounded buffering: a
/// line longer than `max` is discarded (consumed up to the next newline) and
/// reported as [`LineRead::TooLong`] so the caller can drop it and continue.
async fn read_capped_line<R>(reader: &mut R, max: usize) -> std::io::Result<LineRead>
where
R: tokio::io::AsyncBufRead + Unpin,
{
let mut buf: Vec<u8> = Vec::new();

loop {
let chunk = reader.fill_buf().await?;

if chunk.is_empty() {
if buf.is_empty() {
return Ok(LineRead::Eof);
}
// EOF without a trailing newline: emit the final partial line.
return Ok(LineRead::Line(line_to_string(buf)));
}

if let Some(pos) = chunk.iter().position(|&b| b == b'\n') {
let consumed = pos + 1;
if buf.len() + pos > max {
reader.consume(consumed);
return Ok(LineRead::TooLong);
}
buf.extend_from_slice(&chunk[..pos]);
reader.consume(consumed);
return Ok(LineRead::Line(line_to_string(buf)));
}

let len = chunk.len();
if buf.len() + len > max {
reader.consume(len);
discard_to_newline(&mut *reader).await?;
return Ok(LineRead::TooLong);
}
buf.extend_from_slice(chunk);
reader.consume(len);
}
}

/// Consumes bytes until (and including) the next newline, or EOF.
async fn discard_to_newline<R>(reader: &mut R) -> std::io::Result<()>
where
R: tokio::io::AsyncBufRead + Unpin,
{
loop {
let chunk = reader.fill_buf().await?;
if chunk.is_empty() {
return Ok(());
}
if let Some(pos) = chunk.iter().position(|&b| b == b'\n') {
reader.consume(pos + 1);
return Ok(());
}
let len = chunk.len();
reader.consume(len);
}
}

/// Converts line bytes to a `String`, stripping a trailing carriage return.
fn line_to_string(mut buf: Vec<u8>) -> String {
if buf.last() == Some(&b'\r') {
buf.pop();
}
String::from_utf8_lossy(&buf).into_owned()
}

#[cfg(test)]
mod tests {
use super::*;
use tokio::io::BufReader;

async fn collect_lines(data: &[u8], max: usize) -> Vec<Result<String, &'static str>> {
let mut reader = BufReader::new(data);
let mut out = Vec::new();
loop {
match read_capped_line(&mut reader, max).await.unwrap() {
LineRead::Eof => break,
LineRead::TooLong => out.push(Err("too-long")),
LineRead::Line(line) => out.push(Ok(line)),
}
}
out
}

#[tokio::test]
async fn reads_newline_delimited_lines() {
let out = collect_lines(b"hello\r\nworld\n", 1024).await;
assert_eq!(out, vec![Ok("hello".to_string()), Ok("world".to_string())]);
}

#[tokio::test]
async fn emits_final_line_without_trailing_newline() {
let out = collect_lines(b"tail", 1024).await;
assert_eq!(out, vec![Ok("tail".to_string())]);
}

#[tokio::test]
async fn drops_oversized_line_and_resyncs() {
let mut data = vec![b'a'; 100];
data.push(b'\n');
data.extend_from_slice(b"ok\n");
let out = collect_lines(&data, 10).await;
assert_eq!(out, vec![Err("too-long"), Ok("ok".to_string())]);
}
}
Loading