Skip to content

Commit 34b2ab7

Browse files
committed
fix: retain streamable HTTP session on response write failure
1 parent c09ee67 commit 34b2ab7

2 files changed

Lines changed: 132 additions & 1 deletion

File tree

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,6 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, String messageId
755755
}
756756
catch (Exception e) {
757757
logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
758-
HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
759758
this.asyncContext.complete();
760759
}
761760
finally {
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright 2024-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server;
6+
7+
import java.io.IOException;
8+
import java.io.PrintWriter;
9+
import java.io.StringWriter;
10+
import java.nio.charset.StandardCharsets;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
15+
import io.modelcontextprotocol.spec.HttpHeaders;
16+
import io.modelcontextprotocol.spec.McpSchema;
17+
import io.modelcontextprotocol.spec.ProtocolVersions;
18+
import org.junit.jupiter.api.Test;
19+
import org.junit.jupiter.api.Timeout;
20+
21+
import org.springframework.mock.web.MockHttpServletRequest;
22+
import org.springframework.mock.web.MockHttpServletResponse;
23+
24+
import static io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider.APPLICATION_JSON;
25+
import static io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider.TEXT_EVENT_STREAM;
26+
import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER;
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
29+
@Timeout(15)
30+
class HttpServletStreamableSessionFailureTests {
31+
32+
private static final String MCP_ENDPOINT = "/mcp";
33+
34+
@Test
35+
void postStreamWriteFailureShouldNotRemoveSession() throws Exception {
36+
HttpServletStreamableServerTransportProvider transport = HttpServletStreamableServerTransportProvider.builder()
37+
.mcpEndpoint(MCP_ENDPOINT)
38+
.build();
39+
40+
var tool = McpSchema.Tool.builder("test-tool").description("Test tool").build();
41+
var toolSpecification = McpServerFeatures.SyncToolSpecification.builder()
42+
.tool(tool)
43+
.callHandler((exchange, request) -> McpSchema.CallToolResult.builder()
44+
.content(List.of(McpSchema.TextContent.builder("tool response").build()))
45+
.isError(false)
46+
.build())
47+
.build();
48+
var server = McpServer.sync(transport)
49+
.serverInfo("test-server", "1.0.0")
50+
.capabilities(McpSchema.ServerCapabilities.builder().tools(false).build())
51+
.tools(toolSpecification)
52+
.build();
53+
54+
try {
55+
MockHttpServletResponse initializeResponse = new MockHttpServletResponse();
56+
transport.service(postRequest(initializeRequest(), null), initializeResponse);
57+
58+
String sessionId = initializeResponse.getHeader(HttpHeaders.MCP_SESSION_ID);
59+
assertThat(sessionId).isNotBlank();
60+
61+
CheckErrorResponse failedWriteResponse = new CheckErrorResponse();
62+
transport.service(postRequest(toolCallRequest("first-call"), sessionId), failedWriteResponse);
63+
64+
assertThat(failedWriteResponse.getWrittenContent()).contains("event: message");
65+
66+
MockHttpServletResponse subsequentResponse = new MockHttpServletResponse();
67+
transport.service(postRequest(toolCallRequest("second-call"), sessionId), subsequentResponse);
68+
69+
assertThat(subsequentResponse.getStatus()).isNotEqualTo(404);
70+
assertThat(subsequentResponse.getContentAsString()).doesNotContain("Session not found");
71+
}
72+
finally {
73+
server.close();
74+
transport.closeGracefully().block();
75+
}
76+
}
77+
78+
private static MockHttpServletRequest postRequest(McpSchema.JSONRPCMessage message, String sessionId)
79+
throws IOException {
80+
MockHttpServletRequest request = new MockHttpServletRequest("POST", MCP_ENDPOINT);
81+
byte[] content = JSON_MAPPER.writeValueAsBytes(message);
82+
request.setContent(content);
83+
request.setCharacterEncoding(StandardCharsets.UTF_8.name());
84+
request.addHeader("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM);
85+
request.addHeader("Content-Type", APPLICATION_JSON);
86+
request.addHeader("Content-Length", Integer.toString(content.length));
87+
request.addHeader(HttpHeaders.PROTOCOL_VERSION, ProtocolVersions.MCP_2025_11_25);
88+
request.setAsyncSupported(true);
89+
if (sessionId != null) {
90+
request.addHeader(HttpHeaders.MCP_SESSION_ID, sessionId);
91+
}
92+
return request;
93+
}
94+
95+
private static McpSchema.JSONRPCRequest initializeRequest() {
96+
var clientInfo = McpSchema.Implementation.builder("test-client", "1.0.0").build();
97+
var initializeRequest = McpSchema.InitializeRequest
98+
.builder(ProtocolVersions.MCP_2025_11_25, McpSchema.ClientCapabilities.builder().build(), clientInfo)
99+
.build();
100+
return new McpSchema.JSONRPCRequest(McpSchema.METHOD_INITIALIZE, "init", initializeRequest);
101+
}
102+
103+
private static McpSchema.JSONRPCRequest toolCallRequest(String id) {
104+
var callToolRequest = McpSchema.CallToolRequest.builder("test-tool").arguments(Map.of()).build();
105+
return new McpSchema.JSONRPCRequest(McpSchema.METHOD_TOOLS_CALL, id, callToolRequest);
106+
}
107+
108+
private static final class CheckErrorResponse extends MockHttpServletResponse {
109+
110+
private final StringWriter content = new StringWriter();
111+
112+
private final PrintWriter writer = new PrintWriter(this.content) {
113+
@Override
114+
public boolean checkError() {
115+
super.checkError();
116+
return true;
117+
}
118+
};
119+
120+
@Override
121+
public PrintWriter getWriter() {
122+
return this.writer;
123+
}
124+
125+
String getWrittenContent() {
126+
this.writer.flush();
127+
return this.content.toString();
128+
}
129+
130+
}
131+
132+
}

0 commit comments

Comments
 (0)