Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/src/main/java/io/grpc/EquivalentAddressGroup.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ public final class EquivalentAddressGroup {
*/
public static final Attributes.Key<String> ATTR_LOCALITY_NAME =
Attributes.Key.create("io.grpc.EquivalentAddressGroup.LOCALITY");
/**
* Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32.
* Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is
* twice that of another endpoint, it is intended to receive twice the load.
*/
@Attr
static final Attributes.Key<Long> ATTR_WEIGHT =
Attributes.Key.create("io.grpc.EquivalentAddressGroup.ATTR_WEIGHT");

private final List<SocketAddress> addrs;
private final Attributes attrs;

Expand Down
29 changes: 29 additions & 0 deletions api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2026 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc;

@Internal
public final class InternalEquivalentAddressGroup {
private InternalEquivalentAddressGroup() {}

/**
* Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32.
* Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is
* twice that of another endpoint, it is intended to receive twice the load.
*/
public static final Attributes.Key<Long> ATTR_WEIGHT = EquivalentAddressGroup.ATTR_WEIGHT;
}
50 changes: 47 additions & 3 deletions core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.errorprone.annotations.CheckReturnValue;
import io.grpc.Attributes;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.InternalEquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.Status;
import io.grpc.SynchronizationContext.ScheduledHandle;
Expand Down Expand Up @@ -61,6 +63,8 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer {
static final int CONNECTION_DELAY_INTERVAL_MS = 250;
private final boolean enableHappyEyeballs = !isSerializingRetries()
&& PickFirstLoadBalancerProvider.isEnabledHappyEyeballs();
static boolean weightedShuffling =
GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true);
private final Helper helper;
private final Map<SocketAddress, SubchannelData> subchannels = new HashMap<>();
private final Index addressIndex = new Index(ImmutableList.of(), this.enableHappyEyeballs);
Expand Down Expand Up @@ -128,13 +132,13 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
PickFirstLeafLoadBalancerConfig config
= (PickFirstLeafLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
if (config.shuffleAddressList != null && config.shuffleAddressList) {
Collections.shuffle(cleanServers,
config.randomSeed != null ? new Random(config.randomSeed) : new Random());
cleanServers = shuffle(
cleanServers, config.randomSeed != null ? new Random(config.randomSeed) : new Random());
}
}

final ImmutableList<EquivalentAddressGroup> newImmutableAddressGroups =
ImmutableList.<EquivalentAddressGroup>builder().addAll(cleanServers).build();
ImmutableList.copyOf(cleanServers);

if (rawConnectivityState == READY
|| (rawConnectivityState == CONNECTING
Expand Down Expand Up @@ -224,6 +228,46 @@ private static List<EquivalentAddressGroup> deDupAddresses(List<EquivalentAddres
return newGroups;
}

// Also used by PickFirstLoadBalancer
@CheckReturnValue
static List<EquivalentAddressGroup> shuffle(List<EquivalentAddressGroup> eags, Random random) {
if (weightedShuffling) {
List<WeightEntry> weightedEntries = new ArrayList<>(eags.size());
for (EquivalentAddressGroup eag : eags) {
weightedEntries.add(new WeightEntry(eag, eagToWeight(eag, random)));
}
Collections.sort(weightedEntries, Collections.reverseOrder() /* descending */);
return Lists.transform(weightedEntries, entry -> entry.eag);
} else {
List<EquivalentAddressGroup> eagsCopy = new ArrayList<>(eags);
Collections.shuffle(eagsCopy, random);
return eagsCopy;
}
}

private static double eagToWeight(EquivalentAddressGroup eag, Random random) {
Long weight = eag.getAttributes().get(InternalEquivalentAddressGroup.ATTR_WEIGHT);
if (weight == null) {
weight = 1L;
}
return Math.pow(random.nextDouble(), 1.0 / weight);
}

private static final class WeightEntry implements Comparable<WeightEntry> {
final EquivalentAddressGroup eag;
final double weight;

public WeightEntry(EquivalentAddressGroup eag, double weight) {
this.eag = eag;
this.weight = weight;
}

@Override
public int compareTo(WeightEntry entry) {
return Double.compare(this.weight, entry.weight);
}
}

@Override
public void handleNameResolutionError(Status error) {
if (rawConnectivityState == SHUTDOWN) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.Status;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -65,9 +63,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
PickFirstLoadBalancerConfig config
= (PickFirstLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
if (config.shuffleAddressList != null && config.shuffleAddressList) {
servers = new ArrayList<EquivalentAddressGroup>(servers);
Collections.shuffle(servers,
config.randomSeed != null ? new Random(config.randomSeed) : new Random());
servers = PickFirstLeafLoadBalancer.shuffle(
servers, config.randomSeed != null ? new Random(config.randomSeed) : new Random());
}
}

Expand Down
146 changes: 139 additions & 7 deletions core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT;
import static io.grpc.LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY;
import static io.grpc.LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY;
import static io.grpc.LoadBalancer.IS_PETIOLE_POLICY;
Expand Down Expand Up @@ -70,10 +71,13 @@
import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.junit.After;
Expand Down Expand Up @@ -149,6 +153,7 @@ public void uncaughtException(Thread t, Throwable e) {

private String originalHappyEyeballsEnabledValue;
private String originalSerializeRetriesValue;
private boolean originalWeightedShuffling;

private long backoffMillis;

Expand All @@ -165,6 +170,8 @@ public void setUp() {
System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS,
Boolean.toString(enableHappyEyeballs));

originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling;

for (int i = 1; i <= 5; i++) {
SocketAddress addr = new FakeSocketAddress("server" + i);
servers.add(new EquivalentAddressGroup(addr));
Expand Down Expand Up @@ -207,6 +214,7 @@ public void tearDown() {
System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS,
originalHappyEyeballsEnabledValue);
}
PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling;

loadBalancer.shutdown();
verifyNoMoreInteractions(mockArgs);
Expand Down Expand Up @@ -242,6 +250,12 @@ public void pickAfterResolved() {
verifyNoMoreInteractions(mockHelper);
}

@Test
public void pickAfterResolved_shuffle_oppositeWeightedShuffling() {
PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling;
pickAfterResolved_shuffle();
}

@Test
public void pickAfterResolved_shuffle() {
servers.remove(4);
Expand Down Expand Up @@ -305,6 +319,103 @@ public void pickAfterResolved_noShuffle() {
assertNotNull(pickerCaptor.getValue().pickSubchannel(mockArgs));
}

@Test
public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() {
PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling;
pickAfterResolved_shuffleImplicitUniform();
}

@Test
public void pickAfterResolved_shuffleImplicitUniform() {
EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1"));
EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2"));
EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3"));

int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3));
assertThat(counts[0]).isWithin(7).of(33);
assertThat(counts[1]).isWithin(7).of(33);
assertThat(counts[2]).isWithin(7).of(33);
}

@Test
public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() {
PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling;
pickAfterResolved_shuffleExplicitUniform();
}

@Test
public void pickAfterResolved_shuffleExplicitUniform() {
EquivalentAddressGroup eag1 = new EquivalentAddressGroup(
new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build());
EquivalentAddressGroup eag2 = new EquivalentAddressGroup(
new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build());
EquivalentAddressGroup eag3 = new EquivalentAddressGroup(
new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build());

int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3));
assertThat(counts[0]).isWithin(7).of(33);
assertThat(counts[1]).isWithin(7).of(33);
assertThat(counts[2]).isWithin(7).of(33);
}

@Test
public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() {
PickFirstLeafLoadBalancer.weightedShuffling = false;
EquivalentAddressGroup eag1 = new EquivalentAddressGroup(
new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build());
EquivalentAddressGroup eag2 = new EquivalentAddressGroup(
new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build());
EquivalentAddressGroup eag3 = new EquivalentAddressGroup(
new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build());

int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3));
assertThat(counts[0]).isWithin(7).of(33);
assertThat(counts[1]).isWithin(7).of(33);
assertThat(counts[2]).isWithin(7).of(33);
}

