Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/OpenClaw.Connection/GatewayClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public GatewayClientLifecycleAdapter(OpenClawGatewayClient client)
public event EventHandler<ConnectionStatus>? StatusChanged;
public event EventHandler<string>? AuthenticationFailed;

public Task ConnectAsync(CancellationToken ct) => _client.ConnectAsync();
public Task ConnectAsync(CancellationToken ct) => _client.ConnectAsync(ct);

public void Dispose() => _client.Dispose();
}
37 changes: 28 additions & 9 deletions src/OpenClaw.Connection/GatewayConnectionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ private async Task ConnectCoreAsync(string? gatewayId = null)
catch (OperationCanceledException) { }
catch (Exception ex)
{
_logger.Error($"[ConnMgr] Connect failed: {ex.Message}");
if (Interlocked.Read(ref _generation) == gen)
{
_logger.Error($"[ConnMgr] Connect failed: {ex.Message}");
}
}
}, ct);
}
Expand All @@ -276,6 +279,11 @@ public async Task DisconnectAsync()
/// <summary>Core disconnect logic. Caller must hold <see cref="_transitionSemaphore"/>.</summary>
private void DisconnectCore()
{
Interlocked.Increment(ref _generation);
var oldCts = Interlocked.Exchange(ref _operationCts, null);
oldCts?.Cancel();
oldCts?.Dispose();

var prev = _stateMachine.Current.OverallState;
DisposeActiveClient();
_stateMachine.TryTransition(ConnectionTrigger.DisconnectRequested);
Expand Down Expand Up @@ -533,7 +541,7 @@ private async Task HandleHandshakeSucceededAsync(long gen)
// Start node connection outside the semaphore to avoid deadlocks
if (_nodeConnector != null && ShouldStartNodeConnection())
{
await StartNodeConnectionAsync();
await StartNodeConnectionAsync(gen);
}
}

