Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import javasabr.rlib.network.impl.DefaultBufferAllocator;
import javasabr.rlib.network.impl.DefaultConnection;
import javasabr.rlib.network.impl.StringDataConnection;
import javasabr.rlib.network.impl.StringDataMtlsServerConnection;
import javasabr.rlib.network.impl.StringDataSslConnection;
import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket;
import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry;
Expand Down Expand Up @@ -140,7 +141,11 @@ public static ClientNetwork<StringDataSslConnection> stringDataSslClientNetwork(
SSLContext sslContext) {
return clientNetwork(
networkConfig,
(network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true));
(network, channel) -> {
StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true);
connection.beginHandshake();
return connection;
});
}

/**
Expand Down Expand Up @@ -196,7 +201,11 @@ public static ServerNetwork<StringDataSslConnection> stringDataSslServerNetwork(
SSLContext sslContext) {
return serverNetwork(
networkConfig,
(network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false));
(network, channel) -> {
StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false);
connection.beginHandshake();
return connection;
});
}

/**
Expand Down Expand Up @@ -231,4 +240,26 @@ public static ServerNetwork<DefaultConnection> defaultServerNetwork(
networkConfig,
(network, channel) -> new DefaultConnection(network, channel, bufferAllocator, packetRegistry));
}

/**
* Create string packet based asynchronous Mutual TLS server network.
*
* @param networkConfig the server network configuration
* @param bufferAllocator the buffer allocator
* @param sslContext SSL context
* @return a new mTLS server network
* @since 10.0.0
*/
public static ServerNetwork<StringDataMtlsServerConnection> stringDataMtlsServerNetwork(
ServerNetworkConfig networkConfig,
BufferAllocator bufferAllocator,
SSLContext sslContext) {
return serverNetwork(
networkConfig,
(network, channel) -> {
StringDataMtlsServerConnection connection = new StringDataMtlsServerConnection(network, channel, bufferAllocator, sslContext);
connection.beginHandshake();
return connection;
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package javasabr.rlib.network.exception;

public class ConnectionClosedException extends NetworkException {

public ConnectionClosedException(String remoteAddress) {
super("Connection closed: " + remoteAddress);
}

public ConnectionClosedException(String remoteAddress, Throwable cause) {
super("Connection closed: " + remoteAddress, cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import javasabr.rlib.network.Connection;
import javasabr.rlib.network.Network;
import javasabr.rlib.network.UnsafeConnection;
import javasabr.rlib.network.exception.ConnectionClosedException;
import javasabr.rlib.network.packet.NetworkPacketReader;
import javasabr.rlib.network.packet.NetworkPacketWriter;
import javasabr.rlib.network.packet.ReadableNetworkPacket;
Expand Down Expand Up @@ -64,6 +65,7 @@ public WritablePacketWithFeedback(CompletableFuture<Boolean> attachment, Writabl

final MutableArray<BiConsumer<C, ? super ReadableNetworkPacket<C>>> validPacketSubscribers;
final MutableArray<BiConsumer<C, ? super ReadableNetworkPacket<C>>> invalidPacketSubscribers;
final MutableArray<FluxSink<?>> activeSinks;

final int maxPacketsByRead;

Expand All @@ -84,6 +86,7 @@ public AbstractConnection(
this.closed = new AtomicBoolean(false);
this.validPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class);
this.invalidPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class);
this.activeSinks = ArrayFactory.copyOnModifyArray(FluxSink.class);
this.remoteAddress = String.valueOf(NetworkUtils.getRemoteAddress(channel));
}

Expand Down Expand Up @@ -134,10 +137,12 @@ protected void registerFluxOnReceivedEvents(

validPacketSubscribers.add(validListener);
invalidPacketSubscribers.add(invalidListener);
activeSinks.add(sink);

sink.onDispose(() -> {
validPacketSubscribers.remove(validListener);
validPacketSubscribers.remove(invalidListener);
activeSinks.remove(sink);
});

network.inNetworkThread(() -> packetReader().startRead());
Expand All @@ -146,14 +151,22 @@ protected void registerFluxOnReceivedEvents(
protected void registerFluxOnReceivedValidPackets(FluxSink<? super ReadableNetworkPacket<C>> sink) {
BiConsumer<C, ReadableNetworkPacket<C>> listener = (connection, packet) -> sink.next(packet);
validPacketSubscribers.add(listener);
sink.onDispose(() -> validPacketSubscribers.remove(listener));
activeSinks.add(sink);
sink.onDispose(() -> {
validPacketSubscribers.remove(listener);
activeSinks.remove(sink);
});
network.inNetworkThread(() -> packetReader().startRead());
}

protected void registerFluxOnReceivedInvalidPackets(FluxSink<? super ReadableNetworkPacket<C>> sink) {
BiConsumer<C, ReadableNetworkPacket<C>> listener = (connection, packet) -> sink.next(packet);
invalidPacketSubscribers.add(listener);
sink.onDispose(() -> invalidPacketSubscribers.remove(listener));
activeSinks.add(sink);
sink.onDispose(() -> {
invalidPacketSubscribers.remove(listener);
activeSinks.remove(sink);
});
network.inNetworkThread(() -> packetReader().startRead());
}

Expand Down Expand Up @@ -184,6 +197,24 @@ protected void doClose() {
clearWaitPackets();
packetReader().close();
packetWriter().close();
notifySinksOnError();
}

protected void notifySinksOnError() {
if (activeSinks.isEmpty()) {
return;
}
ConnectionClosedException error = new ConnectionClosedException(remoteAddress);
activeSinks
.iterations()
.forEach(error, (sink, exc) -> {
try {
sink.error(exc);
} catch (RuntimeException e) {
log.error("Failed to notify sink of connection closure: " + e.getMessage());
}
});
activeSinks.clear();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ public AbstractSslConnection(
super(network, channel, bufferAllocator, maxPacketsByRead);
this.sslEngine = sslContext.createSSLEngine();
this.sslEngine.setUseClientMode(clientMode);
}

public void beginHandshake() {
try {
this.sslEngine.beginHandshake();
} catch (SSLException e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package javasabr.rlib.network.impl;

import javasabr.rlib.network.BufferAllocator;
import javasabr.rlib.network.Network;
import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket;

import javax.net.ssl.SSLContext;
import java.nio.channels.AsynchronousSocketChannel;

/**
* @author crazyrokr
*/
public class StringDataMtlsServerConnection extends DefaultDataSslConnection<StringDataMtlsServerConnection> {

public StringDataMtlsServerConnection(
Network<StringDataMtlsServerConnection> network,
AsynchronousSocketChannel channel,
BufferAllocator bufferAllocator,
SSLContext sslContext) {
super(network, channel, bufferAllocator, sslContext, 100, 2, false);
sslEngine.setNeedClientAuth(true);
}

@Override
protected StringReadableNetworkPacket<StringDataMtlsServerConnection> createReadablePacket() {
return new StringReadableNetworkPacket<>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,14 @@ protected void handleFailedReceiving(Throwable exception, ByteBuffer readingBuff
retryReadLater();
}
}
case AsynchronousCloseException ex ->
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
case ClosedChannelException ex ->
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
case AsynchronousCloseException ex -> {
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
connection.close();
}
case ClosedChannelException ex -> {
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
connection.close();
}
default -> {
log.error(exception);
connection.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ protected AbstractSslNetworkPacketReader(
protected void handleReceivedData(int receivedBytes, ByteBuffer readingBuffer) {
if (receivedBytes == -1) {
doHandshake(sslNetworkBuffer(), -1);
connection.close();
return;
}
super.handleReceivedData(receivedBytes, readingBuffer);
Expand Down Expand Up @@ -159,6 +160,9 @@ protected int doHandshake(ByteBuffer networkBuffer, int receivedBytes) {
case NEED_WRAP: {
log.debug(remoteAddress, "[%s] Send command to wrap data"::formatted);
packetWriter.accept(SslWrapRequestNetworkPacket.getInstance());
if (networkBuffer.hasRemaining()) {
return decryptAndRead(networkBuffer);
}
NetworkUtils.cleanNetworkBuffer(networkBuffer);
return SKIP_READ_PACKETS;
}
Expand Down Expand Up @@ -203,6 +207,10 @@ protected int decryptAndRead(ByteBuffer receivedBuffer) {
}
switch (result.getStatus()) {
case OK: {
if (result.bytesConsumed() == 0 && result.bytesProduced() == 0) {
log.debug(remoteAddress, "[%s] No progress during decryption, stop processing"::formatted);
return SKIP_READ_PACKETS;
}
sslDataBuffer.flip();
logDataAfterDecrypt(remoteAddress, sslDataBuffer);
total += readPackets(sslDataBuffer, sslDataPendingBuffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ protected ByteBuffer doHandshake(HandshakeStatus handshakeStatus) {
break;
}
case NEED_UNWRAP: {
break;
return EMPTY_BUFFER;
}
default: {
throw new IllegalStateException("Invalid SSL status:" + handshakeStatus);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import javasabr.rlib.common.util.ObjectUtils;
import javasabr.rlib.common.util.StringUtils;
import javasabr.rlib.common.util.Utils;
import javasabr.rlib.network.client.ClientNetwork;
import javasabr.rlib.network.exception.ConnectionClosedException;
import javasabr.rlib.network.impl.DefaultBufferAllocator;
import javasabr.rlib.network.impl.StringDataMtlsServerConnection;
import javasabr.rlib.network.impl.StringDataSslConnection;
import javasabr.rlib.network.packet.ReadableNetworkPacket;
import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket;
Expand Down Expand Up @@ -328,6 +331,63 @@ void shouldReceiveManyPacketsFromSmallToBigSize() {
}
}

@Test
@SneakyThrows
void shouldRejectClientWithoutCertificateWithinMutualTls() {
InputStream serverKeystoreFile = StringSslNetworkTest.class.getResourceAsStream("/ssl/rlib_test_cert.p12");
SSLContext serverSslContext = NetworkUtils.createSslContext(serverKeystoreFile, "test");
ServerNetworkConfig serverConfig = ServerNetworkConfig.SimpleServerNetworkConfig.builder().build();
BufferAllocator bufferAllocator = new DefaultBufferAllocator(serverConfig);

ServerNetwork<StringDataMtlsServerConnection> serverNetwork =
NetworkFactory.stringDataMtlsServerNetwork(serverConfig, bufferAllocator, serverSslContext);

InetSocketAddress serverAddress = serverNetwork.start();
CountDownLatch dataReceivedByServer = new CountDownLatch(1);

serverNetwork
.accepted()
.flatMap(Connection::receivedEvents)
.subscribe(event -> dataReceivedByServer.countDown());

SSLContext clientWithoutCertContext = NetworkUtils.createAllTrustedClientSslContext();
ClientNetwork<StringDataSslConnection> clientNetwork = NetworkFactory.stringDataSslClientNetwork(
NetworkConfig.DEFAULT_CLIENT,
new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT),
clientWithoutCertContext);

AtomicReference<Throwable> connectionError = new AtomicReference<>();
CountDownLatch errorReceived = new CountDownLatch(1);

try {
clientNetwork
.connectReactive(serverAddress)
.doOnNext(connection -> connection.sendInBackground(new StringWritableNetworkPacket<>("no cert")))
.flatMapMany(Connection::receivedEvents)
.subscribe(
event -> {},
ex -> {
connectionError.set(ex);
errorReceived.countDown();
});

assertThat(errorReceived.await(5, TimeUnit.SECONDS))
.as("Client subscriber must receive an error when the server closes the mTLS connection.")
.isTrue();

assertThat(connectionError.get())
.as("Client must receive ConnectionClosedException, not a timeout.")
.isInstanceOf(ConnectionClosedException.class);

assertThat(dataReceivedByServer.getCount())
.as("Server must not receive data from an unauthenticated client.")
.isEqualTo(1);
} finally {
serverNetwork.shutdown();
clientNetwork.shutdown();
}
}

private static StringWritableNetworkPacket<StringDataSslConnection> newMessage(int minMessageLength, int maxMessageLength) {
return new StringWritableNetworkPacket<>(StringUtils.generate(minMessageLength, maxMessageLength));
}
Expand Down
Loading
Loading