@Test
public void pickAfterResolved_shuffleWeighted_weightedShuffling() {
PickFirstLeafLoadBalancer.weightedShuffling = true;
EquivalentAddressGroup eag1 = new EquivalentAddressGroup(
new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build());
EquivalentAddressGroup eag2 = new EquivalentAddressGroup(
new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build());
EquivalentAddressGroup eag3 = new EquivalentAddressGroup(
new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build());

int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3));
assertThat(counts[0]).isWithin(7).of(75); // 100*12/16
assertThat(counts[1]).isWithin(7).of(19); // 100*3/16
assertThat(counts[2]).isWithin(7).of(6); // 100*1/16
}

/** Returns int[index_of_eag] array with number of times each eag was selected. */
private int[] countAddressSelections(int trials, List<EquivalentAddressGroup> eags) {
int[] counts = new int[eags.size()];
Random random = new Random(1);
for (int i = 0; i < trials; i++) {
RecordingHelper helper = new RecordingHelper();
LoadBalancer lb = new PickFirstLeafLoadBalancer(helper);
assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(eags)
.setAttributes(affinity)
.setLoadBalancingPolicyConfig(
new PickFirstLeafLoadBalancerConfig(true, random.nextLong()))
.build()))
.isSameInstanceAs(Status.OK);
helper.subchannels.remove().listener.onSubchannelState(
ConnectivityStateInfo.forNonError(READY));

assertThat(helper.state).isEqualTo(READY);
Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel();
counts[eags.indexOf(subchannel.getAddresses())]++;

lb.shutdown();
}
return counts;
}

@Test
public void requestConnectionPicker() {
// Set up
Expand Down Expand Up @@ -2945,13 +3056,7 @@ public String toString() {
}
}

private class MockHelperImpl extends LoadBalancer.Helper {
private final List<Subchannel> subchannels;

public MockHelperImpl(List<? extends Subchannel> subchannels) {
this.subchannels = new ArrayList<Subchannel>(subchannels);
}

private class BaseHelper extends LoadBalancer.Helper {
@Override
public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) {
return null;
Expand Down Expand Up @@ -2981,6 +3086,14 @@ public ScheduledExecutorService getScheduledExecutorService() {
public void refreshNameResolution() {
// noop
}
}

private class MockHelperImpl extends BaseHelper {
private final List<Subchannel> subchannels;

public MockHelperImpl(List<? extends Subchannel> subchannels) {
this.subchannels = new ArrayList<Subchannel>(subchannels);
}

@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
Expand All @@ -2997,4 +3110,23 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) {
throw new IllegalArgumentException("Unexpected addresses: " + args.getAddresses());
}
}

class RecordingHelper extends BaseHelper {
ConnectivityState state;
SubchannelPicker picker;
final Queue<FakeSubchannel> subchannels = new ArrayDeque<>();

@Override
public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
this.state = newState;
this.picker = newPicker;
}

@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes());
subchannels.add(subchannel);
return subchannel;
}
}
}
Loading