Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions internal/procutil/process_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,33 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
)

// IsExpectedProcess checks if the process at pid is running the expected binary.
// On Linux, reads /proc/<pid>/exe to verify the binary path. Returns false if
// the process does not exist or is running a different binary, preventing
// signals to recycled PIDs.
// IsExpectedProcess checks if the process at pid is running the expected
// binary. On Linux, reads /proc/<pid>/exe to verify the binary path. Returns
// false if the process does not exist or is running a different binary,
// preventing signals to recycled PIDs.
//
// When expectedBinary is an absolute path, the comparison is against the
// full resolved exe path — this is the strong guarantee. When
// expectedBinary is just a name, the comparison falls back to the base
// name; two unrelated binaries with the same base name on the same host
// would collide under the fallback, so callers should pass absolute paths
// when possible.
//
// The "(deleted)" suffix that the kernel appends when the underlying
// binary has been unlinked post-exec is stripped so that processes still
// running from a since-removed extract directory are correctly identified.
func IsExpectedProcess(pid int, expectedBinary string) bool {
exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pid))
if err != nil {
return false // Process gone or no permission
}
// Compare base names: the state may store just the binary name while
// /proc/pid/exe returns the full resolved path.
exePath = strings.TrimSuffix(exePath, " (deleted)")

if filepath.IsAbs(expectedBinary) {
return filepath.Clean(exePath) == filepath.Clean(expectedBinary)
}
return filepath.Base(exePath) == filepath.Base(expectedBinary)
}
24 changes: 21 additions & 3 deletions internal/procutil/process_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,34 @@ func TestIsExpectedProcess_Self(t *testing.T) {
}
}

