Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 74 additions & 18 deletions controlplane/session_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,28 @@ type SessionProgress struct {
Stalled bool
}

// SessionConn is the client transport for a managed session. It supports both
// writing pgwire packets (used to deliver a FATAL ErrorResponse on worker
// crash before closing) and closing the underlying TCP. *tls.Conn — the type
// the pgwire handshake hands us — satisfies this interface naturally.
type SessionConn interface {
io.Writer
io.Closer
}

// ManagedSession tracks a client session bound to a worker.
type ManagedSession struct {
PID int32
WorkerID int
Protocol string // "postgres" or "flight"
SessionToken string
Executor *flightclient.FlightExecutor
connCloser io.Closer // TCP connection, closed on worker crash to unblock the message loop
// conn is the client TCP/TLS connection. Used by the worker-crash path
// to deliver a FATAL ErrorResponse and then close the socket — the
// FATAL is what lets clients (libpq, dbt's psycopg2 adapter) cleanly
// surface "your session was lost" instead of waiting forever on a
// half-open TCP that just got reset.
conn SessionConn

// Cached query progress from worker health checks.
queryProgress atomic.Value // stores *SessionProgress (or nil)
Expand Down Expand Up @@ -296,23 +310,42 @@ func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) {
errorFn(pid)
sm.mu.Lock()
session, ok := sm.sessions[pid]
var executor *flightclient.FlightExecutor
var conn SessionConn
if ok {
delete(sm.sessions, pid)
if session.Executor != nil {
_ = session.Executor.Close()
}
// Close the TCP connection to unblock the message loop's read.
// This causes the session goroutine to exit instead of looping
// with ErrWorkerDead on every query. The deferred close in
// handleConnection will also call Close() on the same conn;
// that's harmless (net.Conn.Close on a closed socket returns
// an error which is discarded).
if session.connCloser != nil {
_ = session.connCloser.Close()
}
executor = session.Executor
conn = session.conn
}
remainingSessions := len(sm.sessions)
sm.mu.Unlock()

if executor != nil {
_ = executor.Close()
}
// Deliver a pgwire FATAL ErrorResponse before closing the TCP.
// Without the FATAL, libpq-based clients (psql, dbt's psycopg2
// adapter) can hang on a half-open socket — psql happens to
// handle the bare TCP close OK because its read loop returns,
// but dbt's libpq-async + disabled keepalives setup leaves
// PQconsumeInput parked indefinitely. The FATAL gives every
// client a structured "your session was lost" they can surface.
//
// Concurrency: the message loop also writes to this conn via
// its own bufio.Writer, but *tls.Conn / net.Conn serialize
// underlying Write calls internally — so we may interleave at
// the message boundary (corrupting an in-flight DataRow), but
// not at the byte level. A client that sees a malformed packet
// followed by a FATAL still surfaces an error cleanly, which
// is strictly better than a silent half-open socket.
//
// Write and Close happen outside sm.mu so a slow or wedged client
// socket cannot block unrelated session-manager operations.
if conn != nil {
_ = server.WriteErrorResponse(conn, "FATAL", "08006",
fmt.Sprintf("worker %d for this session became unresponsive and was reaped", workerID))
_ = conn.Close()
}
slog.Info("Worker crash session cleanup completed.",
"pid", pid,
"worker", workerID,
Expand All @@ -332,17 +365,40 @@ func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) {
}
}

