Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/StackExchange.Redis/Enums/CommandFlags.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,7 @@ public enum CommandFlags
// 1024: Removed - was used for async timeout checks; never user-specified, so not visible on the public API

// 2048: Use subscription connection type; never user-specified, so not visible on the public API

// 4096: Identifies handshake completion messages; never user-specified, so not visible on the public API
}
}
7 changes: 6 additions & 1 deletion src/StackExchange.Redis/Message.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ internal abstract partial class Message : ICompletable

private const CommandFlags AskingFlag = (CommandFlags)32,
ScriptUnavailableFlag = (CommandFlags)256,
DemandSubscriptionConnection = (CommandFlags)2048;
DemandSubscriptionConnection = (CommandFlags)2048,
HandshakeCompletionFlag = (CommandFlags)4096;

private const CommandFlags MaskPrimaryServerPreference = CommandFlags.DemandMaster
| CommandFlags.DemandReplica
Expand Down Expand Up @@ -720,6 +721,8 @@ internal void SetWriteTime()

public virtual string CommandString => Command.ToString();

public bool IsHandshakeCompletion => (Flags & HandshakeCompletionFlag) != 0;

/// <summary>
/// Sends this command to the subscription connection rather than the interactive.
/// </summary>
Expand All @@ -742,6 +745,8 @@ internal void SetAsking(bool value)
else Flags &= ~AskingFlag; // and the bits taketh away
}

internal void SetHandshakeCompletion() => Flags |= HandshakeCompletionFlag;

internal void SetNoRedirect() => Flags |= CommandFlags.NoRedirect;

internal void SetPreferPrimary() =>
Expand Down
5 changes: 3 additions & 2 deletions src/StackExchange.Redis/PhysicalBridge.cs
Original file line number Diff line number Diff line change
Expand Up @@ -617,13 +617,14 @@ internal void OnHeartbeat(bool ifConnectedOnly)
// We need to time that out and cleanup the PhysicalConnection if needed, otherwise that reader and socket will remain open
// for the lifetime of the application due to being orphaned, yet still referenced by the active task doing the pipe read.
case (int)State.ConnectedEstablished:
// Track that we should reset the count on the next disconnect, but not do so in a loop
shouldResetConnectionRetryCount = true;
var tmp = physical;
if (tmp != null)
{
if (state == (int)State.ConnectedEstablished)
{
// Track that we should reset the count on the next disconnect, but not do so in a loop, reset
// the connect-retry-count (used for backoff decay etc), and remove any non-responsive flag.
shouldResetConnectionRetryCount = true;
Interlocked.Exchange(ref connectTimeoutRetryCount, 0);
tmp.BridgeCouldBeNull?.ServerEndPoint?.ClearUnselectable(UnselectableFlags.DidNotRespond);
}
Expand Down
10 changes: 8 additions & 2 deletions src/StackExchange.Redis/PhysicalConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,15 @@ internal void SetProtocol(RedisProtocol value)
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times", Justification = "Trust me yo")]
internal void Shutdown()
internal void Shutdown(ConnectionFailureType failureType = ConnectionFailureType.ConnectionDisposed)
{
var ioPipe = Interlocked.Exchange(ref _ioPipe, null); // compare to the critical read
var socket = Interlocked.Exchange(ref _socket, null);

if (ioPipe != null)
{
Trace("Disconnecting...");
try { BridgeCouldBeNull?.OnDisconnected(ConnectionFailureType.ConnectionDisposed, this, out _, out _); } catch { }
try { BridgeCouldBeNull?.OnDisconnected(failureType, this, out _, out _); } catch { }
try { ioPipe.Input?.CancelPendingRead(); } catch { }
try { ioPipe.Input?.Complete(); } catch { }
try { ioPipe.Output?.CancelPendingFlush(); } catch { }
Expand Down Expand Up @@ -777,6 +777,12 @@ internal int OnBridgeHeartbeat()
multiplexer.OnAsyncTimeout();
result++;
}
else if (msg.IsHandshakeCompletion)
{
// Critical handshake validation timed out; note that this doesn't have a result-box,
// so doesn't get timed out via the above.
Shutdown(ConnectionFailureType.UnableToConnect);
}
}
else
{
Expand Down
3 changes: 3 additions & 0 deletions src/StackExchange.Redis/ServerEndPoint.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Runtime.CompilerServices;
Expand Down Expand Up @@ -1068,8 +1069,10 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log)
}

