Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public enum DisconnectReason {
AUTH_PROVIDER_NOT_FOUND("auth provider not found"),
FAILED_HANDSHAKE("Unsuccessful handshake"),
CLIENT_RATE_LIMIT("Client hits rate limiting threshold"),
CLIENT_CNX_LIMIT("Client hits connection limiting threshold");
CLIENT_CNX_LIMIT("Client hits connection limiting threshold"),
SHED_CONNECTIONS_COMMAND("shed_connections_command");

String disconnectReason;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import javax.management.JMException;
import javax.security.auth.callback.CallbackHandler;
Expand Down Expand Up @@ -160,15 +161,71 @@ public final void setZooKeeperServer(ZooKeeperServer zks) {

public abstract void closeAll(ServerCnxn.DisconnectReason reason);

/**
* Attempts to shed approximately the specified percentage of connections.
*
* @param percentage [0-100] percentage of connections to shed
* @return actual number of connections successfully closed (may vary due to
* randomness)
* @throws IllegalArgumentException if percentage not in [0, 100]
*/
public int shedConnections(final int percentage) {
if (percentage < 0 || percentage > 100) {
throw new IllegalArgumentException("percentage must be between 0 and 100, got: " + percentage);
}

final int totalConnections = cnxns.size();
if (percentage == 0 || totalConnections == 0) {
LOG.info("No connections to shed: percentage={}, totalConnections={}", percentage, totalConnections);
return 0;
}

final ServerCnxn.DisconnectReason reason = ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND;

int actualShedCount = 0;

// Special case for 100%: close all connections deterministically
if (percentage == 100) {
for (final ServerCnxn cnxn : cnxns) {
try {
cnxn.close(reason);
actualShedCount++; // Count only successful closes
} catch (final Exception e) {
LOG.warn("Failed to close connection for session 0x{}: {}",
Long.toHexString(cnxn.getSessionId()), e.getMessage());
}
}
} else {
// For other percentages, use probabilistic approach
final ThreadLocalRandom random = ThreadLocalRandom.current();
final double probability = percentage / 100.0;

for (final ServerCnxn cnxn : cnxns) {
if (random.nextDouble() < probability) {
try {
cnxn.close(reason);
actualShedCount++; // Count only successful closes
} catch (final Exception e) {
LOG.warn("Failed to close connection for session 0x{}: {}",
Long.toHexString(cnxn.getSessionId()), e.getMessage());
}
}
}
}

LOG.info("Shed {} out of {} connections ({}%)", actualShedCount, totalConnections, percentage);
return actualShedCount;
}

public static ServerCnxnFactory createFactory() throws IOException {
String serverCnxnFactoryName = System.getProperty(ZOOKEEPER_SERVER_CNXN_FACTORY);
if (serverCnxnFactoryName == null) {
serverCnxnFactoryName = NIOServerCnxnFactory.class.getName();
}
try {
ServerCnxnFactory serverCnxnFactory = (ServerCnxnFactory) Class.forName(serverCnxnFactoryName)
.getDeclaredConstructor()
.newInstance();
.getDeclaredConstructor()
.newInstance();
LOG.info("Using {} as server connection factory", serverCnxnFactoryName);
return serverCnxnFactory;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

import static org.apache.zookeeper.server.persistence.FileSnap.SNAPSHOT_FILE_PREFIX;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -292,6 +295,7 @@ public static Command getCommand(String cmdName) {
registerCommand(new RestoreCommand());
registerCommand(new RuokCommand());
registerCommand(new SetTraceMaskCommand());
registerCommand(new ShedConnectionsCommand());
registerCommand(new SnapshotCommand());
registerCommand(new SrvrCommand());
registerCommand(new StatCommand());
Expand Down Expand Up @@ -863,6 +867,105 @@ public CommandResponse runGet(ZooKeeperServer zkServer, Map<String, String> kwar

}

/**
* Attempts to shed approximately the specified percentage of connections.
* <p>
* Request: JSON input stream via HTTP POST
* Required JSON fields:
* - "percentage": Integer [0-100] - percentage of connections to attempt shedding
* <p>
* Response: JSON output stream containing:
* - "connections_shed": Integer - actual number of connections successfully closed
* (may vary due to randomness)
* - "percentage_requested": Integer - the percentage that was requested
*/
public static class ShedConnectionsCommand extends PostCommand {
private static final String PARAM_PERCENTAGE = "percentage";

public ShedConnectionsCommand() {
super(Arrays.asList("shed_connections", "shed"), true, new AuthRequest(ZooDefs.Perms.ALL, ROOT_PATH));
}

@Override
public CommandResponse runPost(final ZooKeeperServer zkServer, final InputStream inputStream) {
final CommandResponse response = initializeResponse();

if (inputStream == null) {
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
response.put("error", "Request body is required");
return response;
}

try {
final ObjectMapper mapper = new ObjectMapper();
final JsonNode jsonNode = mapper.readTree(inputStream);

if (!jsonNode.has(PARAM_PERCENTAGE)) {
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
response.put("error", "Missing required parameter: " + PARAM_PERCENTAGE);
return response;
}

final int percentage = jsonNode.get(PARAM_PERCENTAGE).asInt();
if (percentage < 0 || percentage > 100) {
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
response.put("error", "Percentage must be between 0 and 100");
return response;
}

// perform connection shedding
final int connectionsShed = shedConnections(zkServer, percentage);

// populate response
response.put("connections_shed", connectionsShed);
response.put("percentage_requested", percentage);

LOG.info("Shed {} connections ({}%)", connectionsShed, percentage);
} catch (final IOException e) {
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
response.put("error", "Invalid JSON or failed to read request body: " + e.getMessage());
} catch (final Exception e) {
response.setStatusCode(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
response.put("error", "Exception occurred during connection shedding: " + e.getMessage());
LOG.error("Exception occurred during connection shedding", e);
}
return response;
}

private int shedConnections(final ZooKeeperServer zkServer, final int percentage) {
final ServerCnxnFactory factory = zkServer.getServerCnxnFactory();
final ServerCnxnFactory secureFactory = zkServer.getSecureServerCnxnFactory();

if (factory == null && secureFactory == null) {
LOG.warn("No connection factories available for shedding connections");
return 0;
}

// If percentage is 0, don't call shedConnections on factories
if (percentage == 0) {
return 0;
}

int connectionsShed = 0;
if (factory != null) {
try {
connectionsShed += factory.shedConnections(percentage);
} catch (final Exception e) {
LOG.warn("Failed to shed connections from regular factory", e);
}
}

if (secureFactory != null) {
try {
connectionsShed += secureFactory.shedConnections(percentage);
} catch (final Exception e) {
LOG.warn("Failed to shed connections from secure factory", e);
}
}
return connectionsShed;
}
}

/**
* Same as SrvrCommand but has extra "connections" entry.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.zookeeper.server;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockingDetails;
import java.io.IOException;
import java.util.Arrays;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

public class ServerCnxnFactoryTest {
public enum FactoryType {
NIO, NETTY
}

private ServerCnxnFactory factory;

@AfterEach
public void tearDown() {
if (factory != null) {
try {
factory.shutdown();
} catch (Exception e) {
// Ignore all shutdown exceptions in tests since factories may not be fully initialized
// This includes NullPointerException when ServerSocketChannel is null
// and any other exceptions from uninitialized factory state
}
}
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_InvalidPercentage(final FactoryType factoryType) throws IOException {
factory = createFactory(factoryType);
assertThrows(IllegalArgumentException.class, () -> factory.shedConnections(-1));
assertThrows(IllegalArgumentException.class, () -> factory.shedConnections(101));
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_ValidPercentages(final FactoryType factoryType) throws IOException {
factory = createFactory(factoryType);
assertEquals(0, factory.shedConnections(0));
assertEquals(0, factory.shedConnections(50));
assertEquals(0, factory.shedConnections(100));
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_DeterministicBehavior(final FactoryType factoryType) throws Exception {
factory = createFactory(factoryType);

// Create 4 mock connections for testing deterministic edge cases
final ServerCnxn[] mockCnxns = new ServerCnxn[4];
for (int i = 0; i < 4; i++) {
mockCnxns[i] = mock(ServerCnxn.class);
factory.cnxns.add(mockCnxns[i]);
}

// Test 0% shedding - should shed exactly 0 connections (deterministic)
int shedCount = factory.shedConnections(0);
assertEquals(0, shedCount, "0% shedding should shed exactly 0 connections");

// Verify no connections were actually closed
int actualClosedCount = countConnectionsShed(mockCnxns);
assertEquals(0, actualClosedCount, "No connections should be closed for 0% shedding");

// Test 100% shedding - should shed exactly all connections (deterministic)
shedCount = factory.shedConnections(100);
assertEquals(4, shedCount, "100% shedding should shed exactly all 4 connections");

// Verify all connections were actually closed with correct reason
actualClosedCount = countConnectionsShed(mockCnxns);
assertEquals(4, actualClosedCount, "All 4 connections should be closed for 100% shedding");
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_SmallPercentageRoundsToZero(final FactoryType factoryType) throws Exception {
factory = createFactory(factoryType);

// Add single mock connection
final ServerCnxn mockCnxn = mock(ServerCnxn.class);
factory.cnxns.add(mockCnxn);

// Test critical edge case: small percentage rounds to 0
assertEquals(0, factory.shedConnections(1), "1% of 1 connection should round to 0");
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_ErrorHandling(final FactoryType factoryType) throws Exception {
factory = createFactory(factoryType);

// Create mock connections where one will fail to close
final ServerCnxn[] mockCnxns = new ServerCnxn[4];
for (int i = 0; i < 4; i++) {
mockCnxns[i] = mock(ServerCnxn.class);
factory.cnxns.add(mockCnxns[i]);
}

// Make the second connection throw an exception when closed
doThrow(new RuntimeException("Connection close failed"))
.when(mockCnxns[1]).close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND);

// Test 100% shedding to ensure error handling works deterministically
final int shedCount = factory.shedConnections(100);

// With 100% shedding, all 4 connections should be attempted to close
// The method returns the count of connections successfully closed
// Since one connection throws an exception, only 3 should be successfully closed
assertEquals(3, shedCount, "Should successfully close 3 connections, 1 should fail");

// Verify that 3 connections were actually closed (excluding the one that threw exception)
int actualClosedCount = countConnectionsShed(mockCnxns);
assertEquals(4, actualClosedCount, "All 4 connections should have close() called, even if one throws exception");
}

private ServerCnxnFactory createFactory(final FactoryType type) {
switch (type) {
case NIO:
return new NIOServerCnxnFactory();
case NETTY:
return new NettyServerCnxnFactory();
default:
throw new IllegalArgumentException("Unknown factory type: " + type);
}
}

private int countConnectionsShed(final ServerCnxn[] connections) {
return (int) Arrays.stream(connections)
.filter(cnxn -> mockingDetails(cnxn).getInvocations().stream()
.anyMatch(invocation ->
invocation.getMethod().getName().equals("close")
&& invocation.getArguments().length == 1
&& invocation.getArguments()[0].equals(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND)
))
.count();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,16 @@ public void testStatCommandSecureOnly() {
assertThat(response.toMap().containsKey("secure_connections"), is(true));
}

@Test
public void testShedConnections() throws IOException, InterruptedException {
final Map<String, String> kwargs = new HashMap<>();
final InputStream inputStream = new ByteArrayInputStream("{\"percentage\": 25}".getBytes());
final String authInfo = CommandAuthTest.buildAuthorizationForDigest();
testCommand("shed_connections", kwargs, inputStream, authInfo, new HashMap<>(), HttpServletResponse.SC_OK,
new Field("percentage_requested", Integer.class),
new Field("connections_shed", Integer.class));
}

private void testSnapshot(final boolean streaming) throws IOException, InterruptedException {
System.setProperty(ADMIN_SNAPSHOT_ENABLED, "true");
System.setProperty(ADMIN_RATE_LIMITER_INTERVAL, "0");
Expand Down
Loading