Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.springframework.util.PropertyPlaceholderHelper;
import org.springframework.util.PropertyPlaceholderHelper.PlaceholderResolver;
import org.springframework.util.StringUtils;
import java.util.function.Predicate;

/**
* A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a
Expand Down Expand Up @@ -73,6 +74,8 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH

private @Nullable MessageHeaderInitializer headerInitializer;

private @Nullable Predicate<String> headerFilter;


public SendToMethodReturnValueHandler(SimpMessageSendingOperations messagingTemplate, boolean annotationRequired) {
Assert.notNull(messagingTemplate, "'messagingTemplate' must not be null");
Expand Down Expand Up @@ -133,6 +136,27 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
return this.headerInitializer;
}

/**
* Add a filter to determine which headers from the input message should be propagated to the output message.
* Multiple filters are combined with logical OR.
* <p>If not set, no input headers are propagated (default behavior).</p>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of such a warning, which requires further investigation on its own, how about applying header propagation first, and then setting standard headers (like session id) second, so there is no possibility for overwriting.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I’ve removed the warning and updated the order so propagated headers are applied first, followed by standard headers, to avoid any possibility of overwriting.

*/
public void addHeaderFilter(Predicate<String> filter) {
Assert.notNull(filter, "Filter predicate must not be null");
if (this.headerFilter == null) {
this.headerFilter = filter;
} else {
this.headerFilter = this.headerFilter.or(filter);
}
}

/**
* Return the configured header filter.
*/
public @Nullable Predicate<String> getHeaderFilter() {
return this.headerFilter;
}


@Override
public boolean supportsReturnType(MethodParameter returnType) {
Expand Down Expand Up @@ -171,11 +195,11 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
destination = destinationHelper.expandTemplateVars(destination);
if (broadcast) {
this.messagingTemplate.convertAndSendToUser(
user, destination, returnValue, createHeaders(null, returnType));
user, destination, returnValue, createHeaders(null, returnType, message));
}
else {
this.messagingTemplate.convertAndSendToUser(
user, destination, returnValue, createHeaders(sessionId, returnType));
user, destination, returnValue, createHeaders(sessionId, returnType, message));
}
}
}
Expand All @@ -185,7 +209,7 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix);
for (String destination : destinations) {
destination = destinationHelper.expandTemplateVars(destination);
this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType));
this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType, message));
}
}
}
Expand Down Expand Up @@ -234,11 +258,22 @@ protected String[] getTargetDestinations(@Nullable Annotation annotation, Messag
new String[] {defaultPrefix + destination} : new String[] {defaultPrefix + '/' + destination});
}

private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType) {
private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType, @Nullable Message<?> inputMessage) {
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(headerAccessor);
}

if (inputMessage != null && headerFilter != null) {
Map<String, Object> inputHeaders = inputMessage.getHeaders();
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
String name = entry.getKey();
if (headerFilter.test(name)) {
headerAccessor.setHeader(name, entry.getValue());
}
}
}

if (sessionId != null) {
headerAccessor.setSessionId(sessionId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package org.springframework.messaging.simp.annotation.support;

import java.util.Map;
import java.util.function.Predicate;

import org.apache.commons.logging.Log;
import org.jspecify.annotations.Nullable;

Expand Down Expand Up @@ -65,6 +68,8 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn

private @Nullable MessageHeaderInitializer headerInitializer;

private @Nullable Predicate<String> headerFilter;


/**
* Construct a new SubscriptionMethodReturnValueHandler.
Expand Down Expand Up @@ -93,6 +98,27 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
return this.headerInitializer;
}

/**
* Add a filter to determine which headers from the input message should be propagated to the output message.
* Multiple filters are combined with logical OR.
* <p>If not set, no input headers are propagated (default behavior).</p>
*/
public void addHeaderFilter(Predicate<String> filter) {
Assert.notNull(filter, "Filter predicate must not be null");
if (this.headerFilter == null) {
this.headerFilter = filter;
} else {
this.headerFilter = this.headerFilter.or(filter);
}
}

/**
* Return the configured header filter.
*/
public @Nullable Predicate<String> getHeaderFilter() {
return this.headerFilter;
}


@Override
public boolean supportsReturnType(MethodParameter returnType) {
Expand Down Expand Up @@ -126,15 +152,26 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
if (logger.isDebugEnabled()) {
logger.debug("Reply to @SubscribeMapping: " + returnValue);
}
MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType);
MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType, message);
this.messagingTemplate.convertAndSend(destination, returnValue, headersToSend);
}