var tracer = GetTracerMessage(true);
tracer.SetHandshakeCompletion();
tracer = LoggingMessage.Create(log, tracer);
log?.LogInformationSendingCriticalTracer(new(this), tracer.CommandAndKey);
Debug.Assert(tracer.IsHandshakeCompletion, "Tracer message should identify as handshake completion");
await WriteDirectOrQueueFireAndForgetAsync(connection, tracer, ResultProcessor.EstablishConnection).ForAwait();

// Note: this **must** be the last thing on the subscription handshake, because after this
Expand Down
28 changes: 25 additions & 3 deletions tests/StackExchange.Redis.Tests/InProcessTestServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,25 @@ public override void OnClientConnected(RedisClient client, object state)
base.OnClientConnected(client, state);
}

public override void OnClientCompleted(RedisClient client, Exception? fault)
{
if (fault is null)
{
_log?.WriteLine($"[{client}] completed");
}
else
{
_log?.WriteLine($"[{client}] faulted: {fault.Message} ({fault.GetType().Name})");
}
base.OnClientCompleted(client, fault);
}

protected override void OnSkippedReply(RedisClient client)
{
_log?.WriteLine($"[{client}] skipped reply");
base.OnSkippedReply(client);
}

private sealed class InProcTunnel(
InProcessTestServer server,
PipeOptions? pipeOptions = null) : Tunnel
Expand All @@ -189,14 +208,15 @@ private sealed class InProcTunnel(
return base.GetSocketConnectEndpointAsync(endpoint, cancellationToken);
}

public override ValueTask<Stream?> BeforeAuthenticateAsync(
public override async ValueTask<Stream?> BeforeAuthenticateAsync(
EndPoint endpoint,
ConnectionType connectionType,
Socket? socket,
CancellationToken cancellationToken)
{
if (server.TryGetNode(endpoint, out var node))
{
await server.OnAcceptClientAsync(endpoint);
var clientToServer = new Pipe(pipeOptions ?? PipeOptions.Default);
var serverToClient = new Pipe(pipeOptions ?? PipeOptions.Default);
var serverSide = new Duplex(clientToServer.Reader, serverToClient.Writer);
Expand All @@ -211,9 +231,9 @@ private sealed class InProcTunnel(
var readStream = serverToClient.Reader.AsStream();
var writeStream = clientToServer.Writer.AsStream();
var clientSide = new DuplexStream(readStream, writeStream);
return new(clientSide);
return clientSide;
}
return base.BeforeAuthenticateAsync(endpoint, connectionType, socket, cancellationToken);
return await base.BeforeAuthenticateAsync(endpoint, connectionType, socket, cancellationToken);
}

private sealed class Duplex(PipeReader input, PipeWriter output) : IDuplexPipe
Expand All @@ -230,6 +250,8 @@ public ValueTask Dispose()
}
}

protected virtual ValueTask OnAcceptClientAsync(EndPoint endpoint) => default;

