Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions internal/command/command_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type (
commandServerType model.ServerType
subscribeMutex sync.Mutex
agentConfigMutex sync.RWMutex
subscribeWg sync.WaitGroup
}
)

Expand Down Expand Up @@ -195,7 +196,11 @@ func (cp *CommandPlugin) createConnection(ctx context.Context, resource *mpi.Res
subscribeCtx, cp.subscribeCancel = context.WithCancel(ctx)
cp.subscribeMutex.Unlock()

go cp.commandService.Subscribe(subscribeCtx)
cp.subscribeWg.Add(1)
go func() {
defer cp.subscribeWg.Done()
cp.commandService.Subscribe(subscribeCtx)
}()

cp.messagePipe.Process(ctx, &bus.Message{
Topic: bus.ConnectionCreatedTopic,
Expand Down Expand Up @@ -294,18 +299,14 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me
cp.subscribeMutex.Lock()
defer cp.subscribeMutex.Unlock()

// Update the command service with the new client first
err := cp.commandService.UpdateClient(ctxWithMetadata, newConnection.CommandServiceClient())
if err != nil {
slog.ErrorContext(ctx, "Failed to reset connection", "error", err)
return
}

// Once the command service is updated, we close the old connection
slog.InfoContext(ctx, "Canceling old subscribe stream after connection reset")
// Cancel the old subscribe stream and close the old connection first, so the server removes
// the connection from its tracker before we call CreateConnection
// with the same UUID in UpdateClient. Without this ordering, the server would track the UUID
// from CreateConnection and then immediately remove it when the old stream exits.
slog.InfoContext(ctx, "Canceling old subscribe stream before connection reset")
if cp.subscribeCancel != nil {
cp.subscribeCancel()
slog.InfoContext(ctxWithMetadata, "Successfully canceled old subscribe stream after connection reset")
slog.InfoContext(ctxWithMetadata, "Successfully canceled old subscribe stream before connection reset")
}

connectionErr := cp.conn.Close(ctx)
Expand All @@ -314,9 +315,24 @@ func (cp *CommandPlugin) processConnectionReset(ctx context.Context, msg *bus.Me
}

cp.conn = newConnection

// Wait for the old Subscribe goroutine to fully exit before creating a new one with the new connection. .
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Trailing dot.

cp.subscribeWg.Wait()

// Update the command service with the new client after the old stream has been torn down
err := cp.commandService.UpdateClient(ctxWithMetadata, newConnection.CommandServiceClient())
if err != nil {
slog.ErrorContext(ctx, "Failed to reset connection", "error", err)
return
}

slog.InfoContext(ctxWithMetadata, "Starting new subscribe stream after connection reset")
subscribeCtx, cp.subscribeCancel = context.WithCancel(ctxWithMetadata)
go cp.commandService.Subscribe(subscribeCtx)
cp.subscribeWg.Add(1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: Are we missing a cp.subscribeWg.Wait() statement?

go func() {
defer cp.subscribeWg.Done()
cp.commandService.Subscribe(subscribeCtx)
}()

slog.InfoContext(ctx, "Command plugin connection reset finished successfully")
}
Expand Down
39 changes: 26 additions & 13 deletions internal/file/file_service_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ import (

// FileServiceOperator handles requests to the grpc file service
type FileServiceOperator struct {
fileServiceClient mpi.FileServiceClient
agentConfig *config.Config
fileOperator fileOperator
isConnected *atomic.Bool
fileServiceClient mpi.FileServiceClient
agentConfig *config.Config
fileOperator fileOperator
isConnected *atomic.Bool
fileServiceClientMu sync.RWMutex
}

var _ fileServiceOperatorInterface = (*FileServiceOperator)(nil)
Expand All @@ -56,7 +57,9 @@ func NewFileServiceOperator(agentConfig *config.Config, fileServiceClient mpi.Fi
}

func (fso *FileServiceOperator) UpdateClient(ctx context.Context, fileServiceClient mpi.FileServiceClient) {
fso.fileServiceClientMu.Lock()
fso.fileServiceClient = fileServiceClient
fso.fileServiceClientMu.Unlock()
slog.DebugContext(ctx, "File service operator updated client")
}

Expand All @@ -82,7 +85,7 @@ func (fso *FileServiceOperator) File(
grpcCtx, cancel := context.WithTimeout(ctx, fso.agentConfig.Client.FileDownloadTimeout)
defer cancel()

return fso.fileServiceClient.GetFile(grpcCtx, &mpi.GetFileRequest{
return fso.getFileServiceClient().GetFile(grpcCtx, &mpi.GetFileRequest{
MessageMeta: &mpi.MessageMeta{
MessageId: id.GenerateMessageID(),
CorrelationId: logger.CorrelationID(ctx),
Expand Down Expand Up @@ -172,7 +175,8 @@ func (fso *FileServiceOperator) UpdateOverview(
}

sendUpdateOverview := func() (*mpi.UpdateOverviewResponse, error) {
if fso.fileServiceClient == nil {
client := fso.getFileServiceClient()
if client == nil {
return nil, errors.New("file service client is not initialized")
}

Expand All @@ -188,7 +192,7 @@ func (fso *FileServiceOperator) UpdateOverview(
grpcCtx, cancel := context.WithTimeout(ctx, fso.agentConfig.Client.Grpc.ResponseTimeout)
defer cancel()

response, updateError := fso.fileServiceClient.UpdateOverview(grpcCtx, request)
response, updateError := client.UpdateOverview(grpcCtx, request)

validatedError := internalgrpc.ValidateGrpcError(updateError)

Expand Down Expand Up @@ -234,7 +238,7 @@ func (fso *FileServiceOperator) ChunkedFile(
grpcCtx, cancel := context.WithTimeout(ctx, fso.agentConfig.Client.FileDownloadTimeout)
defer cancel()

stream, err := fso.fileServiceClient.GetFileStream(grpcCtx, &mpi.GetFileRequest{
stream, err := fso.getFileServiceClient().GetFileStream(grpcCtx, &mpi.GetFileRequest{
MessageMeta: &mpi.MessageMeta{
MessageId: id.GenerateMessageID(),
CorrelationId: logger.CorrelationID(ctx),
Expand Down Expand Up @@ -329,6 +333,14 @@ func (fso *FileServiceOperator) ValidateFileHash(ctx context.Context, filePath,
return nil
}

//nolint:ireturn // getFileServiceClient needs to return an interface
func (fso *FileServiceOperator) getFileServiceClient() mpi.FileServiceClient {
fso.fileServiceClientMu.RLock()
defer fso.fileServiceClientMu.RUnlock()

return fso.fileServiceClient
}

func (fso *FileServiceOperator) updateFiles(
ctx context.Context,
delta map[string]*mpi.File,
Expand Down Expand Up @@ -378,9 +390,10 @@ func (fso *FileServiceOperator) sendUpdateFileRequest(
defer backoffCancel()

sendUpdateFile := func() (*mpi.UpdateFileResponse, error) {
client := fso.getFileServiceClient()
slog.DebugContext(ctx, "Sending update file request", "request_file", request.GetFile(),
"request_message_meta", request.GetMessageMeta())
if fso.fileServiceClient == nil {
if client == nil {
return nil, errors.New("file service client is not initialized")
}

Expand All @@ -391,7 +404,7 @@ func (fso *FileServiceOperator) sendUpdateFileRequest(
grpcCtx, cancel := context.WithTimeout(ctx, fso.agentConfig.Client.FileDownloadTimeout)
defer cancel()

response, updateError := fso.fileServiceClient.UpdateFile(grpcCtx, request)
response, updateError := client.UpdateFile(grpcCtx, request)

validatedError := internalgrpc.ValidateGrpcError(updateError)

Expand Down Expand Up @@ -429,7 +442,7 @@ func (fso *FileServiceOperator) sendUpdateFileStream(
grpcCtx, cancel := context.WithTimeout(ctx, fso.agentConfig.Client.FileDownloadTimeout)
defer cancel()

updateFileStreamClient, err := fso.fileServiceClient.UpdateFileStream(grpcCtx)
updateFileStreamClient, err := fso.getFileServiceClient().UpdateFileStream(grpcCtx)
if err != nil {
return err
}
Expand Down Expand Up @@ -469,7 +482,7 @@ func (fso *FileServiceOperator) sendUpdateFileStreamHeader(

sendUpdateFileHeader := func() error {
slog.DebugContext(ctx, "Sending update file stream header", "header", header)
if fso.fileServiceClient == nil {
if fso.getFileServiceClient() == nil {
return errors.New("file service client is not initialized")
}

Expand Down Expand Up @@ -562,7 +575,7 @@ func (fso *FileServiceOperator) sendFileUpdateStreamChunk(

sendUpdateFileChunk := func() error {
slog.DebugContext(ctx, "Sending update file stream chunk", "chunk_id", chunk.Content.GetChunkId())
if fso.fileServiceClient == nil {
if fso.getFileServiceClient() == nil {
return errors.New("file service client is not initialized")
}

Expand Down
Loading