private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType) {
private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType, @Nullable Message<?> inputMessage) {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(accessor);
}

if (inputMessage != null && headerFilter != null) {
Map<String, Object> inputHeaders = inputMessage.getHeaders();
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
String name = entry.getKey();
if (headerFilter.test(name)) {
accessor.setHeader(name, entry.getValue());
}
}
}

if (sessionId != null) {
accessor.setSessionId(sessionId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,60 @@ public void sendToUserWithSendToOverride() throws Exception {
assertResponse(parameter, sessionId, 1, "/dest4");
}

@Test
void sendToWithHeaderFilterSinglePredicate() throws Exception {
given(this.messageChannel.send(any(Message.class))).willReturn(true);

String sessionId = "sess1";
String customHeaderName = "x-custom-header";
String customHeaderValue = "custom-value";
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
inputMessage = MessageBuilder.fromMessage(inputMessage)
.setHeader(customHeaderName, customHeaderValue)
.build();

SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
handler.addHeaderFilter(name -> name.equals(customHeaderName));

handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);

verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
for (Message<?> sent : this.messageCaptor.getAllValues()) {
MessageHeaders headers = sent.getHeaders();
assertThat(headers.get(customHeaderName)).isEqualTo(customHeaderValue);
}
}

@Test
void sendToWithHeaderFilterMultiplePredicates() throws Exception {
given(this.messageChannel.send(any(Message.class))).willReturn(true);

String sessionId = "sess1";
String headerA = "x-header-a";
String headerB = "x-header-b";
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
inputMessage = MessageBuilder.fromMessage(inputMessage)
.setHeader(headerA, "A-value")
.setHeader(headerB, "B-value")
.build();

SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB));

handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);

verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
for (Message<?> sent : this.messageCaptor.getAllValues()) {
MessageHeaders headers = sent.getHeaders();
assertThat(headers.get(headerA)).isEqualTo("A-value");
assertThat(headers.get(headerB)).isEqualTo("B-value");
}
}


private void assertResponse(MethodParameter methodParameter, String sessionId,
int index, String destination) {
int index, String destination) {

SimpMessageHeaderAccessor accessor = getCapturedAccessor(index);
assertThat(accessor.getSessionId()).isEqualTo(sessionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,65 @@ void testJsonView() throws Exception {
assertThat(new String((byte[]) message.getPayload(), StandardCharsets.UTF_8)).isEqualTo("{\"withView1\":\"with\"}");
}

@Test
void testHeaderFilterSinglePredicate() throws Exception {
String sessionId = "sess1";
String subscriptionId = "subs1";
String destination = "/dest";
String customHeaderName = "x-custom-header";
String customHeaderValue = "custom-value";
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
.setHeader(customHeaderName, customHeaderValue)
.build();

MessageSendingOperations messagingTemplate = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);

handler.addHeaderFilter(name -> name.equals(customHeaderName));

handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);

ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());

MessageHeaders sentHeaders = captor.getValue();
assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue);
}

@Test
void testHeaderFilterMultiplePredicates() throws Exception {
String sessionId = "sess1";
String subscriptionId = "subs1";
String destination = "/dest";
String headerA = "x-header-a";
String headerB = "x-header-b";
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
.setHeader(headerA, "A-value")
.setHeader(headerB, "B-value")
.build();

MessageSendingOperations messagingTemplate = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);

handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB));

handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);

ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());

MessageHeaders sentHeaders = captor.getValue();
assertThat(sentHeaders.get(headerA)).isEqualTo("A-value");
assertThat(sentHeaders.get(headerB)).isEqualTo("B-value");
}


private Message<?> createInputMessage(String sessId, String subsId, String dest, Principal principal) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();
Expand Down