Skip to content

Commit 48085b0

Browse files
authored
Fix race condition where Prompt() returns before SessionUpdate handlers complete (#5)
* test: demonstrate Prompt() returns before SessionUpdates complete Add a failing test that surfaces a race condition where Prompt() can return before all SessionUpdate notification handlers have finished processing. The issue occurs because: - SessionUpdate notifications are handled asynchronously (goroutines) - PromptResponse is handled synchronously - The receive loop spawns notification handlers but doesn't track them When a server sends multiple SessionUpdate notifications followed by a PromptResponse, the client's Prompt() call returns immediately upon receiving the response, even though notification handlers may still be queued or running. The test expects all SessionUpdate handlers to complete before Prompt() returns, which represents the intended semantic contract: a prompt operation includes all its updates. Currently fails with 0/10 handlers completed when Prompt() returns. This will be fixed in a subsequent commit. * fix: ensure Prompt() waits for SessionUpdate handlers to complete Fix race condition where Prompt() could return before all SessionUpdate notification handlers finished processing. The issue occurred because notification/request handlers were spawned asynchronously while responses were processed synchronously. This meant the receive loop would: 1. Read SessionUpdate, then spawn goroutine G1 2. Read SessionUpdate, then spawn goroutine G2 3. Read PromptResponse, then handle synchronously, unblock Prompt() At step 3, goroutines G1/G2 _could_ be queued, running, or complete. Solution: - Add notificationWg to Connection to track in-flight handlers - Wrap notification handlers with WaitGroup Add/Done - Call notificationWg.Wait() in SendRequest/SendRequestNoResult after receiving response but before returning to caller This ensures the semantic contract that a prompt operation includes all its updates: when Prompt() returns, all SessionUpdate notifications sent before the PromptResponse have been fully processed. Fixes the test added in previous commit.
1 parent cc23bb6 commit 48085b0

File tree

2 files changed

+132
-1
lines changed

2 files changed

+132
-1
lines changed

acp_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"io"
66
"slices"
77
"sync"
8+
"sync/atomic"
89
"testing"
910
"time"
1011
)
@@ -633,3 +634,115 @@ func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) {
633634
t.Fatalf("newSession after cancel: %v", err)
634635
}
635636
}
637+
638+
// TestPromptWaitsForSessionUpdatesComplete verifies that Prompt() waits for all SessionUpdate
639+
// notification handlers to complete before returning. This ensures that when a server sends
640+
// SessionUpdate notifications followed by a PromptResponse, the client-side Prompt() call will not
641+
// return until all notification handlers have finished processing. This is the expected semantic
642+
// contract: the prompt operation includes all its updates.
643+
func TestPromptWaitsForSessionUpdatesComplete(t *testing.T) {
644+
const numUpdates = 10
645+
const handlerDelay = 50 * time.Millisecond
646+
647+
var (
648+
updateStarted atomic.Int64
649+
updateCompleted atomic.Int64
650+
)
651+
652+
c2aR, c2aW := io.Pipe()
653+
a2cR, a2cW := io.Pipe()
654+
655+
// Client side with SessionUpdate handler that tracks execution
656+
c := NewClientSideConnection(&clientFuncs{
657+
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error) {
658+
return WriteTextFileResponse{}, nil
659+
},
660+
ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) {
661+
return ReadTextFileResponse{Content: "test"}, nil
662+
},
663+
RequestPermissionFunc: func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) {
664+
return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil
665+
},
666+
SessionUpdateFunc: func(_ context.Context, n SessionNotification) error {
667+
updateStarted.Add(1)
668+
// Simulate processing time
669+
time.Sleep(handlerDelay)
670+
updateCompleted.Add(1)
671+
return nil
672+
},
673+
}, c2aW, a2cR)
674+
675+
// Agent side that sends multiple SessionUpdate notifications before responding
676+
var wg sync.WaitGroup
677+
wg.Add(1)
678+
679+
var ag *AgentSideConnection
680+
ag = NewAgentSideConnection(agentFuncs{
681+
InitializeFunc: func(context.Context, InitializeRequest) (InitializeResponse, error) {
682+
return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil
683+
},
684+
NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) {
685+
return NewSessionResponse{SessionId: "test-session"}, nil
686+
},
687+
LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) {
688+
return LoadSessionResponse{}, nil
689+
},
690+
AuthenticateFunc: func(context.Context, AuthenticateRequest) (AuthenticateResponse, error) {
691+
return AuthenticateResponse{}, nil
692+
},
693+
PromptFunc: func(ctx context.Context, p PromptRequest) (PromptResponse, error) {
694+
defer wg.Done()
695+
696+
// Send multiple SessionUpdate notifications
697+
for i := 0; i < numUpdates; i++ {
698+
_ = ag.SessionUpdate(ctx, SessionNotification{
699+
SessionId: p.SessionId,
700+
Update: SessionUpdate{
701+
AgentMessageChunk: &SessionUpdateAgentMessageChunk{
702+
Content: TextBlock("chunk"),
703+
},
704+
},
705+
})
706+
}
707+
708+
// Small delay to ensure notifications are queued
709+
time.Sleep(10 * time.Millisecond)
710+
711+
// Return response (this will unblock client's Prompt() call)
712+
return PromptResponse{StopReason: "end_turn"}, nil
713+
},
714+
CancelFunc: func(context.Context, CancelNotification) error { return nil },
715+
}, a2cW, c2aR)
716+
717+
if _, err := c.Initialize(context.Background(), InitializeRequest{ProtocolVersion: ProtocolVersionNumber}); err != nil {
718+
t.Fatalf("initialize: %v", err)
719+
}
720+
sess, err := c.NewSession(context.Background(), NewSessionRequest{Cwd: "/", McpServers: []McpServer{}})
721+
if err != nil {
722+
t.Fatalf("newSession: %v", err)
723+
}
724+
725+
_, err = c.Prompt(context.Background(), PromptRequest{
726+
SessionId: sess.SessionId,
727+
Prompt: []ContentBlock{TextBlock("test")},
728+
})
729+
if err != nil {
730+
t.Fatalf("prompt: %v", err)
731+
}
732+
733+
wg.Wait()
734+
735+
// Verify the expected behavior: at this point, Prompt() has returned, and all SessionUpdate
736+
// handlers should have completed their processing.
737+
// started := updateStarted.Load() ; Currently unsused but useful for debugging
738+
completed := updateCompleted.Load()
739+
740+
// ASSERT: when Prompt() returns, all SessionUpdate notifications that were sent
741+
// before the PromptResponse must have been fully processed. This is the semantic
742+
// contract: the prompt operation includes all its updates.
743+
if completed != numUpdates {
744+
t.Fatalf("Prompt() returned with only %d/%d SessionUpdate "+
745+
"handlers completed. Expected all handlers to complete before Prompt() "+
746+
"returns.", completed, numUpdates)
747+
}
748+
}

