Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lua/claudecode/server/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ local M = {}
---@field server table|nil The TCP server instance
---@field port number|nil The port server is running on
---@field auth_token string|nil The authentication token for validating connections
---@field clients table<string, WebSocketClient> A list of connected clients
---@field clients table<string, WebSocketClient> Mirrored view of connected clients (updated via tcp callbacks)
---@field handlers table Message handlers by method name
---@field ping_timer table|nil Timer for sending pings
M.state = {
Expand Down
52 changes: 44 additions & 8 deletions lua/claudecode/server/tcp.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ local M = {}
---@field server table The vim.loop TCP server handle
---@field port number The port the server is listening on
---@field auth_token string|nil The authentication token for validating connections
---@field clients table<string, WebSocketClient> Table of connected clients
---@field clients table<string, WebSocketClient> Transport-level registry of connected clients (canonical)
---@field on_message function Callback for WebSocket messages
---@field on_connect function Callback for new connections
---@field on_disconnect function Callback for client disconnections
Expand Down Expand Up @@ -124,33 +124,69 @@ function M._handle_new_connection(server)
-- Set up data handler
client_tcp:read_start(function(err, data)
if err then
server.on_error("Client read error: " .. err)
M._remove_client(server, client)
local error_msg = "Client read error: " .. err
server.on_error(error_msg)
M._disconnect_client(server, client, 1006, error_msg)
return
end

if not data then
-- EOF - client disconnected
M._remove_client(server, client)
M._disconnect_client(server, client, 1006, "EOF")
return
end

-- Process incoming data
client_manager.process_data(client, data, function(cl, message)
server.on_message(cl, message)
end, function(cl, code, reason)
server.on_disconnect(cl, code, reason)
M._remove_client(server, cl)
M._disconnect_client(server, cl, code, reason)
end, function(cl, error_msg)
server.on_error("Client " .. cl.id .. " error: " .. error_msg)
M._remove_client(server, cl)
M._disconnect_client(server, cl, 1006, "Client error: " .. error_msg)
end, server.auth_token)
end)

-- Notify about new connection
server.on_connect(client)
end

---Disconnect a client and remove it from the server.
---This ensures `server.on_disconnect` is invoked for every disconnect path
---(EOF, read errors, protocol errors, timeouts), keeping higher-level client
---state in sync.
---@param server TCPServer The server object
---@param client WebSocketClient The client to disconnect
---@param code number|nil WebSocket close code
---@param reason string|nil WebSocket close reason
function M._disconnect_client(server, client, code, reason)
assert(type(server) == "table", "Expected server to be a table")
local on_disconnect_type = type(server.on_disconnect)
local on_disconnect_mt = on_disconnect_type == "table" and getmetatable(server.on_disconnect) or nil
assert(
on_disconnect_type == "function" or (on_disconnect_mt ~= nil and type(on_disconnect_mt.__call) == "function"),
"Expected server.on_disconnect to be callable"
)
assert(type(server.clients) == "table", "Expected server.clients to be a table")
assert(type(client) == "table", "Expected client to be a table")
assert(type(client.id) == "string", "Expected client.id to be a string")
if code ~= nil then
assert(type(code) == "number", "Expected code to be a number")
end
if reason ~= nil then
assert(type(reason) == "string", "Expected reason to be a string")
end

-- Idempotency: a client can hit multiple disconnect paths (e.g. CLOSE frame
-- followed by a TCP EOF). Only notify/remove once.
if not server.clients[client.id] then
return
end

server.on_disconnect(client, code, reason)
M._remove_client(server, client)
end

---Remove a client from the server
---@param server TCPServer The server object
---@param client WebSocketClient The client to remove
Expand Down Expand Up @@ -293,7 +329,7 @@ function M.start_ping_timer(server, interval)
string.format("Client %s keepalive timeout (%ds idle), closing connection", client.id, time_since_pong)
)
client_manager.close_client(client, 1006, "Connection timeout")
M._remove_client(server, client)
M._disconnect_client(server, client, 1006, "Connection timeout")
end
end
end
Expand Down
1 change: 1 addition & 0 deletions tests/mocks/vim.lua
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ local vim = {
return true
end,
read_start = function(self, callback)
self._read_cb = callback
return true
end,
write = function(self, data, callback)
Expand Down
131 changes: 131 additions & 0 deletions tests/unit/server/tcp_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
require("tests.busted_setup")

local client_manager = require("claudecode.server.client")

describe("TCP server disconnect handling", function()
local tcp
local original_process_data

before_each(function()
package.loaded["claudecode.server.tcp"] = nil
tcp = require("claudecode.server.tcp")
original_process_data = client_manager.process_data
end)

after_each(function()
client_manager.process_data = original_process_data
end)

it("should call on_disconnect and remove client on EOF", function()
local callbacks = {
on_message = spy.new(function() end),
on_connect = spy.new(function() end),
on_disconnect = spy.new(function() end),
on_error = spy.new(function() end),
}

local config = { port_range = { min = 10000, max = 10000 } }
local server, err = tcp.create_server(config, callbacks, nil)
assert.is_nil(err)
assert.is_table(server)

tcp._handle_new_connection(server)

assert.spy(callbacks.on_connect).was_called(1)
local client = callbacks.on_connect.calls[1].vals[1]
assert.is_table(client)
assert.is_table(client.tcp_handle)
assert.is_function(client.tcp_handle._read_cb)

-- Simulate client abruptly disconnecting (e.g. CLI terminated via Ctrl-C)
client.tcp_handle._read_cb(nil, nil)

assert.spy(callbacks.on_disconnect).was_called(1)
assert.spy(callbacks.on_disconnect).was_called_with(client, 1006, "EOF")
expect(server.clients[client.id]).to_be_nil()
end)

it("should call on_disconnect and remove client on TCP read error", function()
local callbacks = {
on_message = spy.new(function() end),
on_connect = spy.new(function() end),
on_disconnect = spy.new(function() end),
on_error = spy.new(function() end),
}

local config = { port_range = { min = 10000, max = 10000 } }
local server, err = tcp.create_server(config, callbacks, nil)
assert.is_nil(err)
assert.is_table(server)

tcp._handle_new_connection(server)

local client = callbacks.on_connect.calls[1].vals[1]
client.tcp_handle._read_cb("boom", nil)

assert.spy(callbacks.on_disconnect).was_called(1)
assert.spy(callbacks.on_disconnect).was_called_with(client, 1006, "Client read error: boom")
expect(server.clients[client.id]).to_be_nil()

assert.spy(callbacks.on_error).was_called(1)
assert.spy(callbacks.on_error).was_called_with("Client read error: boom")
end)

it("should call on_disconnect when client manager reports an error", function()
client_manager.process_data = function(cl, data, on_message, on_close, on_error, auth_token)
on_error(cl, "Protocol error")
end

local callbacks = {
on_message = spy.new(function() end),
on_connect = spy.new(function() end),
on_disconnect = spy.new(function() end),
on_error = spy.new(function() end),
}

local config = { port_range = { min = 10000, max = 10000 } }
local server, err = tcp.create_server(config, callbacks, nil)
assert.is_nil(err)
assert.is_table(server)

tcp._handle_new_connection(server)

local client = callbacks.on_connect.calls[1].vals[1]
client.tcp_handle._read_cb(nil, "some data")

assert.spy(callbacks.on_disconnect).was_called(1)
assert.spy(callbacks.on_disconnect).was_called_with(client, 1006, "Client error: Protocol error")
expect(server.clients[client.id]).to_be_nil()
end)

it("should only call on_disconnect once if multiple disconnect paths fire", function()
client_manager.process_data = function(cl, data, on_message, on_close, on_error, auth_token)
on_close(cl, 1000, "bye")
end

local callbacks = {
on_message = spy.new(function() end),
on_connect = spy.new(function() end),
on_disconnect = spy.new(function() end),
on_error = spy.new(function() end),
}

local config = { port_range = { min = 10000, max = 10000 } }
local server, err = tcp.create_server(config, callbacks, nil)
assert.is_nil(err)
assert.is_table(server)

tcp._handle_new_connection(server)

local client = callbacks.on_connect.calls[1].vals[1]
client.tcp_handle._read_cb(nil, "data")

assert.spy(callbacks.on_disconnect).was_called(1)
assert.spy(callbacks.on_disconnect).was_called_with(client, 1000, "bye")
expect(server.clients[client.id]).to_be_nil()

-- Simulate a later EOF after the CLOSE path already removed the client.
client.tcp_handle._read_cb(nil, nil)
assert.spy(callbacks.on_disconnect).was_called(1)
end)
end)