Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 71 additions & 101 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,16 +435,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
// Reset the body so that it can be read later.
req.Body = io.NopCloser(bytes.NewBuffer(body))

msgs, _, err := readBatch(body)
msg, err := jsonrpc2.DecodeMessage(body)
if err == nil {
for _, msg := range msgs {
if req, ok := msg.(*jsonrpc.Request); ok {
switch req.Method {
case methodInitialize:
hasInitialize = true
case notificationInitialized:
hasInitialized = true
}
if req, ok := msg.(*jsonrpc.Request); ok {
switch req.Method {
case methodInitialize:
hasInitialize = true
case notificationInitialized:
hasInitialized = true
}
}
}
Expand Down Expand Up @@ -726,8 +724,6 @@ type stream struct {
// collected here until the stream is complete, at which point they are
// flushed as a single JSON response. Note that the non-nilness of this field
// is significant, as it signals the expected content type.
//
// Note: if we remove support for batching, this could just be a bool.
pendingJSONMessages []json.RawMessage

// w is the HTTP response writer for this stream. A non-nil w indicates
Expand All @@ -752,9 +748,6 @@ type stream struct {
// requests is the set of unanswered incoming requests for the stream.
//
// Requests are removed when their response has been received.
// In practice, there is only one request, but in the 2025-03-26 version of
// the spec and earlier there was a concept of batching, in which POST
// payloads could hold multiple requests or responses.
requests map[jsonrpc.ID]struct{}
}

Expand Down Expand Up @@ -1113,10 +1106,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
http.Error(w, "POST requires a non-empty body", http.StatusBadRequest)
return
}
// TODO(#674): once we've documented the support matrix for 2025-03-26 and
// earlier, drop support for matching entirely; that will simplify this
// logic.
incoming, isBatch, err := readBatch(body)
incoming, err := jsonrpc2.DecodeMessage(body)
if err != nil {
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
return
Expand All @@ -1127,81 +1117,65 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
protocolVersion = protocolVersion20250326
}

if isBatch && protocolVersion >= protocolVersion20250618 {
http.Error(w, fmt.Sprintf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, protocolVersion), http.StatusBadRequest)
return
}

// TODO(rfindley): no tests fail if we reject batch JSON requests entirely.
// We need to test this with older protocol versions.
// if isBatch && c.jsonResponse {
// http.Error(w, "server does not support batch requests", http.StatusBadRequest)
// return
// }

calls := make(map[jsonrpc.ID]struct{})
tokenInfo := auth.TokenInfoFromContext(req.Context())
isInitialize := false
var initializeProtocolVersion string
for _, msg := range incoming {
if jreq, ok := msg.(*jsonrpc.Request); ok {
// Preemptively check that this is a valid request, so that we can fail
// the HTTP request. If we didn't do this, a request with a bad method or
// missing ID could be silently swallowed.
if _, err := checkRequest(jreq, serverMethodInfos); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
if jreq, ok := incoming.(*jsonrpc.Request); ok {
// Preemptively check that this is a valid request, so that we can fail
// the HTTP request. If we didn't do this, a request with a bad method or
// missing ID could be silently swallowed.
if _, err := checkRequest(jreq, serverMethodInfos); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if jreq.Method == methodInitialize {
isInitialize = true
// Extract the protocol version from InitializeParams.
var params InitializeParams
if err := internaljson.Unmarshal(jreq.Params, &params); err == nil {
initializeProtocolVersion = params.ProtocolVersion
}
if jreq.Method == methodInitialize {
isInitialize = true
// Extract the protocol version from InitializeParams.
var params InitializeParams
if err := internaljson.Unmarshal(jreq.Params, &params); err == nil {
initializeProtocolVersion = params.ProtocolVersion
}
// Include metadata for all requests (including notifications).
jreq.Extra = &RequestExtra{
TokenInfo: tokenInfo,
Header: req.Header,
}
if jreq.IsCall() {
calls[jreq.ID] = struct{}{}
// See the doc for CloseSSEStream: allow the request handler to
// explicitly close the ongoing stream.
jreq.Extra.(*RequestExtra).CloseSSEStream = func(args CloseSSEStreamArgs) {
c.mu.Lock()
streamID, ok := c.requestStreams[jreq.ID]
var stream *stream
if ok {
stream = c.streams[streamID]
}
}
// Include metadata for all requests (including notifications).
jreq.Extra = &RequestExtra{
TokenInfo: tokenInfo,
Header: req.Header,
}
if jreq.IsCall() {
calls[jreq.ID] = struct{}{}
// See the doc for CloseSSEStream: allow the request handler to
// explicitly close the ongoing stream.
jreq.Extra.(*RequestExtra).CloseSSEStream = func(args CloseSSEStreamArgs) {
c.mu.Lock()
streamID, ok := c.requestStreams[jreq.ID]
var stream *stream
if ok {
stream = c.streams[streamID]
}
c.mu.Unlock()
c.mu.Unlock()

if stream != nil {
stream.close(args.RetryAfter)
}
if stream != nil {
stream.close(args.RetryAfter)
}
}
}
}

// Validate MCP standard headers (Mcp-Method, Mcp-Name)
if !isBatch && len(incoming) == 1 {
if err := validateMcpHeaders(req.Header, incoming[0]); err != nil {
resp := &jsonrpc.Response{
Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()),
}
if jreq, ok := incoming[0].(*jsonrpc.Request); ok {
resp.ID = jreq.ID
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
if data, err := jsonrpc2.EncodeMessage(resp); err == nil {
w.Write(data)
}
return
if err := validateMcpHeaders(req.Header, incoming); err != nil {
resp := &jsonrpc.Response{
Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()),
}
if jreq, ok := incoming.(*jsonrpc.Request); ok {
resp.ID = jreq.ID
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
if data, err := jsonrpc2.EncodeMessage(resp); err == nil {
w.Write(data)
}
return
}

// The prime and close events were added in protocol version 2025-11-25 (SEP-1699).
Expand All @@ -1220,15 +1194,13 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
//
// [§2.1.4]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server
if len(calls) == 0 {
for _, msg := range incoming {
select {
case c.incoming <- msg:
case <-c.done:
// The session is closing. Since we haven't yet written any data to the
// response, we can signal to the client that the session is gone.
http.Error(w, "session is closing", http.StatusNotFound)
return
}
select {
case c.incoming <- incoming:
case <-c.done:
// The session is closing. Since we haven't yet written any data to the
// response, we can signal to the client that the session is gone.
http.Error(w, "session is closing", http.StatusNotFound)
return
}
w.WriteHeader(http.StatusAccepted)
return
Expand Down Expand Up @@ -1303,19 +1275,17 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
c.mu.Unlock()

// Publish incoming messages.
for _, msg := range incoming {
select {
case c.incoming <- msg:
// Note: don't select on req.Context().Done() here, since we've already
// received the requests and may have already published a response message
// or notification. The client could resume the stream.
//
// In fact, this send could be in a separate goroutine.
case <-c.done:
// Session closed: we don't know if any data has been written, so it's
// too late to write a status code here.
return
}
select {
case c.incoming <- incoming:
// Note: don't select on req.Context().Done() here, since we've already
// received the requests and may have already published a response message
// or notification. The client could resume the stream.
//
// In fact, this send could be in a separate goroutine.
case <-c.done:
// Session closed: we don't know if any data has been written, so it's
// too late to write a status code here.
return
}

c.hangResponse(req.Context(), done)
Expand Down
43 changes: 0 additions & 43 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -966,49 +966,6 @@ func TestStreamableServerTransport(t *testing.T) {
},
wantSessions: 1,
},
{
name: "batch rejected on 2025-06-18",
requests: []streamableRequest{
initialize,
initialized,
{
method: "POST",
// Explicitly set the protocol version header
headers: http.Header{"MCP-Protocol-Version": {"2025-06-18"}},
// Two messages => batch. Expect reject.
messages: []jsonrpc.Message{
req(101, "tools/call", &CallToolParams{Name: "tool"}),
req(102, "tools/call", &CallToolParams{Name: "tool"}),
},
wantStatusCode: http.StatusBadRequest,
wantBodyContaining: "batch",
},
},
wantSessions: 1,
},
{
name: "batch accepted on 2025-03-26",
requests: []streamableRequest{
initialize,
initialized,
{
method: "POST",
headers: http.Header{"MCP-Protocol-Version": {"2025-03-26"}},
// Two messages => batch. Expect OK with two responses in order.
messages: []jsonrpc.Message{
// Note: only include one request here, because responses are not
// necessarily sorted.
req(201, "tools/call", &CallToolParams{Name: "tool"}),
req(0, "notifications/roots/list_changed", &RootsListChangedParams{}),
},
wantStatusCode: http.StatusOK,
wantMessages: []jsonrpc.Message{
resp(201, &CallToolResult{Content: []Content{}}, nil),
},
},
},
wantSessions: 1,
},
{
name: "tool notification",
tool: func(t *testing.T, ctx context.Context, req *CallToolRequest) {
Expand Down
Loading