diff --git a/packages/orchestrator/pkg/sandbox/block/tracker.go b/packages/orchestrator/pkg/sandbox/block/tracker.go new file mode 100644 index 0000000000..18b750263b --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/block/tracker.go @@ -0,0 +1,76 @@ +package block + +import ( + "sync" + + "github.com/RoaringBitmap/roaring/v2" +) + +type State uint8 + +const ( + // NotPresent: fall back to the previous layer. + NotPresent State = iota + // Dirty: this layer holds materialized data. + Dirty + // Zero: known-zero; no need to consult the previous layer. + Zero +) + +type Tracker struct { + mu sync.RWMutex + dirty, zero *roaring.Bitmap +} + +func NewTracker() *Tracker { + return &Tracker{ + dirty: roaring.New(), + zero: roaring.New(), + } +} + +// SetRange sets state for indices in [start, end). The index math.MaxUint32 +// is unaddressable: end is the half-open upper bound and capped at MaxUint32. +func (t *Tracker) SetRange(start, end uint32, state State) { + if end <= start { + return + } + + t.mu.Lock() + defer t.mu.Unlock() + + s, e := uint64(start), uint64(end) + switch state { + case Dirty: + t.dirty.AddRange(s, e) + t.zero.RemoveRange(s, e) + case Zero: + t.zero.AddRange(s, e) + t.dirty.RemoveRange(s, e) + case NotPresent: + t.dirty.RemoveRange(s, e) + t.zero.RemoveRange(s, e) + } +} + +func (t *Tracker) Get(idx uint32) State { + t.mu.RLock() + defer t.mu.RUnlock() + + switch { + case t.dirty.Contains(idx): + return Dirty + case t.zero.Contains(idx): + return Zero + default: + return NotPresent + } +} + +// Export returns clones of the dirty and zero bitmaps. +func (t *Tracker) Export() (dirty, zero *roaring.Bitmap) { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.dirty.Clone(), t.zero.Clone() +} diff --git a/packages/orchestrator/pkg/sandbox/block/tracker_test.go b/packages/orchestrator/pkg/sandbox/block/tracker_test.go new file mode 100644 index 0000000000..a681b7e4e8 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/block/tracker_test.go @@ -0,0 +1,79 @@ +package block + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTracker(t *testing.T) { + t.Parallel() + + t.Run("transitions", func(t *testing.T) { + t.Parallel() + s := NewTracker() + + s.SetRange(0, 1, Dirty) + assert.Equal(t, Dirty, s.Get(0)) + + s.SetRange(0, 1, Zero) + assert.Equal(t, Zero, s.Get(0), "dirty→zero should flip the page") + bmDirty, bmZero := s.Export() + assert.False(t, bmDirty.Contains(0), "dirty→zero must clear dirty bitmap") + assert.True(t, bmZero.Contains(0), "dirty→zero must add to zero bitmap") + + s.SetRange(0, 1, Dirty) + assert.Equal(t, Dirty, s.Get(0), "zero→dirty should flip back") + + s.SetRange(0, 1, NotPresent) + assert.Equal(t, NotPresent, s.Get(0), "→not-present must clear") + bmDirty, bmZero = s.Export() + assert.False(t, bmDirty.Contains(0)) + assert.False(t, bmZero.Contains(0)) + + s.SetRange(0, 1, Dirty) + s.SetRange(0, 1, Dirty) + assert.Equal(t, Dirty, s.Get(0), "dirty→dirty is idempotent") + }) + + t.Run("partial overlap moves only the overlapping pages", func(t *testing.T) { + t.Parallel() + s := NewTracker() + + s.SetRange(0, 10, Dirty) + s.SetRange(3, 7, Zero) + + for i := range uint32(3) { + assert.Equal(t, Dirty, s.Get(i), "page %d outside overlap stays dirty", i) + } + for i := range uint32(4) { + page := i + 3 + assert.Equal(t, Zero, s.Get(page), "page %d in overlap moves to zero", page) + } + for i := range uint32(3) { + page := i + 7 + assert.Equal(t, Dirty, s.Get(page), "page %d outside overlap stays dirty", page) + } + }) + + t.Run("empty and inverted ranges are no-ops", func(t *testing.T) { + t.Parallel() + s := NewTracker() + + s.SetRange(5, 5, Dirty) + s.SetRange(7, 3, Zero) + bmDirty, bmZero := s.Export() + assert.True(t, bmDirty.IsEmpty()) + assert.True(t, bmZero.IsEmpty()) + }) + + t.Run("Export returns clones", func(t *testing.T) { + t.Parallel() + s := NewTracker() + + s.SetRange(0, 1, Dirty) + bmDirty, _ := s.Export() + bmDirty.Add(42) + assert.Equal(t, NotPresent, s.Get(42), "mutating export must not leak into tracker") + }) +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/wire.go b/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/wire.go index 7a475b775d..7cc7111955 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/wire.go +++ b/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness/wire.go @@ -19,7 +19,7 @@ type BootstrapArgs struct { type BootstrapReply struct{} -// PageStateEntry is the wire form of the parent package's pageState enum. +// PageStateEntry is the wire form of a block.State for a single page offset. type PageStateEntry struct { State uint8 Offset uint64 diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go index 443e5cef83..045c68f586 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness" ) @@ -76,7 +77,7 @@ func (h *testHandler) pageStates() (handlerPageStates, error) { var states handlerPageStates for _, e := range entries { - if pageState(e.State) == faulted { + if block.State(e.State) == block.Dirty { states.faulted = append(states.faulted, uint(e.Offset)) } } diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/page_tracker.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/page_tracker.go deleted file mode 100644 index da76d310a8..0000000000 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/page_tracker.go +++ /dev/null @@ -1,33 +0,0 @@ -package userfaultfd - -import "sync" - -type pageState uint8 - -const ( - missing pageState = iota - faulted -) - -type pageTracker struct { - pageSize uintptr - - m map[uintptr]pageState - mu sync.RWMutex -} - -func newPageTracker(pageSize uintptr) *pageTracker { - return &pageTracker{ - pageSize: pageSize, - m: make(map[uintptr]pageState), - } -} - -func (pt *pageTracker) setState(start, end uintptr, state pageState) { - pt.mu.Lock() - defer pt.mu.Unlock() - - for addr := start; addr < end; addr += pt.pageSize { - pt.m[addr] = state - } -} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/rpc_services_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/rpc_services_test.go index 380e9c1d2f..d6a01b8f2e 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/rpc_services_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/rpc_services_test.go @@ -10,6 +10,9 @@ import ( "os" "sync" + "github.com/RoaringBitmap/roaring/v2" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/fdexit" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness" @@ -186,24 +189,24 @@ func (p *Paging) Resume(_ *testharness.Empty, _ *testharness.Empty) error { } // pageStateEntries returns a wire-format snapshot of pageTracker. -// settleRequests.Lock drains fault workers (mirrors PrefetchData); -// pageTracker.mu.RLock is defensive against a future REMOVE writer -// that mutates pageTracker.m outside settleRequests. +// settleRequests.Lock drains fault workers (mirrors PrefetchData) so +// the snapshot is consistent w.r.t. concurrent installs. func (u *Userfaultfd) pageStateEntries() ([]testharness.PageStateEntry, error) { u.settleRequests.Lock() defer u.settleRequests.Unlock() - u.pageTracker.mu.RLock() - defer u.pageTracker.mu.RUnlock() - - entries := make([]testharness.PageStateEntry, 0, len(u.pageTracker.m)) - for addr, state := range u.pageTracker.m { - offset, err := u.ma.GetOffset(addr) - if err != nil { - return nil, fmt.Errorf("address %#x not in mapping: %w", addr, err) + bmDirty, bmZero := u.pageTracker.Export() + entries := make([]testharness.PageStateEntry, 0, bmDirty.GetCardinality()+bmZero.GetCardinality()) + emit := func(bm *roaring.Bitmap, state block.State) { + for _, idx := range bm.ToArray() { + entries = append(entries, testharness.PageStateEntry{ + State: uint8(state), + Offset: uint64(idx) * uint64(u.pageSize), + }) } - entries = append(entries, testharness.PageStateEntry{State: uint8(state), Offset: uint64(offset)}) } + emit(bmDirty, block.Dirty) + emit(bmZero, block.Zero) return entries, nil } diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go index 43b9fee492..ec432d51e5 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go @@ -22,6 +22,7 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/fdexit" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/userfaultfd") @@ -52,7 +53,7 @@ type Userfaultfd struct { src block.Slicer ma *memory.Mapping pageSize uintptr - pageTracker *pageTracker + pageTracker *block.Tracker // We use the settleRequests to guard the pageTracker so we can access a consistent state of the pageTracker after the requests are finished. settleRequests sync.RWMutex @@ -92,7 +93,7 @@ func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logge fd: Fd(fd), src: src, pageSize: uintptr(blockSize), - pageTracker: newPageTracker(uintptr(blockSize)), + pageTracker: block.NewTracker(), prefetchTracker: block.NewPrefetchTracker(blockSize), ma: m, logger: logger, @@ -418,7 +419,8 @@ retryLoop: return fmt.Errorf("failed uffdio copy: %w", joinedErr) } - u.pageTracker.setState(addr, addr+u.pageSize, faulted) + idx := uint32(header.BlockIdx(offset, int64(u.pageSize))) + u.pageTracker.SetRange(idx, idx+1, block.Dirty) u.prefetchTracker.Add(offset, accessType) return nil