Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions pkg/unikontainers/ipc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package unikontainers
import (
"errors"
"fmt"
"io"
"io/fs"
"net"
"os"
Expand Down Expand Up @@ -81,9 +82,21 @@ func SockAddrExists(sockAddr string) bool {

// SendIPCMessage creates a new connection to socketAddress, sends the message and closes the connection
func SendIPCMessage(socketAddress string, message IPCMessage) error {
conn, err := net.Dial("unix", socketAddress)
var conn net.Conn
var err error

// FIX #405: Backoff retry loop to handle IPC socket race conditions during reexec
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
conn, err = net.DialTimeout("unix", socketAddress, 100*time.Millisecond)
if err == nil {
break
}
time.Sleep(10 * time.Millisecond)
}

if err != nil {
return err
return fmt.Errorf("timeout waiting for ipc socket %s: %w", socketAddress, err)
}
defer conn.Close()

Expand Down Expand Up @@ -145,27 +158,51 @@ func createListener(socketAddress string, mustBeValid bool) (*net.UnixListener,
return listener, nil
}

// awaitMessage opens a new connection to socketAddress
// and waits for a given message
// AwaitMessage accepts a connection from the listener and waits for the
// expected IPC message. It implements a 10-second timeout to prevent
// the process from blocking indefinitely (Fixes #405)
func AwaitMessage(listener *net.UnixListener, expectedMessage IPCMessage) error {
timeout := 10 * time.Second
deadline := time.Now().Add(timeout)

// Set deadline for the initial connection (Accept)
if err := listener.SetDeadline(deadline); err != nil {
return fmt.Errorf("failed to set listener deadline: %w", err)
}

conn, err := listener.AcceptUnix()
if err != nil {
return err
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return fmt.Errorf("IPC handshake timeout: no connection received within %v", timeout)
}
return fmt.Errorf("failed to accept IPC connection: %w", err)
}
defer func() {
err = conn.Close()
if err != nil {
logrus.WithError(err).Error("failed to close connection")
if cerr := conn.Close(); cerr != nil {
logrus.WithError(cerr).Error("failed to close IPC connection")
}
}()

// Set deadline for the actual data transfer (Read)
if err := conn.SetDeadline(deadline); err != nil {
return fmt.Errorf("failed to set connection deadline: %w", err)
}

// io.ReadFull ensures we don't return early with a partial message
buf := make([]byte, len(expectedMessage))
n, err := conn.Read(buf)
if err != nil {
return fmt.Errorf("failed to read from socket: %w", err)
if _, err := io.ReadFull(conn, buf); err != nil {
if errors.Is(err, io.ErrUnexpectedEOF) {
return fmt.Errorf("connection closed before full message was received")
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return fmt.Errorf("IPC handshake timeout: message not received within %v", timeout)
}
return fmt.Errorf("failed to read from IPC socket: %w", err)
}
msg := string(buf[0:n])
if msg != string(expectedMessage) {
return fmt.Errorf("received unexpected message: %s (expected %s)", msg, expectedMessage)

if string(buf) != string(expectedMessage) {
return fmt.Errorf("received unexpected message: %q (expected %q)", string(buf), expectedMessage)
}

return nil
}