From 67c52fc3a6e831ca062e722601c0785474bc3728 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Wed, 24 Jun 2026 23:28:02 +0200 Subject: [PATCH 1/2] validate sort field source ids --- table/sorting.go | 45 +++++++++++++++++++++--- table/sorting_test.go | 79 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 5 deletions(-) diff --git a/table/sorting.go b/table/sorting.go index 09c198f10..cb4c82fd9 100644 --- a/table/sorting.go +++ b/table/sorting.go @@ -44,6 +44,7 @@ const ( var ( ErrInvalidSortOrderID = errors.New("invalid sort order ID") + ErrInvalidSortSourceID = errors.New("invalid sort source ID") ErrInvalidTransform = errors.New("invalid transform, must be a valid transform string or a transform object") ErrInvalidSortDirection = errors.New("invalid sort direction, must be 'asc' or 'desc'") ErrInvalidNullOrder = errors.New("invalid null order, must be 'nulls-first' or 'nulls-last'") @@ -133,10 +134,13 @@ func (s *SortField) UnmarshalJSON(b []byte) error { return fmt.Errorf("%w: failed to unmarshal sort field", err) } - if _, ok := raw["source-id"]; ok { - if _, ok := raw["source-ids"]; ok { - return errors.New("sort field cannot contain both source-id and source-ids") - } + _, hasSourceID := raw["source-id"] + _, hasSourceIDs := raw["source-ids"] + if hasSourceID && hasSourceIDs { + return errors.New("sort field cannot contain both source-id and source-ids") + } + if !hasSourceID && !hasSourceIDs { + return fmt.Errorf("%w: sort field must contain source-id or source-ids", ErrInvalidSortSourceID) } aux := struct { @@ -154,9 +158,15 @@ func (s *SortField) UnmarshalJSON(b []byte) error { s.Direction = aux.Direction s.NullOrder = aux.NullOrder - if len(aux.SourceIDs) > 0 { + if hasSourceIDs { + if err := validateSortSourceIDs(aux.SourceIDs); err != nil { + return err + } s.SourceIDs = aux.SourceIDs } else { + if err := validateSortSourceID(aux.SourceID); err != nil { + return err + } s.SourceIDs = []int{aux.SourceID} } @@ -180,6 +190,28 @@ func (s *SortField) UnmarshalJSON(b []byte) error { return nil } +func validateSortSourceID(id int) error { + if id <= 0 { + return fmt.Errorf("%w: source ID must be positive: %d", ErrInvalidSortSourceID, id) + } + + return nil +} + +func validateSortSourceIDs(ids []int) error { + if len(ids) == 0 { + return fmt.Errorf("%w: source-ids must not be empty", ErrInvalidSortSourceID) + } + + for _, id := range ids { + if err := validateSortSourceID(id); err != nil { + return err + } + } + + return nil +} + const ( InitialSortOrderID = 1 UnsortedSortOrderID = 0 @@ -271,6 +303,9 @@ func NewSortOrder(orderID int, fields []SortField) (SortOrder, error) { fields = []SortField{} } for idx, field := range fields { + if err := validateSortSourceIDs(field.SourceIDs); err != nil { + return SortOrder{}, fmt.Errorf("%w: sort field at index %d", err, idx) + } if field.Transform == nil { return SortOrder{}, fmt.Errorf("%w: sort field at index %d has no transform", ErrInvalidTransform, idx) } diff --git a/table/sorting_test.go b/table/sorting_test.go index baeae0936..202dbd1c9 100644 --- a/table/sorting_test.go +++ b/table/sorting_test.go @@ -66,6 +66,45 @@ func TestNewSortOrderRejectsNilTransform(t *testing.T) { assert.Contains(t, err.Error(), "has no transform") } +func TestNewSortOrderRejectsInvalidSourceIDs(t *testing.T) { + for _, tt := range []struct { + name string + sourceIDs []int + }{ + { + name: "missing", + sourceIDs: nil, + }, + { + name: "empty", + sourceIDs: []int{}, + }, + { + name: "zero", + sourceIDs: []int{0}, + }, + { + name: "negative", + sourceIDs: []int{-1}, + }, + { + name: "multi arg with zero", + sourceIDs: []int{1, 0}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + _, err := table.NewSortOrder(1, []table.SortField{{ + SourceIDs: tt.sourceIDs, + Transform: iceberg.IdentityTransform{}, + NullOrder: table.NullsFirst, + Direction: table.SortASC, + }}) + require.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidSortSourceID) + }) + } +} + func TestNewSortOrderAcceptsValidTransform(t *testing.T) { sortOrder, err := table.NewSortOrder(1, []table.SortField{{ SourceIDs: []int{19}, @@ -169,6 +208,46 @@ func TestSortFieldMultiArgSourceIDs(t *testing.T) { assert.Contains(t, err.Error(), "cannot contain both source-id and source-ids") }) + t.Run("unmarshal rejects missing source id", func(t *testing.T) { + jsonData := `{"transform": "identity", "direction": "asc", "null-order": "nulls-first"}` + var field table.SortField + err := json.Unmarshal([]byte(jsonData), &field) + require.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidSortSourceID) + }) + + t.Run("unmarshal rejects zero source-id", func(t *testing.T) { + jsonData := `{"source-id": 0, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}` + var field table.SortField + err := json.Unmarshal([]byte(jsonData), &field) + require.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidSortSourceID) + }) + + t.Run("unmarshal rejects negative source-id", func(t *testing.T) { + jsonData := `{"source-id": -1, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}` + var field table.SortField + err := json.Unmarshal([]byte(jsonData), &field) + require.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidSortSourceID) + }) + + t.Run("unmarshal rejects empty source-ids", func(t *testing.T) { + jsonData := `{"source-ids": [], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}` + var field table.SortField + err := json.Unmarshal([]byte(jsonData), &field) + require.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidSortSourceID) + }) + + t.Run("unmarshal rejects non-positive source-ids member", func(t *testing.T) { + jsonData := `{"source-ids": [1, 0], "transform": "identity", "direction": "asc", "null-order": "nulls-first"}` + var field table.SortField + err := json.Unmarshal([]byte(jsonData), &field) + require.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidSortSourceID) + }) + t.Run("marshal multi-arg round-trip", func(t *testing.T) { field := table.SortField{ SourceIDs: []int{2, 3}, From bb945dc270be95667070dbf305a28d46ceb7334a Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Fri, 26 Jun 2026 23:49:27 +0200 Subject: [PATCH 2/2] fix(table): validate multi-arg sort source ids --- table/sorting.go | 27 +++++++++++++++++---------- table/sorting_test.go | 17 +++++++++++++++++ 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/table/sorting.go b/table/sorting.go index cb4c82fd9..6cebc18a1 100644 --- a/table/sorting.go +++ b/table/sorting.go @@ -330,21 +330,28 @@ func (s *SortOrder) CheckCompatibility(schema *iceberg.Schema) error { } for _, field := range s.fields { - f, ok := schema.FindFieldByID(field.SourceID()) - if !ok { - return fmt.Errorf("sort field with source id %d not found in schema", field.SourceID()) + if field.Transform == nil { + return fmt.Errorf("%w: sort field with source id %d has no transform", ErrInvalidTransform, field.SourceID()) } - if _, ok := f.Type.(iceberg.PrimitiveType); !ok { - return fmt.Errorf("cannot sort by non-primitive source field: %s", f.Type.Type()) - } + var firstField iceberg.NestedField + for idx, sourceID := range field.SourceIDs { + f, ok := schema.FindFieldByID(sourceID) + if !ok { + return fmt.Errorf("sort field with source id %d not found in schema", sourceID) + } - if field.Transform == nil { - return fmt.Errorf("%w: sort field with source id %d has no transform", ErrInvalidTransform, field.SourceID()) + if _, ok := f.Type.(iceberg.PrimitiveType); !ok { + return fmt.Errorf("cannot sort by non-primitive source field: %s", f.Type.Type()) + } + + if idx == 0 { + firstField = f + } } - if !field.Transform.CanTransform(f.Type) { - return fmt.Errorf("invalid source type %s for transform %s", f.Type.Type(), field.Transform) + if !field.Transform.CanTransform(firstField.Type) { + return fmt.Errorf("invalid source type %s for transform %s", firstField.Type.Type(), field.Transform) } } diff --git a/table/sorting_test.go b/table/sorting_test.go index 202dbd1c9..bfc05c59d 100644 --- a/table/sorting_test.go +++ b/table/sorting_test.go @@ -131,6 +131,23 @@ func TestSortOrderCheckCompatibilityWithValidTransform(t *testing.T) { require.NoError(t, sortOrder.CheckCompatibility(schema)) } +func TestSortOrderCheckCompatibilityRejectsInvalidMultiArgSourceID(t *testing.T) { + schema := iceberg.NewSchema(0, + iceberg.NestedField{ID: 19, Name: "id", Type: iceberg.PrimitiveTypes.Int64, Required: true}, + ) + sortOrder, err := table.NewSortOrder(1, []table.SortField{{ + SourceIDs: []int{19, 999}, + Transform: iceberg.IdentityTransform{}, + NullOrder: table.NullsFirst, + Direction: table.SortASC, + }}) + require.NoError(t, err) + + err = sortOrder.CheckCompatibility(schema) + require.Error(t, err) + assert.ErrorContains(t, err, "sort field with source id 999 not found in schema") +} + func TestUnmarshalSortOrderDefaults(t *testing.T) { var order table.SortOrder require.NoError(t, json.Unmarshal([]byte(`{"fields": []}`), &order))