Skip to content

Commit 83c796c

Browse files
committed
refactor: add deprecation path for validateHeaders API
Instead of a breaking change, introduce a deprecation path: - Keep validateHeaders(Map) as deprecated default method that bridges to the new validateHeaders(Function) via case-insensitive lookup - New implementations override validateHeaders(Function) for efficiency - Transport providers continue calling validateHeaders(Map) for backward compatibility with existing custom implementations - Restore HttpServletRequestUtils for header extraction in transports - Add tests for deprecated Map-based API and interface bridge behavior
1 parent 590c6c3 commit 83c796c

6 files changed

Lines changed: 204 additions & 14 deletions

File tree

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2026-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server.transport;
6+
7+
import java.util.Collections;
8+
import java.util.Enumeration;
9+
import java.util.HashMap;
10+
import java.util.List;
11+
import java.util.Map;
12+
13+
import jakarta.servlet.http.HttpServletRequest;
14+
15+
/**
16+
* Utility methods for working with {@link HttpServletRequest}. For internal use only.
17+
*
18+
* @author Daniel Garnier-Moiroux
19+
*/
20+
final class HttpServletRequestUtils {
21+
22+
private HttpServletRequestUtils() {
23+
}
24+
25+
/**
26+
* Extracts all headers from the HTTP request into a map.
27+
* @param request The HTTP servlet request
28+
* @return A map of header names to their values
29+
*/
30+
static Map<String, List<String>> extractHeaders(HttpServletRequest request) {
31+
Map<String, List<String>> headers = new HashMap<>();
32+
Enumeration<String> names = request.getHeaderNames();
33+
while (names.hasMoreElements()) {
34+
String name = names.nextElement();
35+
headers.put(name, Collections.list(request.getHeaders(name)));
36+
}
37+
return headers;
38+
}
39+
40+
}

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import java.io.IOException;
99
import java.io.PrintWriter;
1010
import java.time.Duration;
11-
import java.util.Collections;
1211
import java.util.List;
1312
import java.util.Map;
1413
import java.util.UUID;
@@ -281,7 +280,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
281280
}
282281

283282
try {
284-
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
283+
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
284+
this.securityValidator.validateHeaders(headers);
285285
}
286286
catch (ServerTransportSecurityException e) {
287287
response.sendError(e.getStatusCode(), e.getMessage());
@@ -353,7 +353,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
353353
}
354354

355355
try {
356-
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
356+
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
357+
this.securityValidator.validateHeaders(headers);
357358
}
358359
catch (ServerTransportSecurityException e) {
359360
response.sendError(e.getStatusCode(), e.getMessage());

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import java.io.BufferedReader;
88
import java.io.IOException;
99
import java.io.PrintWriter;
10-
import java.util.Collections;
1110
import java.util.List;
11+
import java.util.Map;
1212

1313
import org.slf4j.Logger;
1414
import org.slf4j.LoggerFactory;
@@ -134,7 +134,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
134134
}
135135

136136
try {
137-
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
137+
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
138+
this.securityValidator.validateHeaders(headers);
138139
}
139140
catch (ServerTransportSecurityException e) {
140141
response.sendError(e.getStatusCode(), e.getMessage());

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import java.io.PrintWriter;
1010
import java.time.Duration;
1111
import java.util.ArrayList;
12-
import java.util.Collections;
1312
import java.util.List;
1413
import java.util.Map;
1514
import java.util.concurrent.ConcurrentHashMap;
@@ -272,7 +271,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
272271
}
273272

274273
try {
275-
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
274+
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
275+
this.securityValidator.validateHeaders(headers);
276276
}
277277
catch (ServerTransportSecurityException e) {
278278
response.sendError(e.getStatusCode(), e.getMessage());
@@ -407,7 +407,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
407407
}
408408