Expand Down Expand Up @@ -653,7 +661,7 @@ void Handler(object? _, GatewayConnectionSnapshot s)
StateChanged += Handler;
try
{
var startAttempted = await StartNodeConnectionAsync();
var startAttempted = await StartNodeConnectionAsync(Interlocked.Read(ref _generation));

if (!startAttempted)
{
Expand Down Expand Up @@ -706,19 +714,26 @@ private bool ShouldStartNodeConnection()
return _isNodeEnabled?.Invoke() ?? false;
}

private async Task<bool> StartNodeConnectionAsync()
private bool IsGenerationStale(long? generation)
=> generation.HasValue && Interlocked.Read(ref _generation) != generation.Value;

private async Task<bool> StartNodeConnectionAsync(long? generation = null)
{
if (_nodeConnector == null || _activeGatewayRecordId == null || _activeIdentityPath == null) return false;
if (IsGenerationStale(generation)) return false;

var record = _registry.GetById(_activeGatewayRecordId);
var activeGatewayRecordId = _activeGatewayRecordId;
var activeIdentityPath = _activeIdentityPath;
if (_nodeConnector == null || activeGatewayRecordId == null || activeIdentityPath == null) return false;

var record = _registry.GetById(activeGatewayRecordId);
if (record == null)
{
_logger.Warn("[ConnMgr] Cannot start node — gateway record not found");
return false;
}

// Use root identity path — clients always read/write from root, not per-gateway
var nodeCredential = _credentialResolver.ResolveNode(record, _activeIdentityPath!);
var nodeCredential = _credentialResolver.ResolveNode(record, activeIdentityPath);
if (nodeCredential == null)
{
_logger.Warn("[ConnMgr] No node credential available — skipping node connection");
Expand All @@ -731,6 +746,7 @@ private async Task<bool> StartNodeConnectionAsync()
await _transitionSemaphore.WaitAsync();
try
{
if (IsGenerationStale(generation)) return false;
_stateMachine.SetNodeEnabled(true);
}
finally
Expand All @@ -747,7 +763,9 @@ private async Task<bool> StartNodeConnectionAsync()

try
{
await _nodeConnector.ConnectAsync(nodeConnectUrl, nodeCredential, _activeIdentityPath,
if (IsGenerationStale(generation)) return false;

await _nodeConnector.ConnectAsync(nodeConnectUrl, nodeCredential, activeIdentityPath,
useV2Signature: _gatewayNeedsV2Signature);
}
catch (Exception ex)
Expand Down Expand Up @@ -809,6 +827,7 @@ private async void OnNodeStatusChanged(object? sender, ConnectionStatus status)

private async void OnNodePairingStatusChanged(object? sender, PairingStatusEventArgs e)
{
var handlerGeneration = Interlocked.Read(ref _generation);
_diagnostics.Record("node", $"Node pairing: {e.Status}");

await _transitionSemaphore.WaitAsync();
Expand Down Expand Up @@ -871,7 +890,7 @@ private async void OnNodePairingStatusChanged(object? sender, PairingStatusEvent
_lastAutoApprovedRequestId = e.RequestId;
_diagnostics.Record("node", "Node pairing auto-approved — reconnecting node");
await Task.Delay(1000); // brief delay for gateway to process
await StartNodeConnectionAsync();
await StartNodeConnectionAsync(handlerGeneration);
}
else
{
Expand Down
129 changes: 103 additions & 26 deletions src/OpenClaw.Shared/WebSocketClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public abstract class WebSocketClientBase : IDisposable
private bool _disposed;
private int _reconnectAttempts;
private int _reconnectLoopActive;
private long _connectionGeneration;
private readonly SemaphoreSlim _sendSemaphore = new(1, 1);
private static readonly int[] BackoffMs = { 1000, 2000, 4000, 8000, 15000, 30000, 60000 };

protected readonly string _token;
Expand Down Expand Up @@ -102,57 +104,89 @@ protected WebSocketClientBase(string gatewayUrl, string token, IOpenClawLogger?
_cts = new CancellationTokenSource();
}

public async Task ConnectAsync()
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
if (_disposed)
{
_logger.Debug($"Skipping {ClientRole} connect: client already disposed");
return;
}

var connectGeneration = Interlocked.Increment(ref _connectionGeneration);
ClientWebSocket? ws = null;

try
{
RaiseStatusChanged(ConnectionStatus.Connecting);
_logger.Info($"Connecting to {ClientRole}: {GatewayUrlForDisplay}");

_webSocket = new ClientWebSocket();
_webSocket.Options.KeepAliveInterval = TimeSpan.FromSeconds(30);
ws = new ClientWebSocket();
ws.Options.KeepAliveInterval = TimeSpan.FromSeconds(30);
_webSocket = ws;

// Set Origin header (convert ws/wss to http/https)
var uri = new Uri(_gatewayUrl);
var originScheme = uri.Scheme == "wss" ? "https" : "http";
var origin = $"{originScheme}://{uri.Host}:{uri.Port}";
_webSocket.Options.SetRequestHeader("Origin", origin);
ws.Options.SetRequestHeader("Origin", origin);

if (!string.IsNullOrEmpty(_credentials))
{
var credentialsToEncode = GatewayUrlHelper.DecodeCredentials(_credentials);
_webSocket.Options.SetRequestHeader(
ws.Options.SetRequestHeader(
"Authorization",
$"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes(credentialsToEncode))}");
}

await _webSocket.ConnectAsync(uri, _cts.Token);
using var connectCts = CancellationTokenSource.CreateLinkedTokenSource(_cts.Token, cancellationToken);
var connectToken = connectCts.Token;

await ws.ConnectAsync(uri, connectToken);
if (!IsCurrentConnection(ws, connectGeneration) || connectToken.IsCancellationRequested)
{
DisposeStaleSocket(ws);
return;
}

// Don't reset _reconnectAttempts here — TCP connect succeeding doesn't mean
// auth will succeed. Reset only after the full application-level handshake
// completes (subclass calls ResetReconnectAttempts after hello-ok).
_logger.Info($"{ClientRole} connected, waiting for challenge...");

await OnConnectedAsync();
if (!IsCurrentConnection(ws, connectGeneration) || connectToken.IsCancellationRequested)
{
DisposeStaleSocket(ws);
return;
}

_ = Task.Run(() => ListenForMessagesAsync(), _cts.Token);
_ = Task.Run(() => ListenForMessagesAsync(ws, _cts.Token, connectGeneration), _cts.Token);
}
catch (OperationCanceledException)
{
_logger.Debug($"{ClientRole} connect canceled (likely shutdown)");
if (ws != null && ReferenceEquals(_webSocket, ws))
{
_webSocket = null;
try { ws.Dispose(); } catch { /* ignore dispose errors */ }
}
}
catch (ObjectDisposedException)
{
_logger.Debug($"{ClientRole} connect aborted after dispose");
if (ws != null && ReferenceEquals(_webSocket, ws))
{
_webSocket = null;
try { ws.Dispose(); } catch { /* ignore dispose errors */ }
}
}
catch (Exception ex)
{
if (ws != null && ReferenceEquals(_webSocket, ws))
{
_webSocket = null;
try { ws.Dispose(); } catch { /* ignore dispose errors */ }
}
_logger.Error($"{ClientRole} connection failed", ex);
RaiseStatusChanged(ConnectionStatus.Error);

Expand All @@ -163,7 +197,22 @@ public async Task ConnectAsync()
}
}

private async Task ListenForMessagesAsync()
private bool IsCurrentConnection(ClientWebSocket ws, long generation)
=> !_disposed
&& Interlocked.Read(ref _connectionGeneration) == generation
&& ReferenceEquals(_webSocket, ws);

private void DisposeStaleSocket(ClientWebSocket ws)
{
if (ReferenceEquals(_webSocket, ws))
{
_webSocket = null;
}

try { ws.Dispose(); } catch { /* ignore dispose errors */ }
}

private async Task ListenForMessagesAsync(ClientWebSocket ws, CancellationToken cancellationToken, long connectionGeneration)
{
// Rent a pooled buffer — consistent with the SendRawAsync hot path; avoids a large
// (16–64 KB) heap allocation per connection that would otherwise land on the LOH.
Expand All @@ -172,10 +221,10 @@ private async Task ListenForMessagesAsync()

try
{
while (_webSocket?.State == WebSocketState.Open && !_cts.Token.IsCancellationRequested)
while (ws.State == WebSocketState.Open && !cancellationToken.IsCancellationRequested)
{
var result = await _webSocket.ReceiveAsync(
new ArraySegment<byte>(buffer, 0, ReceiveBufferSize), _cts.Token);
var result = await ws.ReceiveAsync(
new ArraySegment<byte>(buffer, 0, ReceiveBufferSize), cancellationToken);

if (result.MessageType == WebSocketMessageType.Text)
{
Expand All @@ -197,40 +246,49 @@ private async Task ListenForMessagesAsync()
}
else if (result.MessageType == WebSocketMessageType.Close)
{
var closeStatus = _webSocket.CloseStatus?.ToString() ?? "unknown";
var closeDesc = _webSocket.CloseStatusDescription ?? "no description";
var closeStatus = ws.CloseStatus?.ToString() ?? "unknown";
var closeDesc = ws.CloseStatusDescription ?? "no description";
_logger.Info($"Server closed connection: {closeStatus} - {closeDesc}");
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
if (IsCurrentConnection(ws, connectionGeneration))
{
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
}
break;
}
}
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
{
_logger.Warn("Connection closed prematurely");
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
if (IsCurrentConnection(ws, connectionGeneration))
{
OnDisconnected();
RaiseStatusChanged(ConnectionStatus.Disconnected);
}
}
catch (OperationCanceledException) { }
catch (ObjectDisposedException) { /* CTS or WebSocket disposed during shutdown */ }
catch (Exception ex)
{
_logger.Error($"{ClientRole} listen error", ex);
OnError(ex);
RaiseStatusChanged(ConnectionStatus.Error);
if (IsCurrentConnection(ws, connectionGeneration))
{
OnError(ex);
RaiseStatusChanged(ConnectionStatus.Error);
}
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}

// Auto-reconnect if not intentionally disposed
if (!_disposed)
if (!_disposed && IsCurrentConnection(ws, connectionGeneration))
{
try
{
if (!_cts.Token.IsCancellationRequested && ShouldAutoReconnect())
if (!cancellationToken.IsCancellationRequested && ShouldAutoReconnect())
{
await ReconnectWithBackoffAsync();
}
Expand Down Expand Up @@ -295,20 +353,19 @@ protected async Task ReconnectWithBackoffAsync()
/// <summary>Send a text message over the WebSocket. Thread-safe.</summary>
protected async Task SendRawAsync(string message)
{
// Capture local reference to avoid TOCTOU race with reconnect/dispose
var ws = _webSocket;
if (ws?.State != WebSocketState.Open) return;
await _sendSemaphore.WaitAsync(_cts.Token);

try
{
if (!CanSendRaw) return;

// Rent a pooled buffer to avoid per-send heap allocations on the hot send path.
var byteCount = Encoding.UTF8.GetByteCount(message);
var buffer = ArrayPool<byte>.Shared.Rent(byteCount);
try
{
var written = Encoding.UTF8.GetBytes(message, buffer);
await ws.SendAsync(buffer.AsMemory(0, written),
WebSocketMessageType.Text, true, _cts.Token);
await SendWebSocketTextAsync(buffer.AsMemory(0, written), _cts.Token);
}
finally
{
Expand All @@ -323,6 +380,25 @@ await ws.SendAsync(buffer.AsMemory(0, written),
{
_logger.Warn($"WebSocket send failed (state changed): {ex.Message}");
}
finally
{
_sendSemaphore.Release();
}
}

/// <summary>Test seam for send gating; production sends only when the current socket is open.</summary>
protected virtual bool CanSendRaw => _webSocket?.State == WebSocketState.Open;

/// <summary>Test seam for the raw WebSocket send performed under the send gate.</summary>
protected virtual ValueTask SendWebSocketTextAsync(ReadOnlyMemory<byte> payload, CancellationToken cancellationToken)
{
var ws = _webSocket;
if (ws?.State != WebSocketState.Open)
{
return ValueTask.CompletedTask;
}

return ws.SendAsync(payload, WebSocketMessageType.Text, true, cancellationToken);
}

/// <summary>Gracefully close the WebSocket connection.</summary>
Expand All @@ -346,6 +422,7 @@ public void Dispose()

OnDisposing();

Interlocked.Increment(ref _connectionGeneration);
try { _cts.Cancel(); } catch { }

var ws = _webSocket;
Expand Down
Loading
Loading