// SetConnCloser registers the client's TCP connection so it can be closed
// when the backing worker crashes. This unblocks the message loop's read,
// causing it to exit cleanly instead of looping on ErrWorkerDead.
func (sm *SessionManager) SetConnCloser(pid int32, closer io.Closer) {
// SetSessionConn registers the client's TCP/TLS transport so the worker-crash
// path can deliver a FATAL ErrorResponse and close the socket. *tls.Conn from
// the pgwire handshake satisfies SessionConn (io.Writer + io.Closer).
func (sm *SessionManager) SetSessionConn(pid int32, conn SessionConn) {
sm.mu.Lock()
defer sm.mu.Unlock()
if s, ok := sm.sessions[pid]; ok {
s.connCloser = closer
s.conn = conn
}
}

// SetConnCloser is a back-compat shim for callers that previously passed an
// io.Closer. The real type they pass (*tls.Conn) is also an io.Writer, so we
// upcast to SessionConn here. New callers should use SetSessionConn directly.
//
// Deprecated: use SetSessionConn.
func (sm *SessionManager) SetConnCloser(pid int32, closer io.Closer) {
conn, ok := closer.(SessionConn)
if !ok {
// Caller passed a closer that isn't also a Writer — we can still
// close on crash, just can't deliver a FATAL. Wrap in a discarding
// writer so the type satisfies SessionConn.
conn = closeOnlyConn{closer}
}
sm.SetSessionConn(pid, conn)
}

// closeOnlyConn adapts an io.Closer with no Writer into a SessionConn whose
// Write is a no-op. Used by the deprecated SetConnCloser path for callers that
// genuinely don't have a Writer; modern callers pass *tls.Conn directly.
type closeOnlyConn struct{ io.Closer }

func (closeOnlyConn) Write(p []byte) (int, error) { return len(p), nil }

// SessionCount returns the number of active sessions.
func (sm *SessionManager) SessionCount() int {
sm.mu.RLock()
Expand Down
113 changes: 101 additions & 12 deletions controlplane/session_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,63 @@
package controlplane

import (
"bytes"
"errors"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"

"github.com/posthog/duckgres/server/flightclient"
)

// mockCloser tracks whether Close was called.
// mockCloser stands in for the client TCP/TLS conn — captures bytes written
// (so tests can assert a FATAL ErrorResponse was delivered) and tracks whether
// Close was called.
type mockCloser struct {
closed atomic.Bool
closed atomic.Bool
writeMu sync.Mutex
written []byte
events []string
}

func (m *mockCloser) Write(p []byte) (int, error) {
m.writeMu.Lock()
defer m.writeMu.Unlock()
if m.closed.Load() {
return 0, errors.New("write after close")
}
m.events = append(m.events, "write")
m.written = append(m.written, p...)
return len(p), nil
}

func (m *mockCloser) Bytes() []byte {
m.writeMu.Lock()
defer m.writeMu.Unlock()
out := make([]byte, len(m.written))
copy(out, m.written)
return out
}

func (m *mockCloser) Close() error {
m.writeMu.Lock()
defer m.writeMu.Unlock()
m.events = append(m.events, "close")
m.closed.Store(true)
return nil
}

func (m *mockCloser) Events() []string {
m.writeMu.Lock()
defer m.writeMu.Unlock()
out := make([]string, len(m.events))
copy(out, m.events)
return out
}

func TestOnWorkerCrash_MarksExecutorsDead(t *testing.T) {
pool := &FlightWorkerPool{
workers: make(map[int]*ManagedWorker),
Expand Down Expand Up @@ -72,10 +111,10 @@ func TestOnWorkerCrash_ClosesConnections(t *testing.T) {

sm.mu.Lock()
sm.sessions[pid] = &ManagedSession{
PID: pid,
WorkerID: 7,
Executor: executor,
connCloser: conn,
PID: pid,
WorkerID: 7,
Executor: executor,
conn: conn,
}
sm.byWorker[7] = []int32{pid}
sm.mu.Unlock()
Expand All @@ -99,8 +138,8 @@ func TestOnWorkerCrash_MultipleSessions(t *testing.T) {
conn2 := &mockCloser{}

sm.mu.Lock()
sm.sessions[1001] = &ManagedSession{PID: 1001, WorkerID: 3, Executor: exec1, connCloser: conn1}
sm.sessions[1002] = &ManagedSession{PID: 1002, WorkerID: 3, Executor: exec2, connCloser: conn2}
sm.sessions[1001] = &ManagedSession{PID: 1001, WorkerID: 3, Executor: exec1, conn: conn1}
sm.sessions[1002] = &ManagedSession{PID: 1002, WorkerID: 3, Executor: exec2, conn: conn2}
sm.byWorker[3] = []int32{1001, 1002}
sm.mu.Unlock()

Expand All @@ -117,6 +156,56 @@ func TestOnWorkerCrash_MultipleSessions(t *testing.T) {
}
}

func TestOnWorkerCrash_WritesFATALBeforeClose(t *testing.T) {
// Asserts the new behavior: when a worker is reaped, the CP delivers a
// pgwire FATAL ErrorResponse on the client conn before closing it. This
// is the difference between psql cleanly surfacing "connection lost" and
// dbt's libpq state machine hanging silently on a half-open socket.
pool := &FlightWorkerPool{workers: make(map[int]*ManagedWorker)}
sm := NewSessionManager(pool, nil)

conn := &mockCloser{}
pid := int32(1500)

sm.mu.Lock()
sm.sessions[pid] = &ManagedSession{
PID: pid,
WorkerID: 42,
Executor: &flightclient.FlightExecutor{},
conn: conn,
}
sm.byWorker[42] = []int32{pid}
sm.mu.Unlock()

sm.OnWorkerCrash(42, func(int32) {})

if !conn.closed.Load() {
t.Fatal("conn was not closed on worker crash")
}

// Inspect the bytes the crash handler wrote: pgwire ErrorResponse starts
// with the byte 'E', followed by a 4-byte length, then field-tagged
// strings. We just assert FATAL + the worker ID appear so we know a
// FATAL packet was emitted before the close — not testing the wire
// encoding in detail (server/wire owns that).
got := conn.Bytes()
if len(got) == 0 {
t.Fatal("no bytes written before close — FATAL not delivered")
}
if got[0] != 'E' {
t.Errorf("expected first byte 'E' (ErrorResponse), got %q", got[0])
}
if !bytes.Contains(got, []byte("FATAL")) {
t.Errorf("expected 'FATAL' in payload, got %q", got)
}
if !bytes.Contains(got, []byte("42")) {
t.Errorf("expected worker id '42' in payload, got %q", got)
}
if got := conn.Events(); !slices.Equal(got, []string{"write", "write", "write", "close"}) {
t.Fatalf("expected FATAL message writes before close, got event order %v", got)
}
}

func TestSetConnCloser(t *testing.T) {
pool := &FlightWorkerPool{
workers: make(map[int]*ManagedWorker),
Expand Down Expand Up @@ -233,10 +322,10 @@ func TestDestroySessionAfterOnWorkerCrash(t *testing.T) {

sm.mu.Lock()
sm.sessions[pid] = &ManagedSession{
PID: pid,
WorkerID: 9,
Executor: executor,
connCloser: conn,
PID: pid,
WorkerID: 9,
Executor: executor,
conn: conn,
}
sm.byWorker[9] = []int32{pid}
sm.mu.Unlock()
Expand Down
Loading