Skip to content

Commit 590c6c3

Browse files
committed
fix: accept header accessor function in ServerTransportSecurityValidator
Replace Map<String, List<String>> parameter with Function<String, List<String>> in validateHeaders(), allowing callers to pass a header accessor instead of extracting all headers upfront. This is more efficient (only requested headers are looked up) and delegates case-insensitive header matching to the underlying request implementation (e.g. HttpServletRequest.getHeaders). - Update DefaultServerTransportSecurityValidator to use the accessor directly for Origin and Host headers - Update all three servlet transport providers to pass name -> Collections.list(request.getHeaders(name)) - Remove HttpServletRequestUtils (no longer needed) - Update unit tests to use accessor-based API Closes #870
1 parent accba74 commit 590c6c3

7 files changed

Lines changed: 121 additions & 161 deletions

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import java.util.ArrayList;
88
import java.util.List;
9-
import java.util.Map;
9+
import java.util.function.Function;
1010

1111
import io.modelcontextprotocol.util.Assert;
1212

@@ -47,27 +47,18 @@ private DefaultServerTransportSecurityValidator(List<String> allowedOrigins, Lis
4747
}
4848

4949
@Override
50-
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
51-
boolean missingHost = true;
52-
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
53-
if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
54-
List<String> values = entry.getValue();
55-
if (values == null || values.isEmpty()) {
56-
throw new ServerTransportSecurityException(403, "Invalid Origin header");
57-
}
58-
validateOrigin(values.get(0));
59-
}
60-
else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) {
61-
missingHost = false;
62-
List<String> values = entry.getValue();
63-
if (values == null || values.isEmpty()) {
64-
throw new ServerTransportSecurityException(421, "Invalid Host header");
65-
}
66-
validateHost(values.get(0));
67-
}
50+
public void validateHeaders(Function<String, List<String>> headerAccessor) throws ServerTransportSecurityException {
51+
List<String> originValues = headerAccessor.apply(ORIGIN_HEADER);
52+
if (originValues != null && !originValues.isEmpty()) {
53+
validateOrigin(originValues.get(0));
6854
}
69-
if (!allowedHosts.isEmpty() && missingHost) {
70-
throw new ServerTransportSecurityException(421, "Invalid Host header");
55+
56+
if (!allowedHosts.isEmpty()) {
57+
List<String> hostValues = headerAccessor.apply(HOST_HEADER);
58+
if (hostValues == null || hostValues.isEmpty()) {
59+
throw new ServerTransportSecurityException(421, "Invalid Host header");
60+
}
61+
validateHost(hostValues.get(0));
7162
}
7263
}
7364

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java

Lines changed: 0 additions & 40 deletions
This file was deleted.

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.io.IOException;
99
import java.io.PrintWriter;
1010
import java.time.Duration;
11+
import java.util.Collections;
1112
import java.util.List;
1213
import java.util.Map;
1314
import java.util.UUID;
@@ -280,8 +281,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
280281
}
281282

282283
try {
283-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
284-
this.securityValidator.validateHeaders(headers);
284+
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
285285
}
286286
catch (ServerTransportSecurityException e) {
287287
response.sendError(e.getStatusCode(), e.getMessage());
@@ -353,8 +353,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
353353
}
354354

355355
try {
356-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
357-
this.securityValidator.validateHeaders(headers);
356+
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
358357
}
359358
catch (ServerTransportSecurityException e) {
360359
response.sendError(e.getStatusCode(), e.getMessage());

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import java.io.BufferedReader;
88
import java.io.IOException;
99
import java.io.PrintWriter;
10+
import java.util.Collections;
1011
import java.util.List;
11-
import java.util.Map;
1212

1313
import org.slf4j.Logger;
1414
import org.slf4j.LoggerFactory;
@@ -134,8 +134,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
134134
}
135135

136136
try {
137-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
138-
this.securityValidator.validateHeaders(headers);
137+
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
139138
}
140139
catch (ServerTransportSecurityException e) {
141140
response.sendError(e.getStatusCode(), e.getMessage());

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.io.PrintWriter;
1010
import java.time.Duration;
1111
import java.util.ArrayList;
12+
import java.util.Collections;
1213
import java.util.List;
1314
import java.util.Map;
1415
import java.util.concurrent.ConcurrentHashMap;
@@ -271,8 +272,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
271272
}
272273

273274
try {
274-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
275-
this.securityValidator.validateHeaders(headers);
275+
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
276276
}
277277
catch (ServerTransportSecurityException e) {
278278
response.sendError(e.getStatusCode(), e.getMessage());
@@ -407,8 +407,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
407407
}
408408

409409
try {
410-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
411-
this.securityValidator.validateHeaders(headers);
410+
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
412411
}
413412
catch (ServerTransportSecurityException e) {
414413
response.sendError(e.getStatusCode(), e.getMessage());
@@ -588,8 +587,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
588587
}
589588

590589
try {
591-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
592-
this.securityValidator.validateHeaders(headers);
590+
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
593591
}
594592
catch (ServerTransportSecurityException e) {
595593
response.sendError(e.getStatusCode(), e.getMessage());

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
package io.modelcontextprotocol.server.transport;
66

77
import java.util.List;
8-
import java.util.Map;
8+
import java.util.function.Function;
99

1010
/**
1111
* Interface for validating HTTP requests in server transports. Implementations can
@@ -22,15 +22,16 @@ public interface ServerTransportSecurityValidator {
2222
/**
2323
* A no-op validator that accepts all requests without validation.
2424
*/
25-
ServerTransportSecurityValidator NOOP = headers -> {
25+
ServerTransportSecurityValidator NOOP = headerAccessor -> {
2626
};
2727

2828
/**
2929
* Validates the HTTP headers from an incoming request.
30-
* @param headers A map of header names to their values (multi-valued headers
31-
* supported)
30+
* @param headerAccessor A function that returns the list of values for a given header
31+
* name, or an empty list if the header is not present. Header name lookup should be
32+
* case-insensitive.
3233
* @throws ServerTransportSecurityException if validation fails
3334
*/
34-
void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException;
35+
void validateHeaders(Function<String, List<String>> headerAccessor) throws ServerTransportSecurityException;
3536

3637
}

0 commit comments

Comments
 (0)