409409
try {
410-
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
410+
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
411+
this.securityValidator.validateHeaders(headers);
411412
}
412413
catch (ServerTransportSecurityException e) {
413414
response.sendError(e.getStatusCode(), e.getMessage());
@@ -587,7 +588,8 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
587588
}
588589

589590
try {
590-
this.securityValidator.validateHeaders(name -> Collections.list(request.getHeaders(name)));
591+
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
592+
this.securityValidator.validateHeaders(headers);
591593
}
592594
catch (ServerTransportSecurityException e) {
593595
response.sendError(e.getStatusCode(), e.getMessage());

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,69 @@
55
package io.modelcontextprotocol.server.transport;
66

77
import java.util.List;
8+
import java.util.Map;
89
import java.util.function.Function;
910

1011
/**
1112
* Interface for validating HTTP requests in server transports. Implementations can
1213
* validate Origin headers, Host headers, or any other security-related headers according
1314
* to the MCP specification.
1415
*
16+
* <p>
17+
* New implementations should override {@link #validateHeaders(Function)
18+
* validateHeaders(Function)} for more efficient, case-insensitive header access. The
19+
* older {@link #validateHeaders(Map) validateHeaders(Map)} is deprecated and will be
20+
* removed in a future major version.
21+
*
1522
* @author Daniel Garnier-Moiroux
1623
* @see DefaultServerTransportSecurityValidator
1724
* @see ServerTransportSecurityException
1825
*/
19-
@FunctionalInterface
2026
public interface ServerTransportSecurityValidator {
2127

2228
/**
2329
* A no-op validator that accepts all requests without validation.
2430
*/
25-
ServerTransportSecurityValidator NOOP = headerAccessor -> {
31+
ServerTransportSecurityValidator NOOP = new ServerTransportSecurityValidator() {
2632
};
2733

2834
/**
2935
* Validates the HTTP headers from an incoming request.
36+
*
37+
* <p>
38+
* The default implementation converts the map into a case-insensitive header accessor
39+
* and delegates to {@link #validateHeaders(Function)}.
40+
* @param headers A map of header names to their values (multi-valued headers
41+
* supported)
42+
* @throws ServerTransportSecurityException if validation fails
43+
* @deprecated Use {@link #validateHeaders(Function)} instead for more efficient,
44+
* case-insensitive header access. This method will be removed in a future major
45+
* version.
46+
*/
47+
@Deprecated
48+
default void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
49+
validateHeaders(name -> headers.entrySet()
50+
.stream()
51+
.filter(e -> e.getKey().equalsIgnoreCase(name))
52+
.map(Map.Entry::getValue)
53+
.findFirst()
54+
.orElse(List.of()));
55+
}
56+
57+
/**
58+
* Validates the HTTP headers from an incoming request using a header accessor
59+
* function.
60+
*
61+
* <p>
62+
* New implementations should override this method. Header name lookup through the
63+
* accessor should be case-insensitive (e.g., when backed by
64+
* {@code HttpServletRequest.getHeaders}).
3065
* @param headerAccessor A function that returns the list of values for a given header
31-
* name, or an empty list if the header is not present. Header name lookup should be
32-
* case-insensitive.
66+
* name, or an empty list if the header is not present.
3367
* @throws ServerTransportSecurityException if validation fails
3468
*/
35-
void validateHeaders(Function<String, List<String>> headerAccessor) throws ServerTransportSecurityException;
69+
default void validateHeaders(Function<String, List<String>> headerAccessor)
70+
throws ServerTransportSecurityException {
71+
}
3672

3773
}

mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,116 @@ void originValidHostMissing() {
404404

405405
}
406406

407+
@Nested
408+
class DeprecatedMapBasedApi {
409+
410+
@Test
411+
void originValidation() {
412+
Map<String, List<String>> headers = new HashMap<>();
413+
headers.put("Origin", List.of("http://localhost:8080"));
414+
415+
assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
416+
}
417+
418+
@Test
419+
void originRejected() {
420+
Map<String, List<String>> headers = new HashMap<>();
421+
headers.put("Origin", List.of("http://malicious.example.com"));
422+
423+
assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
424+
}
425+
426+
@Test
427+
void caseInsensitiveHeaderLookup() {
428+
Map<String, List<String>> headers = new HashMap<>();
429+
headers.put("origin", List.of("http://localhost:8080"));
430+
431+
assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
432+
}
433+
434+
@Test
435+
void hostValidation() {
436+
DefaultServerTransportSecurityValidator hostValidator = DefaultServerTransportSecurityValidator.builder()
437+
.allowedHost("localhost:8080")
438+
.build();
439+
440+
Map<String, List<String>> headers = new HashMap<>();
441+
headers.put("Host", List.of("localhost:8080"));
442+
443+
assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException();
444+
}
445+
446+
@Test
447+
void hostRejected() {
448+
DefaultServerTransportSecurityValidator hostValidator = DefaultServerTransportSecurityValidator.builder()
449+
.allowedHost("localhost:8080")
450+
.build();
451+
452+
Map<String, List<String>> headers = new HashMap<>();
453+
headers.put("Host", List.of("malicious.com:8080"));
454+
455+
assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST);
456+
}
457+
458+
@Test
459+
void emptyHeaders() {
460+
assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException();
461+
}
462+
463+
@Test
464+
void combinedOriginAndHost() {
465+
DefaultServerTransportSecurityValidator combinedValidator = DefaultServerTransportSecurityValidator
466+
.builder()
467+
.allowedOrigin("http://localhost:*")
468+
.allowedHost("localhost:*")
469+
.build();
470+
471+
Map<String, List<String>> headers = new HashMap<>();
472+
headers.put("Origin", List.of("http://localhost:8080"));
473+
headers.put("Host", List.of("localhost:8080"));
474+
475+
assertThatCode(() -> combinedValidator.validateHeaders(headers)).doesNotThrowAnyException();
476+
}
477+
478+
}
479+
480+
@Nested
481+
class InterfaceDefaultBridge {
482+
483+
@Test
484+
void noopAcceptsAll() {
485+
assertThatCode(() -> ServerTransportSecurityValidator.NOOP.validateHeaders(emptyAccessor()))
486+
.doesNotThrowAnyException();
487+
assertThatCode(() -> ServerTransportSecurityValidator.NOOP.validateHeaders(new HashMap<>()))
488+
.doesNotThrowAnyException();
489+
}
490+
491+
@Test
492+
void mapDefaultBridgesToFunctionOverride() {
493+
// A validator that only overrides the Function method should still work
494+
// when called via the deprecated Map method
495+
ServerTransportSecurityValidator functionOnlyValidator = new ServerTransportSecurityValidator() {
496+
@Override
497+
public void validateHeaders(Function<String, List<String>> headerAccessor)
498+
throws ServerTransportSecurityException {
499+
List<String> origins = headerAccessor.apply("Origin");
500+
if (origins != null && !origins.isEmpty() && origins.get(0).contains("evil")) {
501+
throw new ServerTransportSecurityException(403, "Invalid Origin header");
502+
}
503+
}
504+
};
505+
506+
Map<String, List<String>> goodHeaders = new HashMap<>();
507+
goodHeaders.put("Origin", List.of("http://good.example.com"));
508+
assertThatCode(() -> functionOnlyValidator.validateHeaders(goodHeaders)).doesNotThrowAnyException();
509+
510+
Map<String, List<String>> evilHeaders = new HashMap<>();
511+
evilHeaders.put("Origin", List.of("http://evil.example.com"));
512+
assertThatThrownBy(() -> functionOnlyValidator.validateHeaders(evilHeaders)).isEqualTo(INVALID_ORIGIN);
513+
}
514+
515+
}
516+
407517
private static Function<String, List<String>> emptyAccessor() {
408518
return name -> List.of();
409519
}

0 commit comments

Comments
 (0)