Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions packages/orchestrator/pkg/sandbox/block/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package block

import (
"sync"

"github.com/RoaringBitmap/roaring/v2"
)

type State uint8

const (
// NotPresent: fall back to the previous layer.
NotPresent State = iota
// Dirty: this layer holds materialized data.
Dirty
// Zero: known-zero; no need to consult the previous layer.
Zero
)

type Tracker struct {
mu sync.RWMutex
dirty, zero *roaring.Bitmap
}

func NewTracker() *Tracker {
return &Tracker{
dirty: roaring.New(),
zero: roaring.New(),
}
}

// SetRange sets state for indices in [start, end). The index math.MaxUint32
// is unaddressable: end is the half-open upper bound and capped at MaxUint32.
func (t *Tracker) SetRange(start, end uint32, state State) {
if end <= start {
return
}

t.mu.Lock()
defer t.mu.Unlock()

s, e := uint64(start), uint64(end)
switch state {
case Dirty:
t.dirty.AddRange(s, e)
t.zero.RemoveRange(s, e)
case Zero:
t.zero.AddRange(s, e)
t.dirty.RemoveRange(s, e)
case NotPresent:
t.dirty.RemoveRange(s, e)
t.zero.RemoveRange(s, e)
}
}

func (t *Tracker) Get(idx uint32) State {
t.mu.RLock()
defer t.mu.RUnlock()

switch {
case t.dirty.Contains(idx):
return Dirty
case t.zero.Contains(idx):
return Zero
default:
return NotPresent
}
}

