Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ func (c client) validateStruct(ctx context.Context, obj any) error {
// Insert implements inserting an object or slice of objects in the database.
// Passed object must be a pointer to a struct with appropriate dgraph tags.
func (c client) Insert(ctx context.Context, obj any) error {
obj = UnwrapSchema(obj)
// Validate struct before insertion
if err := c.validateStruct(ctx, obj); err != nil {
return err
Expand All @@ -503,6 +504,7 @@ func (c client) Insert(ctx context.Context, obj any) error {
//
// Deprecated: InsertRaw is now identical to Insert. Use Insert instead.
func (c client) InsertRaw(ctx context.Context, obj any) error {
obj = UnwrapSchema(obj)
// Validate struct before insertion
if err := c.validateStruct(ctx, obj); err != nil {
return err
Expand All @@ -518,6 +520,7 @@ func (c client) InsertRaw(ctx context.Context, obj any) error {
// to be used for upserting. If none are specified, the first predicate with the `upsert` tag
// will be used.
func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error {
obj = UnwrapSchema(obj)
// Validate struct before upsert
if err := c.validateStruct(ctx, obj); err != nil {
return err
Expand All @@ -531,6 +534,7 @@ func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error
// Update implements updating an existing object in the database.
// Passed object must be a pointer to a struct.
func (c client) Update(ctx context.Context, obj any) error {
obj = UnwrapSchema(obj)
// Validate struct before update
if err := c.validateStruct(ctx, obj); err != nil {
return err
Expand All @@ -557,6 +561,7 @@ func (c client) Delete(ctx context.Context, uids []string) error {
// Get implements retrieving a single object by its UID.
// Passed object must be a pointer to a struct.
func (c client) Get(ctx context.Context, obj any, uid string) error {
obj = UnwrapSchema(obj)
err := checkPointer(obj)
if err != nil {
return err
Expand All @@ -575,6 +580,7 @@ func (c client) Get(ctx context.Context, obj any, uid string) error {
// Returns a *dg.Query that can be further refined with filters, pagination, etc.
// The returned query will be limited to the maximum number of edges specified in the options.
func (c client) Query(ctx context.Context, model any) *dg.Query {
model = UnwrapSchema(model)
client, err := c.pool.get()
if err != nil {
return nil
Expand All @@ -590,6 +596,9 @@ func (c client) Query(ctx context.Context, model any) *dg.Query {
// If any object contains SimString fields tagged `dgraph:"embedding"`, the
// corresponding shadow float32vector predicates (<field>__vec) are also registered.
func (c client) UpdateSchema(ctx context.Context, obj ...any) error {
for i := range obj {
obj[i] = UnwrapSchema(obj[i])
}
dgClient, err := c.pool.get()
if err != nil {
c.logger.Error(err, "Failed to get client from pool")
Expand Down
58 changes: 58 additions & 0 deletions record.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package modusgraph

import "reflect"

// Schema identifies a value as a record of a generated schema-defining type.
// modusgraph-gen-emitted schema structs implement this via a generated
// SchemaTypeName() method that returns the canonical entity name
// (e.g. "Studio"). The interface is intentionally minimal — a single method
// returning a useful piece of metadata.
//
// Plain user structs (not emitted by modusgraph-gen) do not implement Schema
// and are unaffected by the modusgraph.Client routing it enables; they pass
// through to the existing reflection-based dgman pipeline exactly as before.
type Schema interface {
SchemaTypeName() string
}

// UnwrapSchema returns the schema-defining record contained in obj. If obj
// is nil, it is returned as-is. If obj is already a Schema, it is returned
// as-is. If obj exposes an Unwrap() method whose return value satisfies
// Schema, that return is substituted. Otherwise obj is returned unchanged.
//
// This is the bridge between modusgraph-gen-emitted wrapper types and the
// rest of modusgraph.Client. It is purely additive: types that don't
// implement Schema and don't have an Unwrap() method (i.e. existing
// modusgraph users' plain structs) pass through untouched.
//
// Note on errors.Unwrap overlap: Go's errors package uses Unwrap() error
// as the standard "give me the wrapped thing" method. UnwrapSchema's
// secondary check (the returned value must itself implement Schema) means
// an error wrapper is not mistaken for a modusgraph wrapper — the
// reflection probe finds Unwrap(), calls it, gets an error, fails the
// Schema check, and returns the original obj.
func UnwrapSchema(obj any) any {
if obj == nil {
return obj
}
if _, ok := obj.(Schema); ok {
return obj
}
v := reflect.ValueOf(obj)
if !v.IsValid() {
return obj
}
m := v.MethodByName("Unwrap")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: UnwrapSchema misses pointer-receiver Unwrap() methods when wrapper is passed by value

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At record.go, line 45:

<comment>UnwrapSchema misses pointer-receiver Unwrap() methods when wrapper is passed by value</comment>

<file context>
@@ -0,0 +1,58 @@
+	if !v.IsValid() {
+		return obj
+	}
+	m := v.MethodByName("Unwrap")
+	if !m.IsValid() {
+		return obj
</file context>
Suggested change
m := v.MethodByName("Unwrap")
m := v.MethodByName("Unwrap")
if !m.IsValid() && v.Kind() != reflect.Ptr && v.CanAddr() {
m = v.Addr().MethodByName("Unwrap")
}
if !m.IsValid() {
return obj
}

if !m.IsValid() {
return obj
}
mt := m.Type()
if mt.NumIn() != 0 || mt.NumOut() != 1 {
return obj
}
inner := m.Call(nil)[0].Interface()
if _, ok := inner.(Schema); ok {
Comment on lines +53 to +54

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1: Potential panic when invoking reflected Unwrap() on typed nil pointer wrappers

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At record.go, line 53:

<comment>Potential panic when invoking reflected Unwrap() on typed nil pointer wrappers</comment>

<file context>
@@ -0,0 +1,58 @@
+	if mt.NumIn() != 0 || mt.NumOut() != 1 {
+		return obj
+	}
+	inner := m.Call(nil)[0].Interface()
+	if _, ok := inner.(Schema); ok {
+		return inner
</file context>
Suggested change
inner := m.Call(nil)[0].Interface()
if _, ok := inner.(Schema); ok {
// Guard against typed-nil pointer wrappers
if v.Kind() == reflect.Ptr && v.IsNil() {
return obj
}
m := v.MethodByName("Unwrap")
if !m.IsValid() {
return obj
}

return inner
}
return obj
}
117 changes: 117 additions & 0 deletions record_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package modusgraph

import (
"errors"
"testing"
)

type fakeRecord struct{ name string }

func (f *fakeRecord) SchemaTypeName() string { return f.name }

type fakeWrapper struct{ inner *fakeRecord }

func (w *fakeWrapper) Unwrap() *fakeRecord { return w.inner }

type fakeNonSchema struct{ X string }

func TestUnwrapSchema_PassthroughForPlainStruct(t *testing.T) {
in := &fakeNonSchema{X: "hi"}
out := UnwrapSchema(in)
if out != any(in) {
t.Fatalf("expected passthrough, got %T", out)
}
}

func TestUnwrapSchema_PassthroughForSchemaStruct(t *testing.T) {
in := &fakeRecord{name: "Studio"}
out := UnwrapSchema(in)
if out != any(in) {
t.Fatalf("expected passthrough for direct Schema, got %T", out)
}
}

func TestUnwrapSchema_UnwrapsWrapper(t *testing.T) {
inner := &fakeRecord{name: "Studio"}
w := &fakeWrapper{inner: inner}
out := UnwrapSchema(w)
if out != any(inner) {
t.Fatalf("expected unwrapped inner, got %T (%v)", out, out)
}
}

func TestUnwrapSchema_IgnoresErrorsUnwrap(t *testing.T) {
// errors.New("x") has no Unwrap; wrap one to get something with Unwrap() error.
inner := errors.New("inner")
outer := &wrappedErr{err: inner}
out := UnwrapSchema(outer)
if out != any(outer) {
t.Fatalf("expected passthrough for error wrapper, got %T", out)
}
}

type wrappedErr struct{ err error }

func (w *wrappedErr) Error() string { return w.err.Error() }
func (w *wrappedErr) Unwrap() error { return w.err }

func TestUnwrapSchema_NilInput(t *testing.T) {
if out := UnwrapSchema(nil); out != nil {
t.Fatalf("expected nil for nil input, got %v", out)
}
}

// recordingClient is the minimal surface needed to verify that wrappers
// passed to the Client interface get unwrapped before reaching internal
// reflection. It records whatever it received and returns nil. Each method
// applies obj = UnwrapSchema(obj) at the top, mirroring the patch landing
// in this task.
type recordingClient struct {
seen []any
}

func (c *recordingClient) capture(obj any) any {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Tests use a local recordingClient mock to simulate unwrapping instead of exercising real client methods, so regressions in the actual integration points (7 client mutation/query methods) may go undetected.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At record_test.go, line 73:

<comment>Tests use a local `recordingClient` mock to simulate unwrapping instead of exercising real client methods, so regressions in the actual integration points (7 client mutation/query methods) may go undetected.</comment>

<file context>
@@ -0,0 +1,117 @@
+	seen []any
+}
+
+func (c *recordingClient) capture(obj any) any {
+	obj = UnwrapSchema(obj)
+	c.seen = append(c.seen, obj)
</file context>

obj = UnwrapSchema(obj)
c.seen = append(c.seen, obj)
return obj
}

func TestUnwrapSchema_CaptureForwardsInner(t *testing.T) {
inner := &fakeRecord{name: "Studio"}
w := &fakeWrapper{inner: inner}
c := &recordingClient{}
got := c.capture(w)
if got != any(inner) {
t.Fatalf("expected inner record, got %T (%v)", got, got)
}
if len(c.seen) != 1 || c.seen[0] != any(inner) {
t.Fatalf("expected recording to hold inner record, got %v", c.seen)
}
}

func TestUnwrapSchema_CapturePassthroughForPlain(t *testing.T) {
plain := &fakeNonSchema{X: "y"}
c := &recordingClient{}
got := c.capture(plain)
if got != any(plain) {
t.Fatalf("expected plain struct passthrough, got %T", got)
}
}

func TestUnwrapSchema_VariadicUnwrapsEachElement(t *testing.T) {
innerA := &fakeRecord{name: "Studio"}
innerB := &fakeRecord{name: "Film"}
templates := []any{
&fakeWrapper{inner: innerA},
innerB, // already a Schema; passthrough
}
for i, obj := range templates {
templates[i] = UnwrapSchema(obj)
}
if templates[0] != any(innerA) {
t.Fatalf("template[0]: expected innerA, got %T", templates[0])
}
if templates[1] != any(innerB) {
t.Fatalf("template[1]: expected innerB (passthrough), got %T", templates[1])
}
}
Loading