Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;

import javax.net.ssl.SSLException;
Expand All @@ -43,7 +45,6 @@
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.ClientCalls;
import io.netty.buffer.PooledByteBufAllocator;
import org.jspecify.annotations.Nullable;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
Expand All @@ -66,7 +67,6 @@
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.codec.json.JacksonJsonDecoder;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.web.server.ServerWebExchange;
Expand All @@ -89,6 +89,8 @@ public class JsonToGrpcGatewayFilterFactory

private final ResourceLoader resourceLoader;

private final ConcurrentMap<String, ManagedChannel> managedChannelCache = new ConcurrentHashMap<>();

public JsonToGrpcGatewayFilterFactory(GrpcSslConfigurer grpcSslConfigurer, ResourceLoader resourceLoader) {
super(Config.class);
this.grpcSslConfigurer = grpcSslConfigurer;
Expand All @@ -102,10 +104,12 @@ public List<String> shortcutFieldOrder() {

@Override
public GatewayFilter apply(Config config) {
GrpcCallContext callContext = new GrpcCallContext(config);

GatewayFilter filter = new GatewayFilter() {
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
GRPCResponseDecorator modifiedResponse = new GRPCResponseDecorator(exchange, config);
GRPCResponseDecorator modifiedResponse = new GRPCResponseDecorator(exchange, callContext);

ServerWebExchangeUtils.setAlreadyRouted(exchange);
return modifiedResponse.writeWith(exchange.getRequest().getBody())
Expand All @@ -122,6 +126,19 @@ public String toString() {
return new OrderedGatewayFilter(filter, order);
}

private ManagedChannel createChannelChannel(String host, int port) {
String key = host + ":" + port;
return managedChannelCache.computeIfAbsent(key, k -> {
NettyChannelBuilder builder = NettyChannelBuilder.forAddress(host, port);
try {
return grpcSslConfigurer.configureSsl(builder);
}
catch (SSLException e) {
throw new RuntimeException(e);
}
});
}

public static class Config {

private @Nullable String protoDescriptor;
Expand Down Expand Up @@ -159,68 +176,55 @@ public Config setMethod(String method) {

}

class GRPCResponseDecorator extends ServerHttpResponseDecorator {
class GrpcCallContext {

private final ServerWebExchange exchange;
final Descriptors.Descriptor descriptor;

private final Descriptors.Descriptor descriptor;
final MethodDescriptor<DynamicMessage, DynamicMessage> methodDescriptor;

private final ObjectReader objectReader;
final ObjectMapper objectMapper;

private final ClientCall<DynamicMessage, DynamicMessage> clientCall;
final ObjectReader objectReader;

private final ObjectNode objectNode;
final ObjectNode objectNode;

GRPCResponseDecorator(ServerWebExchange exchange, Config config) {
super(exchange.getResponse());
this.exchange = exchange;
try {
Descriptors.MethodDescriptor methodDescriptor = getMethodDescriptor(config);
Descriptors.ServiceDescriptor serviceDescriptor = methodDescriptor.getService();
Descriptors.Descriptor outputType = methodDescriptor.getOutputType();
this.descriptor = methodDescriptor.getInputType();
final JsonFormat.Parser jsonToProtoParser;

clientCall = createClientCallForType(config, serviceDescriptor, outputType);
final JsonFormat.Printer protoToJsonPrinter;

ObjectMapper objectMapper = JsonMapper.builder()
.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false)
.build();
GrpcCallContext(Config config) {
try {
Descriptors.MethodDescriptor protoMethodDescriptor = getMethodDescriptor(config);
Descriptors.ServiceDescriptor serviceDescriptor = protoMethodDescriptor.getService();
Descriptors.Descriptor outputType = protoMethodDescriptor.getOutputType();
this.descriptor = protoMethodDescriptor.getInputType();

MethodDescriptor.Marshaller<DynamicMessage> marshaller = ProtoUtils
.marshaller(DynamicMessage.newBuilder(outputType).build());

methodDescriptor = MethodDescriptor
.<DynamicMessage, DynamicMessage>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName(
MethodDescriptor.generateFullMethodName(serviceDescriptor.getFullName(), config.getMethod()))
.setRequestMarshaller(marshaller)
.setResponseMarshaller(marshaller)
.build();

jsonToProtoParser = JsonFormat.parser();
protoToJsonPrinter = JsonFormat.printer().omittingInsignificantWhitespace();

objectMapper = JsonMapper.builder()
.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false)
.build();
objectReader = objectMapper.readerFor(JsonNode.class);
objectNode = objectMapper.createObjectNode();

}
catch (IOException | Descriptors.DescriptorValidationException e) {
throw new RuntimeException(e);
}
}

@Override
public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
exchange.getResponse().getHeaders().set("Content-Type", "application/json");

return getDelegate().writeWith(deserializeJSONRequest().map(callGRPCServer())
.map(serialiseGRPCResponse())
.map(wrapGRPCResponse())
.cast(DataBuffer.class)
.last());
}

private ClientCall<DynamicMessage, DynamicMessage> createClientCallForType(Config config,
Descriptors.ServiceDescriptor serviceDescriptor, Descriptors.Descriptor outputType) {
MethodDescriptor.Marshaller<DynamicMessage> marshaller = ProtoUtils
.marshaller(DynamicMessage.newBuilder(outputType).build());
MethodDescriptor<DynamicMessage, DynamicMessage> methodDescriptor = MethodDescriptor
.<DynamicMessage, DynamicMessage>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName(
MethodDescriptor.generateFullMethodName(serviceDescriptor.getFullName(), config.getMethod()))
.setRequestMarshaller(marshaller)
.setResponseMarshaller(marshaller)
.build();
Channel channel = createChannel();
return channel.newCall(methodDescriptor, CallOptions.DEFAULT);
}

private Descriptors.MethodDescriptor getMethodDescriptor(Config config)
throws IOException, Descriptors.DescriptorValidationException {
Objects.requireNonNull(config.getProtoDescriptor(), "Proto Descriptor must not be null");
Expand Down Expand Up @@ -271,6 +275,36 @@ private FileDescriptor[] dependencies(FileDescriptorSet input, ProtocolStringLis
return null;
}

}

class GRPCResponseDecorator extends ServerHttpResponseDecorator {

private final ServerWebExchange exchange;

private final GrpcCallContext ctx;

GRPCResponseDecorator(ServerWebExchange exchange, GrpcCallContext ctx) {
super(exchange.getResponse());
this.exchange = exchange;
this.ctx = ctx;
}

@Override
public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
exchange.getResponse().getHeaders().set("Content-Type", "application/json");

return getDelegate().writeWith(deserializeJSONRequest().map(callGRPCServer())
.map(serialiseGRPCResponse())
.map(wrapGRPCResponse())
.cast(DataBuffer.class)
.last());
}

private ClientCall<DynamicMessage, DynamicMessage> createClientCallForType(MethodDescriptor<DynamicMessage, DynamicMessage> methodDescriptor) {
Channel channel = createChannel();
return channel.newCall(methodDescriptor, CallOptions.DEFAULT);
}

private ManagedChannel createChannel() {
Route route = (Route) exchange.getAttributes().get(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
URI requestURI = Objects.requireNonNull(route, "Route not found in exchange attributes").getUri();
Expand All @@ -280,8 +314,10 @@ private ManagedChannel createChannel() {
private Function<JsonNode, DynamicMessage> callGRPCServer() {
return jsonRequest -> {
try {
DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor);
JsonFormat.parser().merge(jsonRequest.toString(), builder);
ClientCall<DynamicMessage, DynamicMessage> clientCall = createClientCallForType(ctx.methodDescriptor);

DynamicMessage.Builder builder = DynamicMessage.newBuilder(ctx.descriptor);
ctx.jsonToProtoParser.merge(jsonRequest.toString(), builder);
return ClientCalls.blockingUnaryCall(clientCall, builder.build());
}
catch (IOException e) {
Expand All @@ -293,8 +329,8 @@ private Function<JsonNode, DynamicMessage> callGRPCServer() {
private Function<DynamicMessage, Object> serialiseGRPCResponse() {
return gRPCResponse -> {
try {
return objectReader
.readValue(JsonFormat.printer().omittingInsignificantWhitespace().print(gRPCResponse));
return ctx.objectReader
.readValue(ctx.protoToJsonPrinter.print(gRPCResponse));
}
catch (IOException e) {
throw new RuntimeException(e);
Expand All @@ -305,29 +341,18 @@ private Function<DynamicMessage, Object> serialiseGRPCResponse() {
private Flux<JsonNode> deserializeJSONRequest() {
return exchange.getRequest().getBody().mapNotNull(dataBufferBody -> {
if (dataBufferBody.capacity() == 0) {
return objectNode;
return ctx.objectNode;
}
ResolvableType targetType = ResolvableType.forType(JsonNode.class);
return new JacksonJsonDecoder().decode(dataBufferBody, targetType, null, null);
}).cast(JsonNode.class);
}

private Function<Object, DataBuffer> wrapGRPCResponse() {
return jsonResponse -> new NettyDataBufferFactory(new PooledByteBufAllocator())
return jsonResponse -> exchange.getResponse().bufferFactory()
.wrap(Objects.requireNonNull(new ObjectMapper().writeValueAsBytes(jsonResponse)));
}

// We are creating this on every call, should optimize?
private ManagedChannel createChannelChannel(String host, int port) {
NettyChannelBuilder nettyChannelBuilder = NettyChannelBuilder.forAddress(host, port);
try {
return grpcSslConfigurer.configureSsl(nettyChannelBuilder);
}
catch (SSLException e) {
throw new RuntimeException(e);
}
}

}

}