func TestIsExpectedProcess_SelfBaseName(t *testing.T) {
func TestIsExpectedProcess_SelfBaseNameFallback(t *testing.T) {
// When expectedBinary is NOT absolute (just a bare name), the fallback
// base-name comparison applies.
pid := os.Getpid()

selfExe, err := os.Executable()
if err != nil {
t.Fatalf("failed to get self executable: %v", err)
}
baseName := filepath.Base(selfExe)
if !IsExpectedProcess(pid, "/some/other/path/"+baseName) {
t.Errorf("IsExpectedProcess with different dir but same base name should return true")
if !IsExpectedProcess(pid, baseName) {
t.Errorf("IsExpectedProcess with bare base name should match self")
}
}

func TestIsExpectedProcess_AbsolutePathMismatch(t *testing.T) {
// When expectedBinary IS absolute, a different directory with the same
// base name must NOT match. This is the strengthened guarantee against
// unrelated binaries with colliding base names.
pid := os.Getpid()

selfExe, err := os.Executable()
if err != nil {
t.Fatalf("failed to get self executable: %v", err)
}
baseName := filepath.Base(selfExe)
if IsExpectedProcess(pid, "/some/other/path/"+baseName) {
t.Errorf("absolute path mismatch must not match even with same base name")
}
}

Expand Down
28 changes: 25 additions & 3 deletions microvm.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,10 @@ func Run(ctx context.Context, imageRef string, opts ...Option) (*VM, error) {
slog.Warn("egress policy overrides firewall default action to Deny")
}
cfg.firewallDefaultAction = firewall.Deny
if cfg.netProvider == nil {
cfg.netProvider = hosted.NewProvider()
}
}

wireDefaultProvider(cfg)

// 1. Preflight checks.
{
ctx, span := tracer.Start(ctx, "microvm.Preflight")
Expand Down Expand Up @@ -312,6 +311,7 @@ func Run(ctx context.Context, imageRef string, opts ...Option) (*VM, error) {
ls.State.Name = cfg.name
if pid, pidErr := pidFromID(handle.ID()); pidErr == nil {
ls.State.PID = pid
ls.State.PIDStartTime = time.Now().UTC()
} else {
slog.Warn("could not persist VM PID", "id", handle.ID(), "error", pidErr)
}
Expand Down Expand Up @@ -354,6 +354,21 @@ const (
staleTermPoll = 250 * time.Millisecond
)

// wireDefaultProvider auto-creates a hosted network provider when any
// firewall configuration (egress policy, static rules, or a non-Allow
// default action) is set but no provider was supplied explicitly. The
// default runner-side networking path does not enforce firewall rules,
// so without this the caller's deny-default would silently degrade to
// allow-all. No-op when a provider is already set.
func wireDefaultProvider(cfg *config) {
firewallConfigured := cfg.egressPolicy != nil ||
len(cfg.firewallRules) > 0 ||
cfg.firewallDefaultAction != firewall.Allow
if firewallConfigured && cfg.netProvider == nil {
cfg.netProvider = hosted.NewProvider()
}
}

func cleanDataDir(cfg *config) error {
if cfg.dataDir == "" {
return nil
Expand Down Expand Up @@ -409,6 +424,13 @@ func terminateStaleRunner(cfg *config) {
slog.Debug("stale runner already dead", "pid", st.PID)
return
}
if !cfg.processIsExpected(st.PID) {
// PID has been recycled onto an unrelated binary since we wrote
// the state file. Signalling it would kill the wrong process
// group (or fail silently if we lack permission). Bail out.
slog.Warn("stale PID does not match expected runner binary, skipping termination", "pid", st.PID)
return
}

// Use negative PID to signal the entire process group (PGID == PID
// because the runner starts with Setsid: true). This ensures any
Expand Down
92 changes: 92 additions & 0 deletions microvm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,20 @@ import (
"github.com/stacklok/go-microvm/hypervisor"
"github.com/stacklok/go-microvm/image"
"github.com/stacklok/go-microvm/internal/testutil"
propnet "github.com/stacklok/go-microvm/net"
"github.com/stacklok/go-microvm/net/firewall"
"github.com/stacklok/go-microvm/preflight"
"github.com/stacklok/go-microvm/state"
)

// sentinelProvider is a minimal net.Provider used by tests to assert that
// a caller-supplied provider survives auto-wiring without being replaced.
type sentinelProvider struct{}

func (*sentinelProvider) Start(_ context.Context, _ propnet.Config) error { return nil }
func (*sentinelProvider) SocketPath() string { return "" }
func (*sentinelProvider) Stop() {}

// --- Pure function tests ---

func TestBuildInitConfig_NilOCIConfig(t *testing.T) {
Expand Down Expand Up @@ -660,6 +669,53 @@ func TestBuildNetConfig_Empty(t *testing.T) {

// --- Egress validation tests ---

func TestWireDefaultProvider(t *testing.T) {
t.Parallel()

t.Run("no firewall config leaves provider nil", func(t *testing.T) {
t.Parallel()
cfg := defaultConfig()
wireDefaultProvider(cfg)
assert.Nil(t, cfg.netProvider)
})

t.Run("egress policy auto-wires provider", func(t *testing.T) {
t.Parallel()
cfg := defaultConfig()
cfg.egressPolicy = &EgressPolicy{}
wireDefaultProvider(cfg)
assert.NotNil(t, cfg.netProvider)
})

t.Run("firewall rules alone auto-wire provider", func(t *testing.T) {
t.Parallel()
cfg := defaultConfig()
cfg.firewallRules = []firewall.Rule{{Direction: firewall.Egress, Action: firewall.Allow}}
wireDefaultProvider(cfg)
assert.NotNil(t, cfg.netProvider,
"firewall-only config must auto-wire a provider; otherwise rules go unenforced")
})

t.Run("deny default alone auto-wires provider", func(t *testing.T) {
t.Parallel()
cfg := defaultConfig()
cfg.firewallDefaultAction = firewall.Deny
wireDefaultProvider(cfg)
assert.NotNil(t, cfg.netProvider,
"deny-default config must auto-wire a provider to actually deny")
})

t.Run("explicit provider is not overwritten", func(t *testing.T) {
t.Parallel()
existing := &sentinelProvider{}
cfg := defaultConfig()
cfg.netProvider = existing
cfg.firewallDefaultAction = firewall.Deny
wireDefaultProvider(cfg)
assert.Same(t, existing, cfg.netProvider)
})
}

func TestRun_EgressPolicy_EmptyHosts_DenyAll(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -862,6 +918,7 @@ func TestTerminateStaleRunner_AliveProcess_GracefulExit(t *testing.T) {
// (after SIGTERM + first poll).
return aliveCount <= 1
}
cfg.processIsExpected = func(_ int) bool { return true }

terminateStaleRunner(cfg)

Expand Down Expand Up @@ -899,6 +956,7 @@ func TestTerminateStaleRunner_AliveProcess_RequiresKill(t *testing.T) {
}
// Process never exits on its own.
cfg.processAlive = func(_ int) bool { return true }
cfg.processIsExpected = func(_ int) bool { return true }

terminateStaleRunner(cfg)

Expand Down Expand Up @@ -941,6 +999,7 @@ func TestTerminateStaleRunner_SendsToProcessGroup(t *testing.T) {
aliveCount++
return aliveCount <= 1
}
cfg.processIsExpected = func(_ int) bool { return true }

terminateStaleRunner(cfg)

Expand All @@ -950,6 +1009,38 @@ func TestTerminateStaleRunner_SendsToProcessGroup(t *testing.T) {
assert.Equal(t, -55555, receivedPIDs[0], "killProcess should receive negative PID for process group")
}

func TestTerminateStaleRunner_RecycledPID_Skipped(t *testing.T) {
t.Parallel()

// The state file points at a live PID, but processIsExpected says the
// binary at that PID is not the runner (as if the kernel had recycled
// the PID onto an unrelated process since state was written). The
// function must refuse to signal it.
dataDir := t.TempDir()

mgr := state.NewManager(dataDir)
ls, err := mgr.LoadAndLock(context.Background())
require.NoError(t, err)
ls.State.Active = true
ls.State.PID = 77777
require.NoError(t, ls.Save())
ls.Release()

cfg := defaultConfig()
cfg.dataDir = dataDir

var killCalled bool
cfg.killProcess = func(_ int, _ syscall.Signal) error {
killCalled = true
return nil
}
cfg.processAlive = func(_ int) bool { return true }
cfg.processIsExpected = func(_ int) bool { return false }

terminateStaleRunner(cfg)
assert.False(t, killCalled, "must not signal a recycled PID belonging to an unrelated binary")
}

func TestTerminateStaleRunner_PID1_Skipped(t *testing.T) {
t.Parallel()

Expand All @@ -973,6 +1064,7 @@ func TestTerminateStaleRunner_PID1_Skipped(t *testing.T) {
return nil
}
cfg.processAlive = func(_ int) bool { return true }
cfg.processIsExpected = func(_ int) bool { return true }

terminateStaleRunner(cfg)
assert.False(t, killCalled, "should not attempt to kill PID 1")
Expand Down
10 changes: 10 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/stacklok/go-microvm/hypervisor"
"github.com/stacklok/go-microvm/image"
"github.com/stacklok/go-microvm/internal/procutil"
"github.com/stacklok/go-microvm/net"
"github.com/stacklok/go-microvm/net/firewall"
"github.com/stacklok/go-microvm/preflight"
Expand Down Expand Up @@ -103,6 +104,7 @@ type config struct {
stat func(string) (os.FileInfo, error)
killProcess func(pid int, sig syscall.Signal) error
processAlive func(pid int) bool
processIsExpected func(pid int) bool
}

func defaultConfig() *config {
Expand All @@ -126,9 +128,17 @@ func defaultConfig() *config {
}
return proc.Signal(syscall.Signal(0)) == nil
},
processIsExpected: func(pid int) bool {
return procutil.IsExpectedProcess(pid, runnerBinaryName)
},
}
}

// runnerBinaryName is the base name of the runner executable — used by
// the default processIsExpected check to distinguish the go-microvm
// runner from an unrelated process that happens to be at the same PID.
const runnerBinaryName = "go-microvm-runner"

func defaultDataDir() string {
if dir := os.Getenv("GO_MICROVM_DATA_DIR"); dir != "" {
return dir
Expand Down
22 changes: 18 additions & 4 deletions runner/process_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,31 @@ func TestIsExpectedProcess_Self(t *testing.T) {
}
}

func TestIsExpectedProcess_SelfBaseName(t *testing.T) {
// Should match by base name even if full paths differ.
func TestIsExpectedProcess_BaseNameFallbackForRelative(t *testing.T) {
// A relative/bare binary name still matches by base name — that is the
// documented fallback. Absolute mismatches must no longer pass.
pid := os.Getpid()

selfExe, err := os.Executable()
if err != nil {
t.Fatalf("failed to get self executable: %v", err)
}
baseName := selfExe[len(selfExe)-len("runner.test"):] // last component
if !isExpectedProcess(pid, "/some/other/path/"+baseName) {
t.Errorf("isExpectedProcess with different dir but same base name should return true")
if !isExpectedProcess(pid, baseName) {
t.Errorf("isExpectedProcess with bare base name should match self")
}
}

func TestIsExpectedProcess_AbsolutePathMismatchFails(t *testing.T) {
pid := os.Getpid()

selfExe, err := os.Executable()
if err != nil {
t.Fatalf("failed to get self executable: %v", err)
}
baseName := selfExe[len(selfExe)-len("runner.test"):]
if isExpectedProcess(pid, "/some/other/path/"+baseName) {
t.Errorf("absolute path with different dir must not match even when base name matches")
}
}

Expand Down
7 changes: 7 additions & 0 deletions state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ type State struct {
// PID is the process ID of the VM runner, or 0 if not running.
PID int `json:"pid,omitempty"`

// PIDStartTime records wall-clock time when PID was recorded. Used
// to disambiguate a recycled PID from the original runner in contexts
// where /proc/PID/exe comparison is unavailable (e.g. macOS) or as
// belt-and-suspenders alongside the exe-path check on Linux.
// Zero time on state files written before this field was introduced.
PIDStartTime time.Time `json:"pid_start_time,omitempty"`

// CreatedAt is the time the VM state was first created.
CreatedAt time.Time `json:"created_at"`
}
Expand Down
Loading
Loading