diff --git a/client.go b/client.go index be9813b..7542ceb 100644 --- a/client.go +++ b/client.go @@ -109,6 +109,20 @@ type StructValidator interface { StructCtx(ctx context.Context, s interface{}) error } +// SelfValidator lets a type drive its own validation. When a value passed to +// Insert, Upsert, or Update implements SelfValidator, the client calls +// ValidateWith instead of handing the value straight to the configured +// StructValidator. +// +// This is the seam for validation that struct tags cannot express on their own: +// cross-field rules (one field constrained by another), conditional rules, +// checks on computed or setter-derived values, and broader business rules. +// ValidateWith receives the configured StructValidator, so an implementation can +// still run ordinary tag-based validation and then layer custom logic on top. +type SelfValidator interface { + ValidateWith(ctx context.Context, v StructValidator) error +} + // clientOptions holds configuration options for the client. // // autoSchema: whether to automatically manage the schema. @@ -472,17 +486,33 @@ func (c client) validateStruct(ctx context.Context, obj any) error { } elem = elem.Elem() } - if err := c.options.validator.StructCtx(ctx, elem.Interface()); err != nil { + if err := c.validateOne(ctx, elem); err != nil { return err } } } else { - return c.options.validator.StructCtx(ctx, obj) + return c.validateOne(ctx, val) } return nil } +// validateOne validates a single struct value. If the value (or its address) +// implements SelfValidator, validation is delegated to ValidateWith so the type +// can apply custom rules — cross-field, conditional, computed-value, or other +// logic beyond struct tags. Otherwise the value is validated by the configured +// StructValidator as usual. +func (c client) validateOne(ctx context.Context, val reflect.Value) error { + iface := val.Interface() + if val.CanAddr() { + iface = val.Addr().Interface() + } + if sv, ok := iface.(SelfValidator); ok { + return sv.ValidateWith(ctx, c.options.validator) + } + return c.options.validator.StructCtx(ctx, iface) +} + // Insert implements inserting an object or slice of objects in the database. // Passed object must be a pointer to a struct with appropriate dgraph tags. func (c client) Insert(ctx context.Context, obj any) error { diff --git a/self_validator_test.go b/self_validator_test.go new file mode 100644 index 0000000..bf5b5b4 --- /dev/null +++ b/self_validator_test.go @@ -0,0 +1,91 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "context" + "errors" + "fmt" + "testing" +) + +// recordingValidator counts StructCtx calls so tests can assert which path ran. +type recordingValidator struct{ calls int } + +func (r *recordingValidator) StructCtx(_ context.Context, _ interface{}) error { + r.calls++ + return nil +} + +var errSelfValidated = errors.New("self-validated") + +type selfValidatingEntity struct{ Name string } + +func (s *selfValidatingEntity) ValidateWith(_ context.Context, _ StructValidator) error { + return errSelfValidated +} + +type plainEntity struct{ Name string } + +func TestValidateRoutesToSelfValidator(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + err := c.validateStruct(context.Background(), &selfValidatingEntity{Name: "x"}) + if !errors.Is(err, errSelfValidated) { + t.Fatalf("expected the SelfValidator path, got %v", err) + } + if rv.calls != 0 { + t.Fatalf("StructCtx must not run for a SelfValidator, got %d calls", rv.calls) + } +} + +func TestValidateFallsBackToStructCtx(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + if err := c.validateStruct(context.Background(), &plainEntity{Name: "x"}); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rv.calls != 1 { + t.Fatalf("expected StructCtx to run once, got %d", rv.calls) + } +} + +func TestValidateSelfValidatorInSlice(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + err := c.validateStruct(context.Background(), []*selfValidatingEntity{{Name: "a"}}) + if !errors.Is(err, errSelfValidated) { + t.Fatalf("expected the SelfValidator path for slice elements, got %v", err) + } +} + +// dateRange validates a relationship between two fields — a cross-field rule +// that struct tags alone cannot express. +type dateRange struct { + Start int + End int +} + +func (d *dateRange) ValidateWith(_ context.Context, _ StructValidator) error { + if d.End < d.Start { + return fmt.Errorf("End (%d) must be >= Start (%d)", d.End, d.Start) + } + return nil +} + +func TestSelfValidatorCustomCrossFieldRule(t *testing.T) { + c := client{options: clientOptions{validator: &recordingValidator{}}} + + if err := c.validateStruct(context.Background(), &dateRange{Start: 1, End: 5}); err != nil { + t.Fatalf("a valid range should pass the cross-field rule: %v", err) + } + if err := c.validateStruct(context.Background(), &dateRange{Start: 5, End: 1}); err == nil { + t.Fatal("End < Start must fail the custom cross-field rule") + } +}