Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package userfaultfd

import "sync"

// deferredFaults collects pagefaults that couldn't be handled (EAGAIN)
// and need to be retried on the next poll iteration. Safe for concurrent push.
type deferredFaults struct {
mu sync.Mutex
pf []*UffdPagefault
}

func (d *deferredFaults) push(pf *UffdPagefault) {
d.mu.Lock()
d.pf = append(d.pf, pf)
d.mu.Unlock()
}

// drain returns all accumulated pagefaults and resets the internal list.
func (d *deferredFaults) drain() []*UffdPagefault {
d.mu.Lock()
out := d.pf
d.pf = nil
d.mu.Unlock()

return out
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func configureApi(f Fd, pagesize uint64) error {
}

features |= UFFD_FEATURE_WP_ASYNC
features |= UFFD_FEATURE_EVENT_REMOVE

api := newUffdioAPI(UFFD_API, features)
ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_API, uintptr(unsafe.Pointer(&api)))
Expand All @@ -42,9 +43,9 @@ func configureApi(f Fd, pagesize uint64) error {
}

// unregister tears down the UFFD registration over [addr, addr+size).
// Used in test cleanup so that any in-flight REMOVE events the kernel
// may have queued (once UFFD_FEATURE_EVENT_REMOVE is enabled in a
// follow-up) don't keep munmap blocked on un-acked events.
// Used in test cleanup so in-flight REMOVE events queued by the kernel
// (configureApi enables UFFD_FEATURE_EVENT_REMOVE on this branch) don't
// keep munmap blocked on un-acked events.
func unregister(f Fd, addr uintptr, size uint64) error {
r := newUffdioRange(CULong(addr), CULong(size))

Expand Down
163 changes: 146 additions & 17 deletions packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ import (
"bytes"
"context"
"fmt"
"slices"
"io"
"net/rpc"
"os/exec"
"sync"
"testing"
"time"
"unsafe"

"github.com/RoaringBitmap/roaring/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"

"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils"
)
Expand All @@ -26,68 +30,167 @@ type testConfig struct {
operations []operation
// alwaysWP makes the handler copy with UFFDIO_COPY_MODE_WP for all faults.
alwaysWP bool
// gated enables pause/resume control over the handler's serve loop.
gated bool
// barriers wires up the per-worker fault hooks in the child
// (used by race tests). Off by default so the worker hot path
// stays a single nil-pointer load + branch in non-race tests.
barriers bool
// sourcePatcher, if non-nil, is invoked on the random source data
// AFTER it's generated but BEFORE it's written to the on-disk
// content file the child reads. Tests can use this to plant
// deterministic sentinel bytes in the source so the post-test
// assertion can distinguish "post-fix zero-fault" from "pre-fix
// UFFDIO_COPY of stale src bytes" without depending on the
// happenstance value of randomly-generated bytes.
sourcePatcher func([]byte)
}

type operationMode uint32

const (
operationModeRead operationMode = 1 << iota
operationModeWrite
operationModeRemove
operationModeServePause
operationModeServeResume
// operationModeSleep pauses for a short duration to let async goroutines
// enter their blocking syscalls before proceeding.
operationModeSleep
)

type operation struct {
// Offset in bytes. Must be smaller than the (numberOfPages-1) * pagesize as it reads a page and it must be aligned to the pagesize from the testConfig.
offset int64
mode operationMode
// async runs the operation in a background goroutine.
async bool
}

// handlerPageStates is a snapshot of the pageTracker grouped by state. It
// lets tests assert on the set of pages that the handler observed in each
// state, rather than a flat list of "accessed" offsets. Follow-up PRs can
// add more state-specific fields (e.g. removed) without touching the
// existing call sites.
// state, rather than a flat list of "accessed" offsets.
type handlerPageStates struct {
faulted []uint
removed []uint
}

// allAccessed returns the sorted union of offsets that the handler touched
// in any non-missing state. Tests that only care about "which pages did the
// handler see" can compare directly against this.
// in any non-missing state.
//
// pageStatesOnce already returns each per-state slice sorted, and a page
// pageStatesOnce returns each per-state slice already sorted, and a page
// has exactly one state at a time in pageTracker, so the per-state slices
// are disjoint. Follow-up PRs that add more state-specific fields should
// sorted-merge them here instead of reaching for a bitset — byte offsets
// make poor bit indices (a single hugepage offset would force ~1.8 MB of
// backing storage).
// are disjoint. We merge them with a simple sorted merge instead of a
// bitset — byte offsets make poor bit indices (a single hugepage offset
// would force ~1.8 MB of backing storage).
func (s handlerPageStates) allAccessed() []uint {
return slices.Clone(s.faulted)
out := make([]uint, 0, len(s.faulted)+len(s.removed))
i, j := 0, 0
for i < len(s.faulted) && j < len(s.removed) {
if s.faulted[i] <= s.removed[j] {
out = append(out, s.faulted[i])
i++
} else {
out = append(out, s.removed[j])
j++
}
}
out = append(out, s.faulted[i:]...)
out = append(out, s.removed[j:]...)

return out
}

type testHandler struct {
memoryArea *[]byte
pagesize uint64
data *MemorySlicer
// pageStatesOnce returns a per-state snapshot of the handler's pageTracker.
// It can only be called once.
// Backed by the PageStates RPC; callable any number of times.
// The "Once" suffix is kept for source-stability with the existing
// test sites.
pageStatesOnce func() (handlerPageStates, error)
mutex sync.Mutex
// servePause and serveResume gate the UFFD event loop in the child process.
// Tests use them to deterministically drain a batch of REMOVE events
// before more faults are processed.
servePause func() error
serveResume func() error

// client is the RPC channel to the child helper process.
client *rpc.Client
conn io.Closer
cmd *exec.Cmd

mutex sync.Mutex
}

// installFaultBarrier asks the child to park the next worker that
// hits `point` for `addr`. Returns a token that must be passed to
// waitFaultHeld and releaseFault.
func (h *testHandler) installFaultBarrier(_ context.Context, addr uintptr, point barrierPoint) (uint64, error) {
var reply FaultBarrierReply
err := h.client.Call("Service.InstallFaultBarrier", &FaultBarrierArgs{Addr: uint64(addr), Point: uint8(point)}, &reply)

return reply.Token, err
}

// waitFaultHeld blocks until the child reports that a worker has
// reached the barrier identified by token. The wait is bounded via
// context by issuing the call on a goroutine and racing it against
// ctx; net/rpc's Call doesn't take a context directly.
func (h *testHandler) waitFaultHeld(ctx context.Context, token uint64) error {
call := h.client.Go("Service.WaitFaultHeld", &TokenArgs{Token: token}, &Empty{}, nil)
select {
case <-call.Done:
return call.Error
case <-ctx.Done():
return ctx.Err()
}
}

// releaseFault releases a parked worker so it proceeds past the
// barrier.
func (h *testHandler) releaseFault(_ context.Context, token uint64) error {
return h.client.Call("Service.ReleaseFault", &TokenArgs{Token: token}, &Empty{})
}

func (h *testHandler) executeAll(t *testing.T, operations []operation) {
t.Helper()

var asyncErrors []chan error

for i, op := range operations {
if op.async {
errCh := make(chan error, 1)
asyncErrors = append(asyncErrors, errCh)

go func() {
errCh <- h.executeOperation(t.Context(), op)
}()

continue
}

err := h.executeOperation(t.Context(), op)
require.NoError(t, err, "step %d: %v at offset %d", i, op.mode, op.offset)
}

for _, errCh := range asyncErrors {
select {
case err := <-errCh:
require.NoError(t, err, "async operation")
case <-t.Context().Done():
t.Fatal("timed out waiting for async operation")
}
}
}

type pageExpectation uint8

const (
expectClean pageExpectation = iota // read-only: present + WP set
expectDirty // written: present + WP cleared
expectClean pageExpectation = iota // read-only: present + WP set
expectDirty // written: present + WP cleared
expectRemoved // removed: not present
)

func (h *testHandler) checkDirtiness(t *testing.T, operations []operation) {
Expand All @@ -100,17 +203,25 @@ func (h *testHandler) checkDirtiness(t *testing.T, operations []operation) {
memStart := uintptr(unsafe.Pointer(&(*h.memoryArea)[0]))

// Track the final expected state per offset by replaying operations in order.
// A remove after a read/write makes the page not present.
// A read/write after a remove makes it present again.
expected := make(map[uint]pageExpectation)

for _, op := range operations {
off := uint(op.offset)
switch op.mode {
case operationModeRead:
if _, seen := expected[off]; !seen {
curr, seen := expected[off]
// If we haven't seen this page before or the page
// has previously been removed then the page should be clean
// after this read operation.
if !seen || curr == expectRemoved {
expected[off] = expectClean
}
case operationModeWrite:
expected[off] = expectDirty
case operationModeRemove:
expected[off] = expectRemoved
}
}

Expand All @@ -119,6 +230,8 @@ func (h *testHandler) checkDirtiness(t *testing.T, operations []operation) {
require.NoError(t, err, "pagemap read at offset %d", off)

switch expect {
case expectRemoved:
assert.False(t, entry.IsPresent(), "removed page at offset %d should not be present", off)
case expectDirty:
assert.True(t, entry.IsPresent(), "written page at offset %d should be present", off)
assert.False(t, entry.IsWriteProtected(), "written page at offset %d should be dirty", off)
Expand All @@ -135,11 +248,27 @@ func (h *testHandler) executeOperation(ctx context.Context, op operation) error
return h.executeRead(ctx, op)
case operationModeWrite:
return h.executeWrite(ctx, op)
case operationModeRemove:
return h.executeRemove(op)
case operationModeServePause:
return h.servePause()
case operationModeServeResume:
return h.serveResume()
case operationModeSleep:
time.Sleep(50 * time.Millisecond)

return nil
default:
return fmt.Errorf("invalid operation mode: %d", op.mode)
}
}

func (h *testHandler) executeRemove(op operation) error {
page := (*h.memoryArea)[op.offset : op.offset+int64(h.pagesize)]

return unix.Madvise(page, unix.MADV_DONTNEED)
}

func (h *testHandler) executeRead(ctx context.Context, op operation) error {
readBytes := (*h.memoryArea)[op.offset : op.offset+int64(h.pagesize)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ type pageState uint8
const (
missing pageState = iota
faulted
removed
)

type pageTracker struct {
Expand All @@ -23,6 +24,18 @@ func newPageTracker(pageSize uintptr) *pageTracker {
}
}

func (pt *pageTracker) get(addr uintptr) pageState {
pt.mu.RLock()
defer pt.mu.RUnlock()

state, ok := pt.m[addr]
if !ok {
return missing
}

return state
}

func (pt *pageTracker) setState(start, end uintptr, state pageState) {
pt.mu.Lock()
defer pt.mu.Unlock()
Expand Down
Loading
Loading