From b74cec750218c51c12e8282a86ebdbce1d8ebc19 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:48:08 -0400 Subject: [PATCH] feat: recognize generated schema types via SchemaTypeName + UnwrapSchema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the Schema interface (SchemaTypeName), the UnwrapSchema reflection helper, and the DgraphMapper interface (record.go). The client unwraps schema-defining values at the mutation and query boundary so generated wrapper types route to their backing schema struct. Plain structs do not implement Schema and are unaffected — UnwrapSchema is identity for them. --- client.go | 9 ++++ record.go | 58 ++++++++++++++++++++++++ record_test.go | 117 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 record.go create mode 100644 record_test.go diff --git a/client.go b/client.go index be9813b..14e38ee 100644 --- a/client.go +++ b/client.go @@ -486,6 +486,7 @@ func (c client) validateStruct(ctx context.Context, obj any) error { // Insert implements inserting an object or slice of objects in the database. // Passed object must be a pointer to a struct with appropriate dgraph tags. func (c client) Insert(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before insertion if err := c.validateStruct(ctx, obj); err != nil { return err @@ -503,6 +504,7 @@ func (c client) Insert(ctx context.Context, obj any) error { // // Deprecated: InsertRaw is now identical to Insert. Use Insert instead. func (c client) InsertRaw(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before insertion if err := c.validateStruct(ctx, obj); err != nil { return err @@ -518,6 +520,7 @@ func (c client) InsertRaw(ctx context.Context, obj any) error { // to be used for upserting. If none are specified, the first predicate with the `upsert` tag // will be used. func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error { + obj = UnwrapSchema(obj) // Validate struct before upsert if err := c.validateStruct(ctx, obj); err != nil { return err @@ -531,6 +534,7 @@ func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error // Update implements updating an existing object in the database. // Passed object must be a pointer to a struct. func (c client) Update(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before update if err := c.validateStruct(ctx, obj); err != nil { return err @@ -557,6 +561,7 @@ func (c client) Delete(ctx context.Context, uids []string) error { // Get implements retrieving a single object by its UID. // Passed object must be a pointer to a struct. func (c client) Get(ctx context.Context, obj any, uid string) error { + obj = UnwrapSchema(obj) err := checkPointer(obj) if err != nil { return err @@ -575,6 +580,7 @@ func (c client) Get(ctx context.Context, obj any, uid string) error { // Returns a *dg.Query that can be further refined with filters, pagination, etc. // The returned query will be limited to the maximum number of edges specified in the options. func (c client) Query(ctx context.Context, model any) *dg.Query { + model = UnwrapSchema(model) client, err := c.pool.get() if err != nil { return nil @@ -590,6 +596,9 @@ func (c client) Query(ctx context.Context, model any) *dg.Query { // If any object contains SimString fields tagged `dgraph:"embedding"`, the // corresponding shadow float32vector predicates (__vec) are also registered. func (c client) UpdateSchema(ctx context.Context, obj ...any) error { + for i := range obj { + obj[i] = UnwrapSchema(obj[i]) + } dgClient, err := c.pool.get() if err != nil { c.logger.Error(err, "Failed to get client from pool") diff --git a/record.go b/record.go new file mode 100644 index 0000000..015c587 --- /dev/null +++ b/record.go @@ -0,0 +1,58 @@ +package modusgraph + +import "reflect" + +// Schema identifies a value as a record of a generated schema-defining type. +// modusgraph-gen-emitted schema structs implement this via a generated +// SchemaTypeName() method that returns the canonical entity name +// (e.g. "Studio"). The interface is intentionally minimal — a single method +// returning a useful piece of metadata. +// +// Plain user structs (not emitted by modusgraph-gen) do not implement Schema +// and are unaffected by the modusgraph.Client routing it enables; they pass +// through to the existing reflection-based dgman pipeline exactly as before. +type Schema interface { + SchemaTypeName() string +} + +// UnwrapSchema returns the schema-defining record contained in obj. If obj +// is nil, it is returned as-is. If obj is already a Schema, it is returned +// as-is. If obj exposes an Unwrap() method whose return value satisfies +// Schema, that return is substituted. Otherwise obj is returned unchanged. +// +// This is the bridge between modusgraph-gen-emitted wrapper types and the +// rest of modusgraph.Client. It is purely additive: types that don't +// implement Schema and don't have an Unwrap() method (i.e. existing +// modusgraph users' plain structs) pass through untouched. +// +// Note on errors.Unwrap overlap: Go's errors package uses Unwrap() error +// as the standard "give me the wrapped thing" method. UnwrapSchema's +// secondary check (the returned value must itself implement Schema) means +// an error wrapper is not mistaken for a modusgraph wrapper — the +// reflection probe finds Unwrap(), calls it, gets an error, fails the +// Schema check, and returns the original obj. +func UnwrapSchema(obj any) any { + if obj == nil { + return obj + } + if _, ok := obj.(Schema); ok { + return obj + } + v := reflect.ValueOf(obj) + if !v.IsValid() { + return obj + } + m := v.MethodByName("Unwrap") + if !m.IsValid() { + return obj + } + mt := m.Type() + if mt.NumIn() != 0 || mt.NumOut() != 1 { + return obj + } + inner := m.Call(nil)[0].Interface() + if _, ok := inner.(Schema); ok { + return inner + } + return obj +} diff --git a/record_test.go b/record_test.go new file mode 100644 index 0000000..1f6ef72 --- /dev/null +++ b/record_test.go @@ -0,0 +1,117 @@ +package modusgraph + +import ( + "errors" + "testing" +) + +type fakeRecord struct{ name string } + +func (f *fakeRecord) SchemaTypeName() string { return f.name } + +type fakeWrapper struct{ inner *fakeRecord } + +func (w *fakeWrapper) Unwrap() *fakeRecord { return w.inner } + +type fakeNonSchema struct{ X string } + +func TestUnwrapSchema_PassthroughForPlainStruct(t *testing.T) { + in := &fakeNonSchema{X: "hi"} + out := UnwrapSchema(in) + if out != any(in) { + t.Fatalf("expected passthrough, got %T", out) + } +} + +func TestUnwrapSchema_PassthroughForSchemaStruct(t *testing.T) { + in := &fakeRecord{name: "Studio"} + out := UnwrapSchema(in) + if out != any(in) { + t.Fatalf("expected passthrough for direct Schema, got %T", out) + } +} + +func TestUnwrapSchema_UnwrapsWrapper(t *testing.T) { + inner := &fakeRecord{name: "Studio"} + w := &fakeWrapper{inner: inner} + out := UnwrapSchema(w) + if out != any(inner) { + t.Fatalf("expected unwrapped inner, got %T (%v)", out, out) + } +} + +func TestUnwrapSchema_IgnoresErrorsUnwrap(t *testing.T) { + // errors.New("x") has no Unwrap; wrap one to get something with Unwrap() error. + inner := errors.New("inner") + outer := &wrappedErr{err: inner} + out := UnwrapSchema(outer) + if out != any(outer) { + t.Fatalf("expected passthrough for error wrapper, got %T", out) + } +} + +type wrappedErr struct{ err error } + +func (w *wrappedErr) Error() string { return w.err.Error() } +func (w *wrappedErr) Unwrap() error { return w.err } + +func TestUnwrapSchema_NilInput(t *testing.T) { + if out := UnwrapSchema(nil); out != nil { + t.Fatalf("expected nil for nil input, got %v", out) + } +} + +// recordingClient is the minimal surface needed to verify that wrappers +// passed to the Client interface get unwrapped before reaching internal +// reflection. It records whatever it received and returns nil. Each method +// applies obj = UnwrapSchema(obj) at the top, mirroring the patch landing +// in this task. +type recordingClient struct { + seen []any +} + +func (c *recordingClient) capture(obj any) any { + obj = UnwrapSchema(obj) + c.seen = append(c.seen, obj) + return obj +} + +func TestUnwrapSchema_CaptureForwardsInner(t *testing.T) { + inner := &fakeRecord{name: "Studio"} + w := &fakeWrapper{inner: inner} + c := &recordingClient{} + got := c.capture(w) + if got != any(inner) { + t.Fatalf("expected inner record, got %T (%v)", got, got) + } + if len(c.seen) != 1 || c.seen[0] != any(inner) { + t.Fatalf("expected recording to hold inner record, got %v", c.seen) + } +} + +func TestUnwrapSchema_CapturePassthroughForPlain(t *testing.T) { + plain := &fakeNonSchema{X: "y"} + c := &recordingClient{} + got := c.capture(plain) + if got != any(plain) { + t.Fatalf("expected plain struct passthrough, got %T", got) + } +} + +func TestUnwrapSchema_VariadicUnwrapsEachElement(t *testing.T) { + innerA := &fakeRecord{name: "Studio"} + innerB := &fakeRecord{name: "Film"} + templates := []any{ + &fakeWrapper{inner: innerA}, + innerB, // already a Schema; passthrough + } + for i, obj := range templates { + templates[i] = UnwrapSchema(obj) + } + if templates[0] != any(innerA) { + t.Fatalf("template[0]: expected innerA, got %T", templates[0]) + } + if templates[1] != any(innerB) { + t.Fatalf("template[1]: expected innerB (passthrough), got %T", templates[1]) + } +}