// Export returns clones of the dirty and zero bitmaps.
func (t *Tracker) Export() (dirty, zero *roaring.Bitmap) {
t.mu.RLock()
defer t.mu.RUnlock()

return t.dirty.Clone(), t.zero.Clone()
}
79 changes: 79 additions & 0 deletions packages/orchestrator/pkg/sandbox/block/tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package block

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestTracker(t *testing.T) {
t.Parallel()

t.Run("transitions", func(t *testing.T) {
t.Parallel()
s := NewTracker()

s.SetRange(0, 1, Dirty)
assert.Equal(t, Dirty, s.Get(0))

s.SetRange(0, 1, Zero)
assert.Equal(t, Zero, s.Get(0), "dirty→zero should flip the page")
bmDirty, bmZero := s.Export()
assert.False(t, bmDirty.Contains(0), "dirty→zero must clear dirty bitmap")
assert.True(t, bmZero.Contains(0), "dirty→zero must add to zero bitmap")

s.SetRange(0, 1, Dirty)
assert.Equal(t, Dirty, s.Get(0), "zero→dirty should flip back")

s.SetRange(0, 1, NotPresent)
assert.Equal(t, NotPresent, s.Get(0), "→not-present must clear")
bmDirty, bmZero = s.Export()
assert.False(t, bmDirty.Contains(0))
assert.False(t, bmZero.Contains(0))

s.SetRange(0, 1, Dirty)
s.SetRange(0, 1, Dirty)
assert.Equal(t, Dirty, s.Get(0), "dirty→dirty is idempotent")
})

t.Run("partial overlap moves only the overlapping pages", func(t *testing.T) {
t.Parallel()
s := NewTracker()

s.SetRange(0, 10, Dirty)
s.SetRange(3, 7, Zero)

for i := range uint32(3) {
assert.Equal(t, Dirty, s.Get(i), "page %d outside overlap stays dirty", i)
}
for i := range uint32(4) {
page := i + 3
assert.Equal(t, Zero, s.Get(page), "page %d in overlap moves to zero", page)
}
for i := range uint32(3) {
page := i + 7
assert.Equal(t, Dirty, s.Get(page), "page %d outside overlap stays dirty", page)
}
})

t.Run("empty and inverted ranges are no-ops", func(t *testing.T) {
t.Parallel()
s := NewTracker()

s.SetRange(5, 5, Dirty)
s.SetRange(7, 3, Zero)
bmDirty, bmZero := s.Export()
assert.True(t, bmDirty.IsEmpty())
assert.True(t, bmZero.IsEmpty())
})

t.Run("Export returns clones", func(t *testing.T) {
t.Parallel()
s := NewTracker()

s.SetRange(0, 1, Dirty)
bmDirty, _ := s.Export()
bmDirty.Add(42)
assert.Equal(t, NotPresent, s.Get(42), "mutating export must not leak into tracker")
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type BootstrapArgs struct {

type BootstrapReply struct{}

// PageStateEntry is the wire form of the parent package's pageState enum.
// PageStateEntry is the wire form of a block.State for a single page offset.
type PageStateEntry struct {
State uint8
Offset uint64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block"
"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils"
"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness"
)
Expand Down Expand Up @@ -76,7 +77,7 @@ func (h *testHandler) pageStates() (handlerPageStates, error) {

var states handlerPageStates
for _, e := range entries {
if pageState(e.State) == faulted {
if block.State(e.State) == block.Dirty {
states.faulted = append(states.faulted, uint(e.Offset))
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
"os"
"sync"

"github.com/RoaringBitmap/roaring/v2"

"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block"
"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/fdexit"
"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/memory"
"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils/testharness"
Expand Down Expand Up @@ -186,24 +189,24 @@ func (p *Paging) Resume(_ *testharness.Empty, _ *testharness.Empty) error {
}

// pageStateEntries returns a wire-format snapshot of pageTracker.
// settleRequests.Lock drains fault workers (mirrors PrefetchData);
// pageTracker.mu.RLock is defensive against a future REMOVE writer
// that mutates pageTracker.m outside settleRequests.
// settleRequests.Lock drains fault workers (mirrors PrefetchData) so
// the snapshot is consistent w.r.t. concurrent installs.
func (u *Userfaultfd) pageStateEntries() ([]testharness.PageStateEntry, error) {
u.settleRequests.Lock()
defer u.settleRequests.Unlock()

u.pageTracker.mu.RLock()
defer u.pageTracker.mu.RUnlock()

entries := make([]testharness.PageStateEntry, 0, len(u.pageTracker.m))
for addr, state := range u.pageTracker.m {
offset, err := u.ma.GetOffset(addr)
if err != nil {
return nil, fmt.Errorf("address %#x not in mapping: %w", addr, err)
bmDirty, bmZero := u.pageTracker.Export()
entries := make([]testharness.PageStateEntry, 0, bmDirty.GetCardinality()+bmZero.GetCardinality())
emit := func(bm *roaring.Bitmap, state block.State) {
for _, idx := range bm.ToArray() {
entries = append(entries, testharness.PageStateEntry{
State: uint8(state),
Offset: uint64(idx) * uint64(u.pageSize),
})
}
entries = append(entries, testharness.PageStateEntry{State: uint8(state), Offset: uint64(offset)})
}
emit(bmDirty, block.Dirty)
emit(bmZero, block.Zero)

return entries, nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/fdexit"
"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/memory"
"github.com/e2b-dev/infra/packages/shared/pkg/logger"
"github.com/e2b-dev/infra/packages/shared/pkg/storage/header"
)

var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/userfaultfd")
Expand Down Expand Up @@ -52,7 +53,7 @@ type Userfaultfd struct {
src block.Slicer
ma *memory.Mapping
pageSize uintptr
pageTracker *pageTracker
pageTracker *block.Tracker

// We use the settleRequests to guard the pageTracker so we can access a consistent state of the pageTracker after the requests are finished.
settleRequests sync.RWMutex
Expand Down Expand Up @@ -92,7 +93,7 @@ func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logge
fd: Fd(fd),
src: src,
pageSize: uintptr(blockSize),
pageTracker: newPageTracker(uintptr(blockSize)),
pageTracker: block.NewTracker(),
prefetchTracker: block.NewPrefetchTracker(blockSize),
ma: m,
logger: logger,
Expand Down Expand Up @@ -418,7 +419,8 @@ retryLoop:
return fmt.Errorf("failed uffdio copy: %w", joinedErr)
}

u.pageTracker.setState(addr, addr+u.pageSize, faulted)
idx := uint32(header.BlockIdx(offset, int64(u.pageSize)))
u.pageTracker.SetRange(idx, idx+1, block.Dirty)
u.prefetchTracker.Add(offset, accessType)

return nil
Expand Down
Loading