/*

private readonly RespServer _server;
Expand Down
205 changes: 205 additions & 0 deletions tests/StackExchange.Redis.Tests/RetryPolicyUnitTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using StackExchange.Redis.Server;
using Xunit;

namespace StackExchange.Redis.Tests;

public class RetryPolicyUnitTests(ITestOutputHelper log)
{
[Theory]
[InlineData(FailureMode.Success)]
[InlineData(FailureMode.ConnectionRefused)]
[InlineData(FailureMode.SlowNonConnect)]
[InlineData(FailureMode.NoResponses)]
[InlineData(FailureMode.GarbageResponses)]
public async Task RetryPolicyFailureCases(FailureMode failureMode)
{
using var server = new NonResponsiveServer(log);
var options = server.GetClientConfig(withPubSub: false);
var policy = new CountingRetryPolicy();
options.ConnectRetry = 5;
options.SyncTimeout = options.AsyncTimeout = options.ConnectTimeout = 1_000;
options.ReconnectRetryPolicy = policy;

// connect while the server is stable
await using var conn = await ConnectionMultiplexer.ConnectAsync(options);
var db = conn.GetDatabase();
db.Ping();
Assert.Equal(0, policy.Clear());

// now tell the server to become non-responsive to the next 2, and kill the current
server.FailNext(2, failureMode);
server.ForAllClients(x => x.Kill());

for (int i = 0; i < 10; i++)
{
try
{
await db.PingAsync();
break;
}
catch (Exception ex)
{
log.WriteLine($"{nameof(db.PingAsync)} attempt {i}: {ex.GetType().Name}: {ex.Message}");
}
}
var counts = policy.GetRetryCounts();
if (failureMode is FailureMode.Success)
{
Assert.Empty(counts);
}
else
{
Assert.Equal("0,1", string.Join(",", counts));
}
}

private sealed class CountingRetryPolicy : IReconnectRetryPolicy
{
private readonly struct RetryRequest(int currentRetryCount, int timeElapsedMillisecondsSinceLastRetry)
{
public int CurrentRetryCount { get; } = currentRetryCount;
public int TimeElapsedMillisecondsSinceLastRetry { get; } = timeElapsedMillisecondsSinceLastRetry;
}
private readonly List<RetryRequest> retryCounts = [];

public int Clear()
{
lock (retryCounts)
{
int count = retryCounts.Count;
retryCounts.Clear();
return count;
}
}

public int[] GetRetryCounts()
{
lock (retryCounts)
{
return retryCounts.Select(x => x.CurrentRetryCount).ToArray();
}
}

public bool ShouldRetry(long currentRetryCount, int timeElapsedMillisecondsSinceLastRetry)
{
lock (retryCounts)
{
retryCounts.Add(new(checked((int)currentRetryCount), timeElapsedMillisecondsSinceLastRetry));
}
return true;
}
}

public enum FailureMode
{
Success,
SlowNonConnect,
ConnectionRefused,
NoResponses,
GarbageResponses,
}
private sealed class NonResponsiveServer(ITestOutputHelper log) : InProcessTestServer(log)
{
private int _failNext;
private FailureMode _failureMode;

public void FailNext(int count, FailureMode failureMode)
{
_failNext = count;
_failureMode = failureMode;
}

protected override ValueTask OnAcceptClientAsync(EndPoint endpoint)
{
switch (_failureMode)
{
case FailureMode.SlowNonConnect when ShouldIgnoreClient():
Log($"(leaving pending connect to {endpoint})");
return TimeoutEventually();
case FailureMode.ConnectionRefused when ShouldIgnoreClient():
Log($"(rejecting connection to {endpoint})");
throw new SocketException((int)SocketError.ConnectionRefused);
default:
return base.OnAcceptClientAsync(endpoint);
}

static async ValueTask TimeoutEventually()
{
await Task.Delay(TimeSpan.FromMinutes(5)).ConfigureAwait(false);
throw new TimeoutException();
}
}

private bool ShouldIgnoreClient()
{
while (true)
{
var oldValue = Volatile.Read(ref _failNext);
if (oldValue <= 0) return false;
var newValue = oldValue - 1;
if (Interlocked.CompareExchange(ref _failNext, newValue, oldValue) == oldValue) return true;
}
}

private sealed class GarbageClient(Node node) : RedisClient(node)
{
protected override void WriteResponse(
IBufferWriter<byte> output,
TypedRedisValue value,
RedisProtocol protocol)
{
#if NET
var rand = Random.Shared;
#else
var rand = new Random();
#endif
var len = rand.Next(1, 1024);
var buffer = ArrayPool<byte>.Shared.Rent(len);
var span = buffer.AsSpan(0, len);
try
{
#if NET
rand.NextBytes(span);
#else
rand.NextBytes(buffer);
#endif
output.Write(span);
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
}
}

public override RedisClient CreateClient(Node node)
{
RedisClient client;
if (_failureMode is FailureMode.GarbageResponses && ShouldIgnoreClient())
{
client = new GarbageClient(node);
Log($"(accepting garbage-responsive connection to {node.Host}:{node.Port})");
return client;
}
client = base.CreateClient(node);
if (_failureMode is FailureMode.NoResponses && ShouldIgnoreClient())
{
Log($"(accepting non-responsive connection to {node.Host}:{node.Port})");
client.SkipAllReplies();
}
else
{
Log($"(accepting responsive connection to {node.Host}:{node.Port})");
}
return client;
}
}
}
Loading
Loading