Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 57 additions & 15 deletions table/sorting.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const (

var (
ErrInvalidSortOrderID = errors.New("invalid sort order ID")
ErrInvalidSortSourceID = errors.New("invalid sort source ID")
ErrInvalidTransform = errors.New("invalid transform, must be a valid transform string or a transform object")
ErrInvalidSortDirection = errors.New("invalid sort direction, must be 'asc' or 'desc'")
ErrInvalidNullOrder = errors.New("invalid null order, must be 'nulls-first' or 'nulls-last'")
Expand Down Expand Up @@ -133,10 +134,13 @@ func (s *SortField) UnmarshalJSON(b []byte) error {
return fmt.Errorf("%w: failed to unmarshal sort field", err)
}

if _, ok := raw["source-id"]; ok {
if _, ok := raw["source-ids"]; ok {
return errors.New("sort field cannot contain both source-id and source-ids")
}
_, hasSourceID := raw["source-id"]
_, hasSourceIDs := raw["source-ids"]
if hasSourceID && hasSourceIDs {
return errors.New("sort field cannot contain both source-id and source-ids")
}
if !hasSourceID && !hasSourceIDs {
return fmt.Errorf("%w: sort field must contain source-id or source-ids", ErrInvalidSortSourceID)
}

aux := struct {
Expand All @@ -154,9 +158,15 @@ func (s *SortField) UnmarshalJSON(b []byte) error {
s.Direction = aux.Direction
s.NullOrder = aux.NullOrder

if len(aux.SourceIDs) > 0 {
if hasSourceIDs {
if err := validateSortSourceIDs(aux.SourceIDs); err != nil {
return err
}
s.SourceIDs = aux.SourceIDs
} else {
if err := validateSortSourceID(aux.SourceID); err != nil {
return err
}
s.SourceIDs = []int{aux.SourceID}
}

Expand All @@ -180,6 +190,28 @@ func (s *SortField) UnmarshalJSON(b []byte) error {
return nil
}

func validateSortSourceID(id int) error {
if id <= 0 {
return fmt.Errorf("%w: source ID must be positive: %d", ErrInvalidSortSourceID, id)
}

return nil
}

func validateSortSourceIDs(ids []int) error {
if len(ids) == 0 {
return fmt.Errorf("%w: source-ids must not be empty", ErrInvalidSortSourceID)
}

for _, id := range ids {
if err := validateSortSourceID(id); err != nil {
return err
}
}

return nil
}

const (
InitialSortOrderID = 1
UnsortedSortOrderID = 0
Expand Down Expand Up @@ -271,6 +303,9 @@ func NewSortOrder(orderID int, fields []SortField) (SortOrder, error) {
fields = []SortField{}
}
for idx, field := range fields {
if err := validateSortSourceIDs(field.SourceIDs); err != nil {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validates that every entry in SourceIDs is present and positive - good. The remaining gap is existence in the schema: SortOrder.CheckCompatibility only looks up field.SourceID(), i.e. the first source id (sorting.go:298), so a multi-arg sort field such as SourceIDs: [1, 999] still passes when 999 isn't in the schema. Worth iterating all of field.SourceIDs there (looking each up via FindFieldByID and checking it's primitive), with a regression case pairing a valid id and a nonexistent one.

return SortOrder{}, fmt.Errorf("%w: sort field at index %d", err, idx)
}
if field.Transform == nil {
return SortOrder{}, fmt.Errorf("%w: sort field at index %d has no transform", ErrInvalidTransform, idx)
}
Expand All @@ -295,21 +330,28 @@ func (s *SortOrder) CheckCompatibility(schema *iceberg.Schema) error {
}

for _, field := range s.fields {
f, ok := schema.FindFieldByID(field.SourceID())
if !ok {
return fmt.Errorf("sort field with source id %d not found in schema", field.SourceID())
if field.Transform == nil {
return fmt.Errorf("%w: sort field with source id %d has no transform", ErrInvalidTransform, field.SourceID())
}

if _, ok := f.Type.(iceberg.PrimitiveType); !ok {
return fmt.Errorf("cannot sort by non-primitive source field: %s", f.Type.Type())
}
var firstField iceberg.NestedField
for idx, sourceID := range field.SourceIDs {
f, ok := schema.FindFieldByID(sourceID)
if !ok {
return fmt.Errorf("sort field with source id %d not found in schema", sourceID)
}

if field.Transform == nil {
return fmt.Errorf("%w: sort field with source id %d has no transform", ErrInvalidTransform, field.SourceID())
if _, ok := f.Type.(iceberg.PrimitiveType); !ok {
return fmt.Errorf("cannot sort by non-primitive source field: %s", f.Type.Type())
}

if idx == 0 {
firstField = f
}
}

if !field.Transform.CanTransform(f.Type) {
return fmt.Errorf("invalid source type %s for transform %s", f.Type.Type(), field.Transform)
if !field.Transform.CanTransform(firstField.Type) {
return fmt.Errorf("invalid source type %s for transform %s", firstField.Type.Type(), field.Transform)
}
}

Expand Down
96 changes: 96 additions & 0 deletions table/sorting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,45 @@ func TestNewSortOrderRejectsNilTransform(t *testing.T) {
assert.Contains(t, err.Error(), "has no transform")
}

func TestNewSortOrderRejectsInvalidSourceIDs(t *testing.T) {
for _, tt := range []struct {
name string
sourceIDs []int
}{
{
name: "missing",
sourceIDs: nil,
},
{
name: "empty",
sourceIDs: []int{},
},
{
name: "zero",
sourceIDs: []int{0},
},
{
name: "negative",
sourceIDs: []int{-1},
},
{
name: "multi arg with zero",
sourceIDs: []int{1, 0},
},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := table.NewSortOrder(1, []table.SortField{{
SourceIDs: tt.sourceIDs,
Transform: iceberg.IdentityTransform{},
NullOrder: table.NullsFirst,
Direction: table.SortASC,
}})
require.Error(t, err)
assert.ErrorIs(t, err, table.ErrInvalidSortSourceID)
})
}
}

func TestNewSortOrderAcceptsValidTransform(t *testing.T) {
sortOrder, err := table.NewSortOrder(1, []table.SortField{{
SourceIDs: []int{19},
Expand All @@ -92,6 +131,23 @@ func TestSortOrderCheckCompatibilityWithValidTransform(t *testing.T) {
require.NoError(t, sortOrder.CheckCompatibility(schema))
}

func TestSortOrderCheckCompatibilityRejectsInvalidMultiArgSourceID(t *testing.T) {
schema := iceberg.NewSchema(0,
iceberg.NestedField{ID: 19, Name: "id", Type: iceberg.PrimitiveTypes.Int64, Required: true},
)
sortOrder, err := table.NewSortOrder(1, []table.SortField{{
SourceIDs: []int{19, 999},
Transform: iceberg.IdentityTransform{},
NullOrder: table.NullsFirst,
Direction: table.SortASC,
}})
require.NoError(t, err)

err = sortOrder.CheckCompatibility(schema)
require.Error(t, err)
assert.ErrorContains(t, err, "sort field with source id 999 not found in schema")
}

func TestUnmarshalSortOrderDefaults(t *testing.T) {
var order table.SortOrder
require.NoError(t, json.Unmarshal([]byte(`{"fields": []}`), &order))
Expand Down Expand Up @@ -169,6 +225,46 @@ func TestSortFieldMultiArgSourceIDs(t *testing.T) {
assert.Contains(t, err.Error(), "cannot contain both source-id and source-ids")
})

t.Run("unmarshal rejects missing source id", func(t *testing.T) {
jsonData := `{"transform": "identity", "direction": "asc", "null-order": "nulls-first"}`
var field table.SortField
err := json.Unmarshal([]byte(jsonData), &field)
require.Error(t, err)
assert.ErrorIs(t, err, table.ErrInvalidSortSourceID)
})

t.Run("unmarshal rejects zero source-id", func(t *testing.T) {
jsonData := `{"source-id": 0, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`
var field table.SortField
err := json.Unmarshal([]byte(jsonData), &field)
require.Error(t, err)
assert.ErrorIs(t, err, table.ErrInvalidSortSourceID)
})

t.Run("unmarshal rejects negative source-id", func(t *testing.T) {
jsonData := `{"source-id": -1, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`
var field table.SortField
err := json.Unmarshal([]byte(jsonData), &field)
require.Error(t, err)
assert.ErrorIs(t, err, table.ErrInvalidSortSourceID)
})

t.Run("unmarshal rejects empty source-ids", func(t *testing.T) {
jsonData := `{"source-ids": [], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`
var field table.SortField
err := json.Unmarshal([]byte(jsonData), &field)
require.Error(t, err)
assert.ErrorIs(t, err, table.ErrInvalidSortSourceID)
})

t.Run("unmarshal rejects non-positive source-ids member", func(t *testing.T) {
jsonData := `{"source-ids": [1, 0], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}`
var field table.SortField
err := json.Unmarshal([]byte(jsonData), &field)
require.Error(t, err)
assert.ErrorIs(t, err, table.ErrInvalidSortSourceID)
})

t.Run("marshal multi-arg round-trip", func(t *testing.T) {
field := table.SortField{
SourceIDs: []int{2, 3},
Expand Down
Loading