Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion controlplane/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
cp.cfg.WorkerQueueTimeout,
server.BackendKey{Pid: pid, SecretKey: secretKey},
func(ctx context.Context) (int32, *flightclient.FlightExecutor, error) {
return sessions.CreateSession(ctx, username, clientSearchPath, pid, memLimit, threads)
return sessions.CreateSession(ctx, username, pid, memLimit, threads)
},
)
if err != nil {
Expand Down Expand Up @@ -1144,6 +1144,26 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
_ = writer.Flush()
return
}

// Apply the client's connect-time search_path (from the startup `options`
// parameter, e.g. `options=-c search_path=iceberg.public`) AFTER metadata
// init. It must run here, not on the worker at session create: (1)
// InitSessionDatabaseMetadata's defer resets search_path to the ducklake
// default, so an earlier value is clobbered; and (2) running metadata init
// while the session default points at the iceberg REST catalog fails. We
// append memory.main so pg_catalog macros stay resolvable; on failure
// (e.g. the schema doesn't exist) we log and keep the default.
if clientSearchPath != "" {
sp := clientSearchPath
if !strings.Contains(strings.ToLower(sp), "memory.main") {
sp += ",memory.main"
}
spCtx, spCancel := context.WithTimeout(context.Background(), cp.cfg.SessionInitTimeout)
if _, err := executor.ExecContext(spCtx, fmt.Sprintf("SET search_path = '%s'", sp)); err != nil {
slog.Warn("Failed to apply client connect-time search_path; using default.", "user", username, "org", orgID, "search_path", clientSearchPath, "error", err)
}
spCancel()
}
}

// Register the TCP connection so OnWorkerCrash can close it to unblock
Expand Down
4 changes: 2 additions & 2 deletions controlplane/flight_ingress.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type flightSessionProvider struct {
}