connection.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ type Connection struct {
4141
cancel context.CancelCauseFunc
4242

4343
logger *slog.Logger
44+
45+
// notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
46+
// for all notifications received before the response to complete processing.
47+
notificationWg sync.WaitGroup
4448
}
4549

4650
func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
@@ -94,7 +98,11 @@ func (c *Connection) receive() {
9498
case msg.ID != nil && msg.Method == "":
9599
c.handleResponse(&msg)
96100
case msg.Method != "":
97-
go c.handleInbound(&msg)
101+
c.notificationWg.Add(1)
102+
go func(m *anyMessage) {
103+
defer c.notificationWg.Done()
104+
c.handleInbound(m)
105+
}(&msg)
98106
default:
99107
c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line))
100108
}
@@ -193,6 +201,11 @@ func SendRequest[T any](c *Connection, ctx context.Context, method string, param
193201
return result, err
194202
}
195203

204+
// Wait for all notification handlers that were spawned before this response to complete
205+
// processing. This ensures that when a request returns, all notifications sent by the
206+
// server before the response have been fully processed.
207+
c.notificationWg.Wait()
208+
196209
if resp.Error != nil {
197210
return result, resp.Error
198211
}
@@ -266,6 +279,11 @@ func (c *Connection) SendRequestNoResult(ctx context.Context, method string, par
266279
return err
267280
}
268281

282+
// Wait for all notification handlers that were spawned before this response to complete
283+
// processing. This ensures that when a request returns, all notifications sent by the
284+
// server before the response have been fully processed.
285+
c.notificationWg.Wait()
286+
269287
if resp.Error != nil {
270288
return resp.Error
271289
}

0 commit comments

Comments
 (0)