func (p *flightSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) {
workerPID, executor, err := p.sm.CreateSession(ctx, username, "", pid, memoryLimit, threads)
workerPID, executor, err := p.sm.CreateSession(ctx, username, pid, memoryLimit, threads)
if err != nil {
return 0, nil, err
}
Expand Down Expand Up @@ -88,7 +88,7 @@ func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username s

// SessionManager.resolveSessionLimits handles rebalancer defaults,
// so pass memoryLimit/threads through as-is.
workerPID, executor, err := sessions.CreateSessionWithProtocol(ctx, username, "", pid, memoryLimit, threads, "flight")
workerPID, executor, err := sessions.CreateSessionWithProtocol(ctx, username, pid, memoryLimit, threads, "flight")
if err != nil {
return 0, nil, err
}
Expand Down
14 changes: 7 additions & 7 deletions controlplane/session_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ func (sm *SessionManager) removeWaiterLocked(waiter *connectionWaiter) {
// CreateSession acquires a worker from the configured pool, creates a session
// on it, and rebalances memory/thread limits across all active sessions.
// If pid is 0, a new one is generated.
func (sm *SessionManager) CreateSession(ctx context.Context, username, searchPath string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) {
return sm.CreateSessionWithProtocol(ctx, username, searchPath, pid, memoryLimit, threads, "postgres")
func (sm *SessionManager) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) {
return sm.CreateSessionWithProtocol(ctx, username, pid, memoryLimit, threads, "postgres")
}

func (sm *SessionManager) CreateSessionWithProtocol(ctx context.Context, username, searchPath string, pid int32, memoryLimit string, threads int, protocol string) (int32, *flightclient.FlightExecutor, error) {
func (sm *SessionManager) CreateSessionWithProtocol(ctx context.Context, username string, pid int32, memoryLimit string, threads int, protocol string) (int32, *flightclient.FlightExecutor, error) {
ctx, endCreation, err := sm.beginSessionCreation(ctx)
if err != nil {
return 0, nil, err
Expand Down Expand Up @@ -287,7 +287,7 @@ func (sm *SessionManager) CreateSessionWithProtocol(ctx context.Context, usernam
acquireSpan.End()
slog.Debug("Worker acquired.", "pid", pid, "worker", worker.ID, "user", username, "duration", time.Since(acquireStart))

pid, exec, err := sm.createSessionOnWorker(ctx, username, searchPath, pid, memoryLimit, threads, worker, protocol, true, lease)
pid, exec, err := sm.createSessionOnWorker(ctx, username, pid, memoryLimit, threads, worker, protocol, true, lease)
if err != nil {
return 0, nil, err
}
Expand Down Expand Up @@ -334,7 +334,7 @@ func (sm *SessionManager) ReconnectFlightSession(ctx context.Context, username s
if err != nil {
return 0, nil, fmt.Errorf("reconnect worker %d: %w", workerID, err)
}
pid, exec, err := sm.createSessionOnWorker(ctx, username, "", 0, "", 0, worker, "flight", false, lease)
pid, exec, err := sm.createSessionOnWorker(ctx, username, 0, "", 0, worker, "flight", false, lease)
if err != nil {
return 0, nil, err
}
Expand All @@ -346,7 +346,7 @@ func (sm *SessionManager) beginSessionCreation(ctx context.Context) (context.Con
return sm.lifecycle.begin(ctx)
}

func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username, searchPath string, pid int32, memoryLimit string, threads int, worker *ManagedWorker, protocol string, retireOnFailure bool, lease connectionLease) (int32, *flightclient.FlightExecutor, error) {
func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username string, pid int32, memoryLimit string, threads int, worker *ManagedWorker, protocol string, retireOnFailure bool, lease connectionLease) (int32, *flightclient.FlightExecutor, error) {
createStart := time.Now()
slog.Info("Creating session on worker.",
"pid", pid,
Expand All @@ -358,7 +358,7 @@ func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username, s
"owner_cp_instance_id", worker.OwnerCPInstanceID(),
"owner_epoch", worker.OwnerEpoch(),
)
sessionToken, err := worker.CreateSession(ctx, username, memoryLimit, searchPath, threads)
sessionToken, err := worker.CreateSession(ctx, username, memoryLimit, threads)
if err != nil {
slog.Warn("Failed to create session on worker.",
"pid", pid,
Expand Down
6 changes: 3 additions & 3 deletions controlplane/session_mgr_drain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ func TestDestroyAllSessionsRejectsInFlightCreateBeforeRegistration(t *testing.T)

createErr := make(chan error, 1)
go func() {
_, _, err := sm.CreateSessionWithProtocol(context.Background(), "root", "", 1010, "", 0, "postgres")
_, _, err := sm.CreateSessionWithProtocol(context.Background(), "root", 1010, "", 0, "postgres")
createErr <- err
}()

Expand Down Expand Up @@ -722,7 +722,7 @@ func TestDestroyAllSessionsWaitsForCreateBlockedInLimiterAcquire(t *testing.T) {

createErr := make(chan error, 1)
go func() {
_, _, err := sm.CreateSessionWithProtocol(context.Background(), "root", "", 1010, "", 0, "postgres")
_, _, err := sm.CreateSessionWithProtocol(context.Background(), "root", 1010, "", 0, "postgres")
createErr <- err
}()

Expand Down Expand Up @@ -771,7 +771,7 @@ func TestDestroyAllSessionsCancelsCreateBlockedInAcquireWorker(t *testing.T) {

createErr := make(chan error, 1)
go func() {
_, _, err := sm.CreateSessionWithProtocol(callerCtx, "root", "", 1010, "", 0, "postgres")
_, _, err := sm.CreateSessionWithProtocol(callerCtx, "root", 1010, "", 0, "postgres")
createErr <- err
}()

Expand Down
2 changes: 1 addition & 1 deletion controlplane/session_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestCreateSessionObservesWarmCapacityExhaustion(t *testing.T) {
err: NewWarmCapacityExhaustedError(30 * time.Second),
}, nil)

_, _, err := sm.CreateSession(context.Background(), "root", "", 1001, "", 0)
_, _, err := sm.CreateSession(context.Background(), "root", 1001, "", 0)
var capacityErr *WarmCapacityExhaustedError
if !errors.As(err, &capacityErr) {
t.Fatalf("expected warm capacity error, got %v", err)
Expand Down
3 changes: 1 addition & 2 deletions controlplane/worker_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ func recoverWorkerPanic(err *error) {
}

// CreateSession creates a new session on the given worker.
func (w *ManagedWorker) CreateSession(ctx context.Context, username, memoryLimit, searchPath string, threads int) (token string, err error) {
func (w *ManagedWorker) CreateSession(ctx context.Context, username, memoryLimit string, threads int) (token string, err error) {
defer recoverWorkerPanic(&err)

body, _ := json.Marshal(server.WorkerCreateSessionPayload{
Expand All @@ -1211,7 +1211,6 @@ func (w *ManagedWorker) CreateSession(ctx context.Context, username, memoryLimit
},
Username: username,
MemoryLimit: memoryLimit,
SearchPath: searchPath,
Threads: threads,
})

Expand Down
2 changes: 1 addition & 1 deletion duckdbservice/flight_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (h *FlightSQLHandler) doCreateSession(body []byte, stream flight.FlightServ
}
}

session, err := h.pool.CreateSession(req.Username, req.MemoryLimit, req.SearchPath, req.Threads)
session, err := h.pool.CreateSession(req.Username, req.MemoryLimit, req.Threads)
if err != nil {
return status.Errorf(codes.ResourceExhausted, "create session: %v", err)
}
Expand Down
30 changes: 4 additions & 26 deletions duckdbservice/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,8 @@ func (svc *DuckDBService) Shutdown() {
svc.pool.CloseAll()
}

// CreateSession creates a new DuckDB session for the given username. When
// searchPath is non-empty it is the client's connect-time search_path
// (sanitized control-plane side) and overrides the username-based default.
func (p *SessionPool) CreateSession(username, memoryLimit, searchPath string, threads int) (*Session, error) {
// CreateSession creates a new DuckDB session for the given username.
func (p *SessionPool) CreateSession(username, memoryLimit string, threads int) (*Session, error) {
start := time.Now()
// Reserve a slot under the lock to prevent TOCTOU race on maxSessions.
p.mu.Lock()
Expand Down Expand Up @@ -525,7 +523,7 @@ func (p *SessionPool) CreateSession(username, memoryLimit, searchPath string, th

// Initialize the session connection with username-specific state if needed.
// Since the DB is shared, we must set session-local parameters here.
initSearchPath(conn, username, searchPath)
initSearchPath(conn, username)

// Re-apply DuckDB profiling settings on the freshly-pooled connection.
// ConfigureMainDB sets these once at warmup, but in cluster mode the
Expand Down Expand Up @@ -922,27 +920,7 @@ func dropTemporary(conn *sql.Conn, query, dropFmt string) bool {
// format_type, etc.) remain resolvable when the default catalog is ducklake.
// Without it, DuckDB restricts function resolution to the ducklake catalog
// and psql commands like \dt fail.
func initSearchPath(conn *sql.Conn, username, clientSearchPath string) {
// Honor a client-supplied connect-time search_path (from the startup
// `options` parameter, e.g. `options=-c search_path=iceberg.public`) so a
// session can pick its default catalog/schema at connect — the standard
// mechanism BI tools and JDBC (currentSchema) use. The value is sanitized
// control-plane side. memory.main is appended so pg_catalog macros stay
// resolvable. On failure (e.g. the schema doesn't exist) we fall through to
// the username-based default rather than leaving the session unconfigured.
if clientSearchPath != "" {
sp := clientSearchPath
if !strings.Contains(strings.ToLower(sp), "memory.main") {
sp += ",memory.main"
}
if _, err := conn.ExecContext(context.Background(), fmt.Sprintf("SET search_path = '%s'", sp)); err == nil {
return
} else {
slog.Warn("Client-requested search_path rejected; using default.", "user", username, "search_path", clientSearchPath, "error", err)
_, _ = conn.ExecContext(context.Background(), "ROLLBACK")
}
}

func initSearchPath(conn *sql.Conn, username string) {
if _, err := conn.ExecContext(context.Background(), fmt.Sprintf("SET search_path = '%s,main,memory.main'", username)); err != nil {
slog.Debug("User schema not found, using default search_path.", "user", username)
// Clear the aborted transaction state before retrying.
Expand Down
46 changes: 2 additions & 44 deletions duckdbservice/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestInitSearchPath(t *testing.T) {
defer func() { _ = conn.Close() }()

// "nonexistent_user" is not a schema — should fall back to 'main' without error
initSearchPath(conn, "nonexistent_user", "")
initSearchPath(conn, "nonexistent_user")

var searchPath string
if err := conn.QueryRowContext(context.Background(), "SELECT current_setting('search_path')").Scan(&searchPath); err != nil {
Expand All @@ -51,7 +51,7 @@ func TestInitSearchPath(t *testing.T) {
t.Fatalf("failed to create schema: %v", err)
}

initSearchPath(conn, "myuser", "")
initSearchPath(conn, "myuser")

var searchPath string
if err := conn.QueryRowContext(context.Background(), "SELECT current_setting('search_path')").Scan(&searchPath); err != nil {
Expand All @@ -61,48 +61,6 @@ func TestInitSearchPath(t *testing.T) {
t.Errorf("expected search_path 'myuser,main,memory.main', got %q", searchPath)
}
})

t.Run("honors client search_path and appends memory.main", func(t *testing.T) {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatalf("failed to get connection: %v", err)
}
defer func() { _ = conn.Close() }()

if _, err := conn.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS chosen"); err != nil {
t.Fatalf("failed to create schema: %v", err)
}

// Client picked `chosen` at connect; memory.main must be appended.
initSearchPath(conn, "ignored_user", "chosen")

var searchPath string
if err := conn.QueryRowContext(context.Background(), "SELECT current_setting('search_path')").Scan(&searchPath); err != nil {
t.Fatalf("failed to query search_path: %v", err)
}
if searchPath != "chosen,memory.main" {
t.Errorf("expected search_path 'chosen,memory.main', got %q", searchPath)
}
})

t.Run("falls back to default when client search_path is invalid", func(t *testing.T) {
conn, err := db.Conn(context.Background())
if err != nil {
t.Fatalf("failed to get connection: %v", err)
}
defer func() { _ = conn.Close() }()

// A schema that doesn't exist: SET fails, fall back to the username default.
initSearchPath(conn, "nonexistent_user", "no_such_schema")

var searchPath string
if err := conn.QueryRowContext(context.Background(), "SELECT current_setting('search_path')").Scan(&searchPath); err != nil {
t.Fatalf("failed to query search_path: %v", err)
}
if searchPath != "main,memory.main" {
t.Errorf("expected fallback search_path 'main,memory.main', got %q", searchPath)
}
})
}

func TestCleanupSessionState(t *testing.T) {
Expand Down
4 changes: 0 additions & 4 deletions server/wire/worker_proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ type WorkerCreateSessionPayload struct {
Username string `json:"username"`
MemoryLimit string `json:"memory_limit"`
Threads int `json:"threads"`
// SearchPath, when set, is the client's connect-time search_path (parsed
// from the startup `options` parameter, e.g. `-c search_path=iceberg.public`).
// Empty means use the worker's default. Sanitized control-plane side.
SearchPath string `json:"search_path,omitempty"`
}

// WorkerDestroySessionPayload is the control plane request body for
Expand Down
Loading