From e9028a1f7193a40bbe5193417bf0f2da977ba984 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:06:57 -0400 Subject: [PATCH] feat(typed): generic, type-safe client and query builder Add a generic typed layer over modusgraph.Client: typed.Client[T] with CRUD and iterators; a fluent Query[T] builder (filters, ordering, paging, edge traversal, IterNodes); MultiQuery for N homogeneous blocks in one round-trip; functional options; a filter DSL (typed/filter); and ordered result merging (typed/search). A small no-op-by-default Tracer seam (typed.SetTracer) lets a host plug in tracing without the typed package depending on any telemetry library. Self-contained: builds and tests against the current client with no other changes. --- typed/client.go | 87 +++ typed/client_test.go | 209 ++++++ typed/filter/filter.go | 118 +++ typed/filter/filter_test.go | 118 +++ typed/filter/fulltext.go | 21 + typed/filter/fulltext_test.go | 41 ++ typed/multi_query.go | 191 +++++ typed/multi_query_test.go | 127 ++++ typed/option.go | 17 + typed/option_test.go | 37 + typed/query.go | 565 ++++++++++++++ typed/query_test.go | 1294 +++++++++++++++++++++++++++++++++ typed/search/merge.go | 27 + typed/search/merge_test.go | 86 +++ typed/tracing.go | 58 ++ typed/tracing_test.go | 47 ++ 16 files changed, 3043 insertions(+) create mode 100644 typed/client.go create mode 100644 typed/client_test.go create mode 100644 typed/filter/filter.go create mode 100644 typed/filter/filter_test.go create mode 100644 typed/filter/fulltext.go create mode 100644 typed/filter/fulltext_test.go create mode 100644 typed/multi_query.go create mode 100644 typed/multi_query_test.go create mode 100644 typed/option.go create mode 100644 typed/option_test.go create mode 100644 typed/query.go create mode 100644 typed/query_test.go create mode 100644 typed/search/merge.go create mode 100644 typed/search/merge_test.go create mode 100644 typed/tracing.go create mode 100644 typed/tracing_test.go diff --git a/typed/client.go b/typed/client.go new file mode 100644 index 0000000..c540f89 --- /dev/null +++ b/typed/client.go @@ -0,0 +1,87 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package typed binds a Go type to the otherwise any-typed modusgraph.Client, +// providing generic, type-safe CRUD and query operations without per-entity +// code generation. It is the handwritten substrate that modusgraph-gen's +// generated clients compose over. +package typed + +import ( + "context" + "iter" + + "github.com/matthewmcneely/modusgraph" +) + +// Client provides type-safe CRUD and query operations over records of type T. +// T is the schema struct (for example schema.Actor); modusgraph reflects over +// the struct's dgraph/json tags, so T needs no constraint. +type Client[T any] struct { + conn modusgraph.Client +} + +// NewClient binds a Client[T] to conn. +func NewClient[T any](conn modusgraph.Client) *Client[T] { + return &Client[T]{conn: conn} +} + +// Get loads the T with the given UID. +func (c *Client[T]) Get(ctx context.Context, uid string) (rec *T, err error) { + ctx, span := tracer.StartSpan(ctx, "get", entityName[T]()) + defer func() { span.End(err) }() + var out T + if err = c.conn.Get(ctx, &out, uid); err != nil { + return nil, err + } + return &out, nil +} + +// Add inserts a new T. modusgraph writes the assigned UID back into rec. +func (c *Client[T]) Add(ctx context.Context, rec *T) (err error) { + ctx, span := tracer.StartSpan(ctx, "add", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Insert(ctx, rec) +} + +// Update modifies an existing T (must have its UID set). +func (c *Client[T]) Update(ctx context.Context, rec *T) (err error) { + ctx, span := tracer.StartSpan(ctx, "update", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Update(ctx, rec) +} + +// Upsert inserts or updates rec, matching against predicates. With no +// predicates, the first field tagged dgraph:"upsert" is used. +func (c *Client[T]) Upsert(ctx context.Context, rec *T, predicates ...string) (err error) { + ctx, span := tracer.StartSpan(ctx, "upsert", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Upsert(ctx, rec, predicates...) +} + +// Delete removes the T with the given UID. +func (c *Client[T]) Delete(ctx context.Context, uid string) (err error) { + ctx, span := tracer.StartSpan(ctx, "delete", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Delete(ctx, []string{uid}) +} + +// Query returns a typed query builder for T. conn and ctx are carried so the +// builder can run a WhereEdge pre-pass (see Query.WhereEdge) if one is needed. +func (c *Client[T]) Query(ctx context.Context) *Query[T] { + var z T + return &Query[T]{q: c.conn.Query(ctx, &z), conn: c.conn, ctx: ctx} +} + +// defaultPageSize is the page size IterNodes uses to page through results. +const defaultPageSize = 50 + +// Iter returns an iterator over every T, paging transparently so large result +// sets are not materialized at once. It yields each record in turn; on error +// it yields a final (nil, err) and stops. All pages execute against one +// read-only transaction, so the iteration reads a single consistent snapshot. +func (c *Client[T]) Iter(ctx context.Context) iter.Seq2[*T, error] { + return c.Query(ctx).IterNodes() +} diff --git a/typed/client_test.go b/typed/client_test.go new file mode 100644 index 0000000..6fa2b1d --- /dev/null +++ b/typed/client_test.go @@ -0,0 +1,209 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// widget is a minimal schema struct used to exercise the typed package. +type widget struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Qty int `json:"qty,omitempty" dgraph:"index=int"` +} + +// owner and pet exercise Query.WhereEdge: owner has an outbound "pets" edge to +// pet, and pet's Name carries an index so eq(name, ...) resolves inside an edge +// filter. The pair is the typed-package analogue of the Person/Dog example in +// docs/specs/2026-05-21-query-edge-filter-design.md. +type owner struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Pets []*pet `json:"pets,omitempty"` +} + +type pet struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` +} + +// newConn builds a local file-backed modusgraph client for a test. +func newConn(t *testing.T) modusgraph.Client { + t.Helper() + conn, err := modusgraph.NewClient("file://"+t.TempDir(), modusgraph.WithAutoSchema(true)) + if err != nil { + t.Fatalf("modusgraph.NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestClient_AddPopulatesUIDAndGetReadsBack(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "sprocket", Qty: 3} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + if w.UID == "" { + t.Fatal("Add did not populate UID on the passed struct") + } + + got, err := c.Get(ctx, w.UID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Name != "sprocket" || got.Qty != 3 { + t.Fatalf("Get returned %+v, want Name=sprocket Qty=3", got) + } +} + +func TestClient_Update(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "gear", Qty: 1} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + w.Qty = 99 + if err := c.Update(ctx, w); err != nil { + t.Fatalf("Update: %v", err) + } + + got, err := c.Get(ctx, w.UID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Qty != 99 { + t.Fatalf("Update did not persist; Qty = %d, want 99", got.Qty) + } +} + +func TestClient_Delete(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "bolt"} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + if err := c.Delete(ctx, w.UID); err != nil { + t.Fatalf("Delete: %v", err) + } + if _, err := c.Get(ctx, w.UID); err == nil { + t.Fatal("Get after Delete returned no error; expected not-found") + } +} + +func TestClient_IterPagesThroughAllRecords(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // 125 is deliberately larger than the package's 50-record page size, so + // a correct Iter must fetch more than one page. + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + seen := 0 + for w, err := range c.Iter(ctx) { + if err != nil { + t.Fatalf("Iter yielded error: %v", err) + } + if w == nil { + t.Fatal("Iter yielded a nil widget") + } + seen++ + } + if seen != n { + t.Fatalf("Iter yielded %d records, want %d", seen, n) + } +} + +// gadget is a dedicated upsert struct. It must not be the shared widget, because +// widget is used in tests that insert many records with duplicate Name values; +// adding a "upsert" directive to widget.Name would cause those inserts to +// collide and break unrelated tests. +type gadget struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Label string `json:"label,omitempty" dgraph:"index=exact upsert"` + Stock int `json:"stock,omitempty" dgraph:"index=int"` +} + +func TestClient_Upsert(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[gadget](newConn(t)) + + // First call — creates the record. + g := &gadget{Label: "sprocket", Stock: 10} + if err := c.Upsert(ctx, g, "label"); err != nil { + t.Fatalf("Upsert (create): %v", err) + } + if g.UID == "" { + t.Fatal("Upsert (create) did not populate UID") + } + + // Second call — same Label value, different Stock. Must UPDATE, not insert. + g2 := &gadget{Label: "sprocket", Stock: 99} + if err := c.Upsert(ctx, g2, "label"); err != nil { + t.Fatalf("Upsert (update): %v", err) + } + + // Exactly one record must exist and it must carry the updated Stock. + nodes, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Query after Upsert: %v", err) + } + if len(nodes) != 1 { + t.Fatalf("got %d gadgets after two upserts on the same label, want 1", len(nodes)) + } + if nodes[0].Stock != 99 { + t.Fatalf("upserted gadget Stock = %d, want 99", nodes[0].Stock) + } +} + +func TestClient_IterStopsOnConsumerBreak(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + seen := 0 + for w, err := range c.Iter(ctx) { + if err != nil { + t.Fatalf("Iter yielded error: %v", err) + } + if w == nil { + t.Fatal("Iter yielded a nil widget") + } + seen++ + if seen == 10 { + break + } + } + if seen != 10 { + t.Fatalf("Iter yielded %d records after break at 10, want 10", seen) + } +} diff --git a/typed/filter/filter.go b/typed/filter/filter.go new file mode 100644 index 0000000..d67f118 --- /dev/null +++ b/typed/filter/filter.go @@ -0,0 +1,118 @@ +// Package filter provides typed values and a parameterised expression builder +// for composing dgraph @filter clauses on generated Query types. +// +// Generated By methods accept []UUID or []String and feed them into +// Builder.EqGroupUUID / Builder.EqGroupString. Consumers can also build +// custom expressions directly with Builder for cases the generator does not +// cover (multi-predicate joins, non-equality operators, domain defaults). +package filter + +import ( + "fmt" + "strings" +) + +// UUID is one UUID-valued filter term, optionally negated. A leading "!" in +// the parsed source negates the term ("!abc" becomes {Negated: true, Value: "abc"}). +type UUID struct { + Negated bool + Value string +} + +// String is one string-valued filter term, optionally negated. +type String struct { + Negated bool + Value string +} + +// ParseUUID parses "value" or "!value" into a UUID. +func ParseUUID(s string) UUID { + neg, v := parseNegation(s) + return UUID{Negated: neg, Value: v} +} + +// ParseString parses "value" or "!value" into a String. +func ParseString(s string) String { + neg, v := parseNegation(s) + return String{Negated: neg, Value: v} +} + +func parseNegation(s string) (bool, string) { + if strings.HasPrefix(s, "!") { + return true, s[1:] + } + return false, s +} + +// term is one predicate-agnostic value used by Builder. +type term struct { + value string + negated bool +} + +// Builder composes parameterised DQL @filter expressions. Terms within an +// EqGroup join with OR; groups join with AND. Required terms become their own +// single-term group. The output is the (expression, positional params) pair +// that typed.Query[T].Filter consumes. +type Builder struct { + groups []string + params []any +} + +func (b *Builder) param(v any) string { + b.params = append(b.params, v) + return fmt.Sprintf("$%d", len(b.params)) +} + +// EqGroupUUID adds an OR-group of eq(predicate, value) terms for one +// UUID-typed predicate. An empty terms slice is a no-op. +func (b *Builder) EqGroupUUID(predicate string, terms []UUID) { + if len(terms) == 0 { + return + } + tg := make([]term, 0, len(terms)) + for _, t := range terms { + tg = append(tg, term{value: t.Value, negated: t.Negated}) + } + b.addEqGroup(predicate, tg) +} + +// EqGroupString adds an OR-group of eq(predicate, value) terms for one +// string-typed predicate. +func (b *Builder) EqGroupString(predicate string, terms []String) { + if len(terms) == 0 { + return + } + tg := make([]term, 0, len(terms)) + for _, t := range terms { + tg = append(tg, term{value: t.Value, negated: t.Negated}) + } + b.addEqGroup(predicate, tg) +} + +func (b *Builder) addEqGroup(predicate string, terms []term) { + parts := make([]string, 0, len(terms)) + for _, t := range terms { + eq := fmt.Sprintf("eq(%s, %s)", predicate, b.param(t.value)) + if t.negated { + eq = "NOT " + eq + } + parts = append(parts, eq) + } + b.groups = append(b.groups, "("+strings.Join(parts, " OR ")+")") +} + +// RequiredEq adds a single mandatory eq(predicate, value) term (its own group). +func (b *Builder) RequiredEq(predicate, value string) { + b.groups = append(b.groups, fmt.Sprintf("eq(%s, %s)", predicate, b.param(value))) +} + +// Build returns the combined DQL filter expression and its parameters. When +// no groups were added it returns ("", nil) — callers should skip the +// .Filter() call entirely in that case. +func (b *Builder) Build() (string, []any) { + if len(b.groups) == 0 { + return "", nil + } + return strings.Join(b.groups, " AND "), b.params +} diff --git a/typed/filter/filter_test.go b/typed/filter/filter_test.go new file mode 100644 index 0000000..864a554 --- /dev/null +++ b/typed/filter/filter_test.go @@ -0,0 +1,118 @@ +package filter_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +func TestParseUUID(t *testing.T) { + tests := []struct { + name string + in string + want filter.UUID + }{ + {"plain", "abc", filter.UUID{Value: "abc"}}, + {"negated", "!abc", filter.UUID{Negated: true, Value: "abc"}}, + {"empty", "", filter.UUID{}}, + {"just bang", "!", filter.UUID{Negated: true}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := filter.ParseUUID(tt.in) + if got != tt.want { + t.Errorf("ParseUUID(%q) = %+v, want %+v", tt.in, got, tt.want) + } + }) + } +} + +func TestParseString(t *testing.T) { + got := filter.ParseString("!hello") + want := filter.String{Negated: true, Value: "hello"} + if got != want { + t.Errorf("ParseString = %+v, want %+v", got, want) + } +} + +func TestBuilder_Empty(t *testing.T) { + var b filter.Builder + expr, params := b.Build() + if expr != "" || params != nil { + t.Errorf("empty Build = (%q, %v), want (\"\", nil)", expr, params) + } +} + +func TestBuilder_EqGroupUUID_SingleTerm(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + expr, params := b.Build() + want := "(eq(id, $1))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 1 || params[0] != "u1" { + t.Errorf("params = %v, want [u1]", params) + } +} + +func TestBuilder_EqGroupUUID_MultipleTermsJoinWithOR(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}, {Value: "u2"}, {Negated: true, Value: "u3"}}) + expr, params := b.Build() + want := "(eq(id, $1) OR eq(id, $2) OR NOT eq(id, $3))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 3 { + t.Errorf("len(params) = %d, want 3", len(params)) + } +} + +func TestBuilder_EqGroupString_NoTermsIsNoop(t *testing.T) { + var b filter.Builder + b.EqGroupString("name", nil) + expr, _ := b.Build() + if expr != "" { + t.Errorf("empty EqGroupString should be no-op, got expr=%q", expr) + } +} + +func TestBuilder_MultipleGroupsJoinWithAND(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + b.EqGroupString("name", []filter.String{{Value: "Alice"}}) + expr, params := b.Build() + want := "(eq(id, $1)) AND (eq(name, $2))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 2 || params[0] != "u1" || params[1] != "Alice" { + t.Errorf("params = %v, want [u1 Alice]", params) + } +} + +func TestBuilder_RequiredEqIsOwnGroup(t *testing.T) { + var b filter.Builder + b.RequiredEq("archiveStatus", "Active") + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + expr, params := b.Build() + want := "eq(archiveStatus, $1) AND (eq(id, $2))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 2 { + t.Errorf("len(params) = %d, want 2", len(params)) + } +} + +func TestBuilder_PositionalParamsAreSequential(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "a"}, {Value: "b"}}) + b.EqGroupString("name", []filter.String{{Value: "c"}}) + expr, _ := b.Build() + if !strings.Contains(expr, "$1") || !strings.Contains(expr, "$2") || !strings.Contains(expr, "$3") { + t.Errorf("expected $1, $2, $3 in expr; got %q", expr) + } +} diff --git a/typed/filter/fulltext.go b/typed/filter/fulltext.go new file mode 100644 index 0000000..a025ef0 --- /dev/null +++ b/typed/filter/fulltext.go @@ -0,0 +1,21 @@ +package filter + +import "fmt" + +// AnyOfText adds a fulltext OR-match group: anyoftext(predicate, term). +// An empty term is a no-op. +func (b *Builder) AnyOfText(predicate, term string) { + if term == "" { + return + } + b.groups = append(b.groups, fmt.Sprintf("anyoftext(%s, %s)", predicate, b.param(term))) +} + +// AllOfText adds a fulltext AND-match group: alloftext(predicate, term). +// An empty term is a no-op. +func (b *Builder) AllOfText(predicate, term string) { + if term == "" { + return + } + b.groups = append(b.groups, fmt.Sprintf("alloftext(%s, %s)", predicate, b.param(term))) +} diff --git a/typed/filter/fulltext_test.go b/typed/filter/fulltext_test.go new file mode 100644 index 0000000..1d71e0b --- /dev/null +++ b/typed/filter/fulltext_test.go @@ -0,0 +1,41 @@ +package filter_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +func TestAnyOfTextEmitsFilterAndBindsParam(t *testing.T) { + b := &filter.Builder{} + b.AnyOfText("resourceName", "honda civic") + expr, params := b.Build() + if !strings.Contains(expr, "anyoftext(resourceName, $1)") { + t.Fatalf("expected anyoftext(resourceName, $1) in expr, got %q", expr) + } + if len(params) != 1 || params[0] != "honda civic" { + t.Fatalf("expected params [\"honda civic\"], got %v", params) + } +} + +func TestAllOfTextEmitsFilterAndBindsParam(t *testing.T) { + b := &filter.Builder{} + b.AllOfText("description", "engine block") + expr, params := b.Build() + if !strings.Contains(expr, "alloftext(description, $1)") { + t.Fatalf("expected alloftext(description, $1) in expr, got %q", expr) + } + if len(params) != 1 || params[0] != "engine block" { + t.Fatalf("expected params [\"engine block\"], got %v", params) + } +} + +func TestAnyOfTextEmptyTermIsNoop(t *testing.T) { + b := &filter.Builder{} + b.AnyOfText("resourceName", "") + expr, params := b.Build() + if expr != "" || params != nil { + t.Fatalf("expected empty expr/params for empty term, got %q / %v", expr, params) + } +} diff --git a/typed/multi_query.go b/typed/multi_query.go new file mode 100644 index 0000000..98409c6 --- /dev/null +++ b/typed/multi_query.go @@ -0,0 +1,191 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "strings" + + dg "github.com/dolan-in/dgman/v2" + "github.com/matthewmcneely/modusgraph" +) + +// MultiQuery batches N homogeneous-type Query[T] blocks into a single +// Dgraph multi-block request. All blocks return rows of the same T; the +// per-block result is keyed by the block name supplied at Add. +// +// Dgraph executes the blocks concurrently on the server side; the entire +// batch costs one gRPC round-trip. +type MultiQuery[T any] struct { + conn modusgraph.Client + names []string + blocks map[string]*Query[T] +} + +// NewMultiQuery constructs a MultiQuery bound to conn. +func NewMultiQuery[T any](conn modusgraph.Client) *MultiQuery[T] { + return &MultiQuery[T]{ + conn: conn, + blocks: make(map[string]*Query[T]), + } +} + +// Add registers a named block. Names must be unique within one MultiQuery. +// Panics on duplicate name — the call site is a programming error, not a +// runtime condition. +func (mq *MultiQuery[T]) Add(name string, q *Query[T]) *MultiQuery[T] { + if _, exists := mq.blocks[name]; exists { + panic(fmt.Sprintf("multi_query: duplicate block name %q", name)) + } + mq.names = append(mq.names, name) + mq.blocks[name] = q + return mq +} + +// BlockNames returns the registered block names in insertion order. +func (mq *MultiQuery[T]) BlockNames() []string { + out := make([]string, len(mq.names)) + copy(out, mq.names) + return out +} + +// Execute runs every registered block in a single Dgraph round-trip and +// returns the per-block results, keyed by the block name supplied at Add. +// A block that matched no rows appears as an empty (non-nil) slice in the +// result map; the key is always present. +// +// Execute rejects blocks that carry WhereEdge constraints — those require a +// runtime pre-pass that cannot be folded into the multi-block batch. Run such +// queries individually with Query.Nodes. +// +// Dgraph keys response JSON by predicate name (e.g. resourceName), but Go +// structs typically use their json tag (e.g. name). Execute remaps the keys +// per T's tags before decoding so a schema that uses `dgraph:"predicate=..."` +// with a divergent `json:"..."` decodes correctly — matching the behavior of +// dgman's QueryBlock.Scan path. +func (mq *MultiQuery[T]) Execute(ctx context.Context) (map[string][]T, error) { + if len(mq.names) == 0 { + return map[string][]T{}, nil + } + + rawBlocks := make([]*dg.Query, 0, len(mq.names)) + for _, name := range mq.names { + block := mq.blocks[name] + if len(block.edges) != 0 { + return nil, fmt.Errorf("multi_query: block %q carries WhereEdge constraints; MultiQuery cannot batch edge-filtered blocks", name) + } + // Name the underlying dgman query so blocks do not collide on the + // default "data" name and so the response JSON keys are predictable. + block.q.Name(name) + rawBlocks = append(rawBlocks, block.q) + } + + dql := dg.NewQueryBlock(rawBlocks...).String() + raw, err := mq.conn.QueryRaw(ctx, dql, nil) + if err != nil { + return nil, fmt.Errorf("multi_query: dgraph: %w", err) + } + + var perBlockRaw map[string]json.RawMessage + if err := json.Unmarshal(raw, &perBlockRaw); err != nil { + return nil, fmt.Errorf("multi_query: decoding response: %w", err) + } + + var zero T + predMap := buildPredicateToJSONMap(reflect.TypeOf(zero)) + + out := make(map[string][]T, len(mq.names)) + for _, name := range mq.names { + body, ok := perBlockRaw[name] + if !ok { + out[name] = []T{} + continue + } + if len(predMap) > 0 { + remapped, err := remapArrayKeys(body, predMap) + if err == nil { + body = remapped + } + } + var rows []T + if err := json.Unmarshal(body, &rows); err != nil { + return nil, fmt.Errorf("multi_query: decoding block %q: %w", name, err) + } + if rows == nil { + rows = []T{} + } + out[name] = rows + } + return out, nil +} + +// buildPredicateToJSONMap returns a map from dgraph predicate name → JSON tag +// name for fields on T where the two differ. Mirrors dgman's unexported helper +// of the same name; we need our own because the multi-block response from +// QueryRaw bypasses dgman's scan path. +func buildPredicateToJSONMap(t reflect.Type) map[string]string { + for t != nil && t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t == nil || t.Kind() != reflect.Struct { + return nil + } + result := make(map[string]string) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + jsonName := strings.Split(jsonTag, ",")[0] + if jsonName == "" { + continue + } + dgraphTag := field.Tag.Get("dgraph") + if dgraphTag == "" { + continue + } + var predName string + for _, part := range strings.Fields(dgraphTag) { + if strings.HasPrefix(part, "predicate=") { + predName = strings.TrimPrefix(part, "predicate=") + break + } + } + if predName == "" || predName == jsonName { + continue + } + if predName == "uid" || predName == "dgraph.type" { + continue + } + result[predName] = jsonName + } + return result +} + +// remapArrayKeys rewrites top-level keys in each object of a JSON array using +// the predicate → JSON-tag map. Nested objects are left untouched (search +// callers iterate scalar predicates of the root type; edge fields are +// hydrated lazily, not in the multi-block response). +func remapArrayKeys(data json.RawMessage, predMap map[string]string) (json.RawMessage, error) { + var rows []map[string]json.RawMessage + if err := json.Unmarshal(data, &rows); err != nil { + return data, err + } + for i, row := range rows { + for k, v := range row { + if newK, ok := predMap[k]; ok && newK != k { + delete(row, k) + row[newK] = v + } + } + rows[i] = row + } + return json.Marshal(rows) +} diff --git a/typed/multi_query_test.go b/typed/multi_query_test.go new file mode 100644 index 0000000..98f1ae4 --- /dev/null +++ b/typed/multi_query_test.go @@ -0,0 +1,127 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph/typed" +) + +func TestMultiQueryAddAccumulatesBlocks(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q1 := typed.NewClient[widget](conn).Query(context.Background()) + q2 := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("byName", q1) + mq.Add("byQty", q2) + got := mq.BlockNames() + if len(got) != 2 || got[0] != "byName" || got[1] != "byQty" { + t.Fatalf("BlockNames = %v, want [byName, byQty]", got) + } +} + +func TestMultiQueryAddRejectsDuplicateName(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("byName", q) + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on duplicate block name") + } + }() + mq.Add("byName", q) +} + +func TestMultiQueryExecuteReturnsPerBlockResults(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + c := typed.NewClient[widget](conn) + + for _, w := range []*widget{ + {Name: "sprocket", Qty: 1}, + {Name: "gear", Qty: 5}, + {Name: "bolt", Qty: 10}, + } { + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add %s: %v", w.Name, err) + } + } + + mq := typed.NewMultiQuery[widget](conn) + mq.Add("all", c.Query(ctx)) + mq.Add("filtered", c.Query(ctx).Filter("eq(name, $1)", "gear")) + + results, err := mq.Execute(ctx) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if got := len(results["all"]); got != 3 { + t.Fatalf("results[all] has %d rows, want 3", got) + } + if got := len(results["filtered"]); got != 1 { + t.Fatalf("results[filtered] has %d rows, want 1", got) + } + if results["filtered"][0].Name != "gear" { + t.Fatalf("results[filtered][0].Name = %q, want gear", results["filtered"][0].Name) + } +} + +func TestMultiQueryExecuteEmptyReturnsEmptyMap(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + results, err := mq.Execute(context.Background()) + if err != nil { + t.Fatalf("Execute on empty MultiQuery: %v", err) + } + if len(results) != 0 { + t.Fatalf("expected empty map, got %v", results) + } +} + +// renamed exercises the predicate-vs-json-tag remap. Dgraph returns the +// "thingName" key (the predicate name) but the struct's JSON tag is +// "name"; MultiQuery.Execute must remap before unmarshaling so Name +// populates. +type renamed struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"predicate=thingName index=hash,fulltext"` + Qty int `json:"qty,omitempty" dgraph:"index=int"` +} + +func TestMultiQueryExecuteRemapsPredicateKeys(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + c := typed.NewClient[renamed](conn) + + for _, w := range []*renamed{ + {Name: "alpha", Qty: 1}, + {Name: "beta", Qty: 2}, + } { + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add %s: %v", w.Name, err) + } + } + + mq := typed.NewMultiQuery[renamed](conn) + mq.Add("all", c.Query(ctx)) + results, err := mq.Execute(ctx) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rows := results["all"] + if len(rows) != 2 { + t.Fatalf("rows = %d, want 2", len(rows)) + } + for _, r := range rows { + if r.Name == "" { + t.Fatalf("Name not populated; multi-block response was not remapped from predicate key: %+v", r) + } + } +} diff --git a/typed/option.go b/typed/option.go new file mode 100644 index 0000000..d944483 --- /dev/null +++ b/typed/option.go @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +// Option configures a *T. Generated With constructors return an Option; +// generated New/Wrap constructors apply them via Apply. +type Option[T any] func(*T) + +// Apply applies opts to target in declaration order. +func Apply[T any](target *T, opts ...Option[T]) { + for _, opt := range opts { + opt(target) + } +} diff --git a/typed/option_test.go b/typed/option_test.go new file mode 100644 index 0000000..7c1f378 --- /dev/null +++ b/typed/option_test.go @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed" +) + +func TestApply_RunsOptionsInOrder(t *testing.T) { + type rec struct{ trail []string } + r := &rec{} + + typed.Apply(r, + func(x *rec) { x.trail = append(x.trail, "a") }, + func(x *rec) { x.trail = append(x.trail, "b") }, + func(x *rec) { x.trail = append(x.trail, "c") }, + ) + + if got := strings.Join(r.trail, ""); got != "abc" { + t.Fatalf("Apply ran options as %q, want %q", got, "abc") + } +} + +func TestApply_NoOptionsIsNoop(t *testing.T) { + type rec struct{ n int } + r := &rec{n: 7} + typed.Apply(r) + if r.n != 7 { + t.Fatalf("Apply with no options mutated target: n = %d, want 7", r.n) + } +} diff --git a/typed/query.go b/typed/query.go new file mode 100644 index 0000000..e4b2199 --- /dev/null +++ b/typed/query.go @@ -0,0 +1,565 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "fmt" + "iter" + "strconv" + "strings" + + dg "github.com/dolan-in/dgman/v2" + "github.com/matthewmcneely/modusgraph" +) + +// Query is a fluent, type-safe query builder over records of type T. Builder +// methods return *Query[T] for chaining, except As, Var, and GroupBy, which +// change the result shape and transition to *RawQuery; terminal methods +// (Nodes, First, IterNodes) execute the query and decode typed results. +// +// A Query is single-use. Builder methods mutate the underlying query in place +// and return the same *Query, so a Query value should be built as one chain +// and handed to a single terminal. It is not safe to save a Query to a +// variable and branch it into independent queries: every branch shares — and +// keeps mutating — the same underlying query. +// +// Repeated builder calls do not all behave the same way. Limit, Offset, After, +// Cascade, Name, RootFunc, and Vars overwrite: the last call wins. Filter, +// OrderAsc, OrderDesc, and WhereEdge accumulate: each call adds to the query. +// Accumulated Filter fragments AND together (see CombinedFilter, OrGroup). +// +// Limit and Offset additionally record the bounds that IterNodes pages +// within — a Limit caps the rows it streams, an Offset is its start. +type Query[T any] struct { + q *dg.Query + conn modusgraph.Client // runs the WhereEdge pre-pass; set by Client.Query + ctx context.Context // carried for the WhereEdge pre-pass query + limit int // caller-set row cap; 0 = unbounded + offset int // caller-set starting offset; 0 = none + edges []edgeFilter // accumulated WhereEdge constraints; empty = none + filters []filterFrag // accumulated @filter fragments, ANDed; empty = none +} + +// edgeFilter is one accumulated WhereEdge constraint: a dgraph @filter +// expression scoped to an outbound edge predicate of T. +type edgeFilter struct { + predicate string + filter string + params []any +} + +// filterFrag is one accumulated @filter fragment. Fragments join with AND. +type filterFrag struct { + expr string + params []any +} + +// NewDetachedQuery returns a Query[T] with no connection, used only to +// accumulate a filter expression: its By/Filter calls record fragments +// that CombinedFilter reads back. It must not be executed (it has no terminal +// path) and exists as the capture target behind the generated Or and +// WhereBy combinators. +func NewDetachedQuery[T any]() *Query[T] { + return &Query[T]{} +} + +// Filter adds a dgraph @filter expression. params bind to placeholders. +// Repeated calls accumulate: every fragment ANDs together. +func (qb *Query[T]) Filter(filter string, params ...any) *Query[T] { + qb.addFilter(filter, params) + return qb +} + +// addFilter accumulates one @filter fragment. Fragments AND together: the +// effective filter is every fragment joined with AND, each fragment's $N +// placeholders shifted to stay bound to its own params. dgman's own Filter is +// last-write-wins, so the full combined expression is re-pushed on every call. +// A detached query (nil q — used to capture a sub-scope's filter for OrGroup or +// WhereBy) accumulates with no dgman query to push to; CombinedFilter +// reads the fragments back. +func (qb *Query[T]) addFilter(expr string, params []any) { + if expr == "" { + return + } + qb.filters = append(qb.filters, filterFrag{expr: expr, params: params}) + if qb.q != nil { + combined, cp := combineAnd(qb.filters) + qb.q.Filter(combined, cp...) + } +} + +// combineAnd joins fragments with AND, renumbering each fragment's ordinal +// placeholders against the concatenated params slice. +func combineAnd(frags []filterFrag) (string, []any) { + parts := make([]string, 0, len(frags)) + var params []any + for _, f := range frags { + if f.expr == "" { + continue + } + parts = append(parts, shiftPlaceholders(f.expr, len(params))) + params = append(params, f.params...) + } + if len(parts) == 0 { + return "", nil + } + return strings.Join(parts, " AND "), params +} + +// CombinedFilter returns the AND-combined accumulated @filter expression and +// its params, or ("", nil) when no filter was set. It is the substrate behind +// the generated Or and WhereBy combinators: they run a sub-scope's +// By/Filter calls against a detached query, then fold the captured +// expression into a parent OR group or edge constraint. +func (qb *Query[T]) CombinedFilter() (string, []any) { + return combineAnd(qb.filters) +} + +// OrGroup adds one @filter group that ORs the combined filter of each sub. +// Each sub is a detached Query[T] whose By/Filter calls have been +// accumulated; their combined (AND) expressions are parenthesized, joined with +// OR, and the whole OR group ANDs with the receiver's other filters. Subs with +// an empty filter are skipped; an all-empty OrGroup is a no-op. It is the +// substrate behind the generated Query.Or combinator. +func (qb *Query[T]) OrGroup(subs ...*Query[T]) *Query[T] { + parts := make([]string, 0, len(subs)) + var params []any + for _, s := range subs { + e, p := s.CombinedFilter() + if e == "" { + continue + } + parts = append(parts, "("+shiftPlaceholders(e, len(params))+")") + params = append(params, p...) + } + if len(parts) == 0 { + return qb + } + qb.addFilter("("+strings.Join(parts, " OR ")+")", params) + return qb +} + +// OrderAsc orders results ascending by clause. +func (qb *Query[T]) OrderAsc(clause string) *Query[T] { + qb.q.OrderAsc(clause) + return qb +} + +// OrderDesc orders results descending by clause. +func (qb *Query[T]) OrderDesc(clause string) *Query[T] { + qb.q.OrderDesc(clause) + return qb +} + +// Limit caps the number of results. dgman names this First; it is renamed +// here so it does not collide with the First terminal. +func (qb *Query[T]) Limit(n int) *Query[T] { + qb.limit = n + qb.q.First(n) + return qb +} + +// Offset skips the first n results. +func (qb *Query[T]) Offset(n int) *Query[T] { + qb.offset = n + qb.q.Offset(n) + return qb +} + +// After returns results with UID greater than uid (cursor pagination). +func (qb *Query[T]) After(uid string) *Query[T] { + qb.q.After(uid) + return qb +} + +// Cascade drops nodes missing any of the given predicates (all, if none given). +func (qb *Query[T]) Cascade(predicates ...string) *Query[T] { + qb.q.Cascade(predicates...) + return qb +} + +// RootFunc overrides the query root function. dgman's default root function +// is type(); RootFunc replaces it with an expression such as +// eq(name, "Alice") or has(email). Repeated calls overwrite. +func (qb *Query[T]) RootFunc(rootFunc string) *Query[T] { + qb.q.RootFunc(rootFunc) + return qb +} + +// Name sets the query block name. It defaults to "data"; dgman uses the name +// to both generate and decode the query, so a renamed block still decodes +// into []T. Repeated calls overwrite. +func (qb *Query[T]) Name(queryName string) *Query[T] { + qb.q.Name(queryName) + return qb +} + +// Vars supplies GraphQL variables for a parameterized query: funcDef is the +// query function definition (for example "getByName($n: string)") and vars +// binds each variable. The query then executes via dgraph's QueryWithVars +// path. Repeated calls overwrite. +func (qb *Query[T]) Vars(funcDef string, vars map[string]string) *Query[T] { + qb.q.Vars(funcDef, vars) + return qb +} + +// WhereEdge constrains results to records that have at least one `predicate` +// edge whose target node satisfies the dgraph @filter expression. params bind +// to $N placeholders within filter, exactly as Filter binds them. +// +// Where Filter constrains T's own scalar predicates, WhereEdge constrains a +// neighbouring node reached over an edge. dgraph's root @filter cannot express +// that, so a query carrying WhereEdge constraints executes in two steps: a +// pre-pass resolves the UIDs of roots that satisfy every constraint, then the +// main query runs against uid(...) — keeping ordering, pagination, and result +// projection on the normal path. See +// docs/specs/2026-05-21-query-edge-filter-design.md. +// +// WhereEdge accumulates: multiple calls AND together (a record must satisfy +// every edge constraint). It is the substrate behind the generated +// Query.Where methods. +func (qb *Query[T]) WhereEdge(predicate, filter string, params ...any) *Query[T] { + qb.edges = append(qb.edges, edgeFilter{predicate: predicate, filter: filter, params: params}) + return qb +} + +// WhereAnyOfText adds an @filter(anyoftext(predicate, $1)) clause. It +// accumulates and ANDs with other filters like Filter. +func (qb *Query[T]) WhereAnyOfText(predicate, term string) *Query[T] { + qb.addFilter(fmt.Sprintf("anyoftext(%s, $1)", predicate), []any{term}) + return qb +} + +// WhereAllOfText adds an @filter(alloftext(predicate, $1)) clause. It +// accumulates and ANDs with other filters like Filter. +func (qb *Query[T]) WhereAllOfText(predicate, term string) *Query[T] { + qb.addFilter(fmt.Sprintf("alloftext(%s, $1)", predicate), []any{term}) + return qb +} + +// As names the query block as a dgraph query variable. dgraph requires such a +// variable be consumed by another block, which a single-block typed query +// cannot do, so As transitions out of the typed query: it returns a *RawQuery, +// which exposes no node terminal. +func (qb *Query[T]) As(varName string) *RawQuery { + qb.q.As(varName) + return &RawQuery{q: qb.q} +} + +// Var marks the query block as a dgraph var block. A var block computes query +// variables and returns no data of its own, so Var transitions out of the +// typed query: it returns a *RawQuery, which exposes no node terminal. +func (qb *Query[T]) Var() *RawQuery { + qb.q.Var() + return &RawQuery{q: qb.q} +} + +// GroupBy adds an @groupby(predicate) aggregation. A grouped query returns +// aggregation groups rather than a slice of T, so GroupBy transitions out of +// the typed query: it returns a *RawQuery, which exposes no node terminal. +func (qb *Query[T]) GroupBy(predicate string) *RawQuery { + qb.q.GroupBy(predicate) + return &RawQuery{q: qb.q} +} + +// Nodes executes the query and returns all matching records. +func (qb *Query[T]) Nodes() (out []T, err error) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() + matched, err := qb.resolveRoots() + if err != nil { + return nil, err + } + if !matched { + return nil, nil + } + if err = qb.q.Nodes(&out); err != nil { + return nil, err + } + return out, nil +} + +// First executes the query with an implicit Limit(1) and returns the first +// record, or (nil, nil) if the query matched no rows. +func (qb *Query[T]) First() (rec *T, err error) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() + matched, err := qb.resolveRoots() + if err != nil { + return nil, err + } + if !matched { + return nil, nil + } + var out []T + if err = qb.q.First(1).Nodes(&out); err != nil { + return nil, err + } + if len(out) == 0 { + return nil, nil + } + return &out[0], nil +} + +// IterNodes executes the query and returns an iterator over matching records, +// paging transparently so a large result set is never materialized at once. +// +// IterNodes is a terminal operation: it drives Offset/Limit internally as it +// pages and leaves the builder spent — do not call another terminal on the +// same Query afterward. A Limit set on the query caps the total number of +// rows streamed; an Offset is the starting point. +// +// All pages execute against one read-only transaction, so the iteration reads +// a single consistent snapshot: a concurrent writer cannot make it skip or +// repeat rows. A WhereEdge pre-pass, when present, runs once before paging +// begins, in its own transaction. On error it yields a final (nil, err) and +// stops. +func (qb *Query[T]) IterNodes() iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + var ferr error + defer func() { span.End(ferr) }() + matched, err := qb.resolveRoots() + if err != nil { + ferr = err + yield(nil, err) + return + } + if !matched { + return // edge constraints present, but no root matched + } + remaining := qb.limit // 0 = unbounded + for off := qb.offset; ; off += defaultPageSize { + size := defaultPageSize + if remaining > 0 && remaining < size { + size = remaining // shrink the last page so it can't overshoot the cap + } + var page []T + if err := qb.q.Offset(off).First(size).Nodes(&page); err != nil { + ferr = err + yield(nil, err) + return + } + for i := range page { + if !yield(&page[i], nil) { + return // consumer broke out + } + } + if remaining > 0 { + if remaining -= len(page); remaining <= 0 { + return // hit the caller's Limit + } + } + if len(page) < size { + return // result set exhausted + } + } + } +} + +// Raw returns the underlying dgman query for operations Query does not wrap +// (for example the raw-selection Query method). Raw does not carry WhereEdge +// constraints — those are resolved only when a terminal runs. +func (qb *Query[T]) Raw() *dg.Query { + return qb.q +} + +// UID roots the query at a specific node UID. Results still decode into []T. +func (qb *Query[T]) UID(uid string) *Query[T] { + qb.q.UID(uid) + return qb +} + +// All sets the edge-traversal depth for this query, overriding the client's +// default maxEdgeTraversal. Use a small depth to stay under Dgraph's 4MB gRPC +// limit on highly-connected entities. +func (qb *Query[T]) All(depth int) *Query[T] { + qb.q.All(depth) + return qb +} + +// NodesAndCount executes the query and returns the matching records together +// with the total count (useful for pagination totals). Like Nodes, it runs the +// WhereEdge pre-pass first when edge constraints are present. +func (qb *Query[T]) NodesAndCount() ([]T, int, error) { + matched, err := qb.resolveRoots() + if err != nil { + return nil, 0, err + } + if !matched { + return nil, 0, nil + } + var out []T + count, err := qb.q.NodesAndCount(&out) + if err != nil { + return nil, 0, err + } + return out, count, nil +} + +// String renders the generated DQL without executing it. WhereEdge constraints +// are not reflected — they are resolved only when a terminal runs. +func (qb *Query[T]) String() string { + return qb.q.String() +} + +// FormatBlock renders the query as a single DQL block named name, without +// executing it. The returned text is suitable for inclusion inside a wrapping +// "{ ... }" multi-block request — it does not include outer braces. +// +// FormatBlock is the substrate behind MultiQuery; external callers can use it +// to compose typed queries into larger hand-written DQL requests. +// +// Filter parameters are inlined at Filter-call time (dgman renders $N +// placeholders into the filter string immediately), so the returned block +// carries no unresolved variables. WhereEdge constraints are not formatted — +// they require a runtime pre-pass and would produce no useful output here. +func (qb *Query[T]) FormatBlock(name string) (string, error) { + if len(qb.edges) != 0 { + return "", fmt.Errorf("typed: FormatBlock cannot render a Query carrying WhereEdge constraints") + } + qb.q.Name(name) + wrapped := dg.NewQueryBlock(qb.q).String() + // QueryBlock.String() wraps the block in "{\n ... }" — strip the wrapper so + // the caller can compose blocks inside their own braces. + inner := strings.TrimPrefix(wrapped, "{\n") + inner = strings.TrimSuffix(inner, "}") + return inner, nil +} + +// RawQuery is a query whose result is not a slice of T — produced by the +// shape-changing builders Query.As, Query.Var, and Query.GroupBy. A RawQuery +// deliberately exposes no typed node terminal: its result must be decoded by +// the caller through the underlying dgman query, obtained via Raw. +type RawQuery struct { + q *dg.Query +} + +// Raw returns the underlying dgman query, for the caller to execute and decode. +func (r *RawQuery) Raw() *dg.Query { + return r.q +} + +// String returns the generated DQL. +func (r *RawQuery) String() string { + return r.q.String() +} + +// As names the block as a dgraph query variable. See Query.As. +func (r *RawQuery) As(varName string) *RawQuery { + r.q.As(varName) + return r +} + +// Var marks the block as a dgraph var block. See Query.Var. +func (r *RawQuery) Var() *RawQuery { + r.q.Var() + return r +} + +// GroupBy adds an @groupby(predicate) aggregation. See Query.GroupBy. +func (r *RawQuery) GroupBy(predicate string) *RawQuery { + r.q.GroupBy(predicate) + return r +} + +// resolveRoots runs the WhereEdge pre-pass when the query carries edge +// constraints, rewriting the main query's root function to the matching UIDs. +// It returns matched=false when constraints are present but no root satisfied +// them — callers then return an empty result without running the main query. +// With no edge constraints it is a no-op returning matched=true. +func (qb *Query[T]) resolveRoots() (matched bool, err error) { + if len(qb.edges) == 0 { + return true, nil + } + uids, err := qb.matchedUIDs() + if err != nil { + return false, err + } + if len(uids) == 0 { + return false, nil + } + qb.q.RootFunc("uid(" + strings.Join(uids, ", ") + ")") + return true, nil +} + +// matchedUIDs runs the pre-pass: an @cascade query over T that keeps only +// nodes whose every WhereEdge predicate has a target matching its filter, and +// returns those nodes' UIDs. +func (qb *Query[T]) matchedUIDs() ([]string, error) { + var z T + pre := qb.conn.Query(qb.ctx, &z) + body, params := qb.edgeMatchBody() + pre.Cascade().Query(body, params...) + + var rows []struct { + UID string `json:"uid"` + } + if err := pre.Nodes(&rows); err != nil { + return nil, err + } + uids := make([]string, len(rows)) + for i := range rows { + uids[i] = rows[i].UID + } + return uids, nil +} + +// edgeMatchBody renders the selection set for the pre-pass: uid plus one +// aliased, filtered block per WhereEdge constraint. The caller adds a bare +// @cascade, which then drops any node with an empty block — so a survivor +// satisfies every constraint. Blocks are aliased mg_e0, mg_e1, ... so two +// constraints on the same predicate do not collide as duplicate fields. Each +// fragment's $N placeholders are shifted to stay bound to its own params once +// every fragment's params are concatenated into one slice. +func (qb *Query[T]) edgeMatchBody() (body string, params []any) { + var b strings.Builder + b.WriteString("{\n\tuid\n") + for i, e := range qb.edges { + b.WriteString("\tmg_e") + b.WriteString(strconv.Itoa(i)) + b.WriteString(" : ") + b.WriteString(e.predicate) + b.WriteString(" @filter(") + b.WriteString(shiftPlaceholders(e.filter, len(params))) + b.WriteString(") { uid }\n") + params = append(params, e.params...) + } + b.WriteString("}") + return b.String(), params +} + +// shiftPlaceholders rewrites dgman ordinal placeholders ($1, $2, ...) in expr, +// adding delta to each index. WhereEdge filters are written independently, each +// numbering its params from $1; concatenating them into one pre-pass body +// needs every fragment renumbered against the combined params slice. A '$' not +// followed by a digit is left as-is, matching dgman's parseQueryWithParams. +func shiftPlaceholders(expr string, delta int) string { + if delta == 0 || !strings.ContainsRune(expr, '$') { + return expr + } + var b strings.Builder + for i := 0; i < len(expr); i++ { + if expr[i] != '$' { + b.WriteByte(expr[i]) + continue + } + j := i + 1 + for j < len(expr) && expr[j] >= '0' && expr[j] <= '9' { + j++ + } + if j == i+1 { // '$' not followed by digits — leave verbatim + b.WriteByte('$') + continue + } + n, _ := strconv.Atoi(expr[i+1 : j]) + b.WriteByte('$') + b.WriteString(strconv.Itoa(n + delta)) + i = j - 1 + } + return b.String() +} diff --git a/typed/query_test.go b/typed/query_test.go new file mode 100644 index 0000000..588bf6b --- /dev/null +++ b/typed/query_test.go @@ -0,0 +1,1294 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "strings" + "testing" + + dg "github.com/dolan-in/dgman/v2" + "github.com/go-logr/logr/funcr" + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// newCountingConn builds a file-backed modusgraph client exactly like newConn, +// but wires in a logr.Logger that counts dgman query executions. dgman logs +// every executed query at verbosity 3 with the message "execute query"; the +// returned *int is incremented once per such log line. +// +// dgman's logger is process-global, and modusgraph allows only one live +// file-backed engine per process (see modusgraph.ErrSingletonOnly). Each call +// uses a fresh t.TempDir() URI for data isolation. Tests that use +// newCountingConn must NOT call t.Parallel(): a second live client would hit +// the engine singleton, and parallel tests would also corrupt the shared +// query count. +func newCountingConn(t *testing.T, count *int) modusgraph.Client { + t.Helper() + logger := funcr.New(func(_, args string) { + // funcr renders the message into args as `"msg"="execute query"`. + // Match that exact pair so unrelated dgman/pool log lines (which log + // other messages, e.g. "executeQuery" for query blocks) are ignored. + if strings.Contains(args, `"msg"="execute query"`) { + *count++ + } + }, funcr.Options{Verbosity: 3}) + conn, err := modusgraph.NewClient("file://"+t.TempDir(), + modusgraph.WithAutoSchema(true), modusgraph.WithLogger(logger)) + if err != nil { + t.Fatalf("modusgraph.NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestQuery_NodesReturnsAll(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + + got, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("Nodes returned %d records, want 3", len(got)) + } +} + +func TestQuery_LimitCapsResults(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for i := range 5 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + got, err := c.Query(ctx).Limit(2).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("Limit(2) returned %d records, want 2", len(got)) + } +} + +func TestQuery_FirstReturnsAMatch(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "only", Qty: 7}); err != nil { + t.Fatalf("Add: %v", err) + } + + got, err := c.Query(ctx).First() + if err != nil { + t.Fatalf("First: %v", err) + } + if got == nil || got.Name != "only" { + t.Fatalf("First returned %+v, want Name=only", got) + } +} + +func TestQuery_FirstNoMatchReturnsNilNil(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + got, err := c.Query(ctx).First() + if err != nil { + t.Fatalf("First on empty: unexpected error %v", err) + } + if got != nil { + t.Fatalf("First on empty returned %+v, want nil", got) + } +} + +func TestQuery_BuilderChainCompilesAndRuns(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "x", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + + // Every builder method must return *Query[widget] so the chain stays typed. + _, err := c.Query(ctx). + OrderAsc("qty"). + Offset(0). + Limit(10). + Cascade(). + Nodes() + if err != nil { + t.Fatalf("builder chain Nodes: %v", err) + } +} + +func TestQuery_RawExposesUnderlyingBuilder(t *testing.T) { + c := typed.NewClient[widget](newConn(t)) + if c.Query(context.Background()).Raw() == nil { + t.Fatal("Raw() returned nil; expected the underlying *dg.Query") + } +} + +func TestQuery_Filter(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Insert three widgets with distinct names. + for _, name := range []string{"alpha", "beta", "gamma"} { + if err := c.Add(ctx, &widget{Name: name}); err != nil { + t.Fatalf("Add %s: %v", name, err) + } + } + + // Filter to exactly those whose name equals "beta" (index=exact allows eq()). + got, err := c.Query(ctx).Filter(`eq(name, "beta")`).Nodes() + if err != nil { + t.Fatalf("Filter Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf("Filter returned %d records, want 1", len(got)) + } + if got[0].Name != "beta" { + t.Fatalf("Filter returned Name=%q, want beta", got[0].Name) + } +} + +func TestQuery_FilterAccumulatesWithAnd(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Three widgets; only "beta"/9 satisfies BOTH name=="beta" and qty>=5. + for _, w := range []widget{ + {Name: "alpha", Qty: 9}, + {Name: "beta", Qty: 9}, + {Name: "beta", Qty: 1}, + } { + if err := c.Add(ctx, &w); err != nil { + t.Fatalf("Add %+v: %v", w, err) + } + } + + // Two Filter calls must AND together, not overwrite. With last-write-wins + // only ge(qty, 5) survives and this returns the two qty>=5 rows instead of + // the single AND match. + got, err := c.Query(ctx). + Filter(`eq(name, "beta")`). + Filter(`ge(qty, "5")`). + Nodes() + if err != nil { + t.Fatalf("Filter Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf("two ANDed Filters returned %d records, want 1 (name==beta AND qty>=5)", len(got)) + } + if got[0].Name != "beta" || got[0].Qty != 9 { + t.Fatalf("got %+v, want Name=beta Qty=9", got[0]) + } +} + +func TestQuery_CombinedFilterShiftsPlaceholders(t *testing.T) { + q := typed.NewDetachedQuery[widget]() + if expr, params := q.CombinedFilter(); expr != "" || params != nil { + t.Fatalf("empty CombinedFilter = (%q, %v), want (\"\", nil)", expr, params) + } + q.Filter("eq(name, $1)", "a") + q.Filter("eq(qty, $1)", 7) + expr, params := q.CombinedFilter() + const want = "eq(name, $1) AND eq(qty, $2)" + if expr != want { + t.Fatalf("CombinedFilter expr = %q, want %q", expr, want) + } + if len(params) != 2 || params[0] != "a" || params[1] != 7 { + t.Fatalf("CombinedFilter params = %v, want [a 7]", params) + } +} + +func TestQuery_OrGroup(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, w := range []widget{ + {Name: "alpha", Qty: 9}, + {Name: "beta", Qty: 9}, + {Name: "gamma", Qty: 1}, + } { + if err := c.Add(ctx, &w); err != nil { + t.Fatalf("Add %+v: %v", w, err) + } + } + + // name == "alpha" OR name == "gamma": two of three rows. + got, err := c.Query(ctx).OrGroup( + typed.NewDetachedQuery[widget]().Filter(`eq(name, "alpha")`), + typed.NewDetachedQuery[widget]().Filter(`eq(name, "gamma")`), + ).Nodes() + if err != nil { + t.Fatalf("OrGroup Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("OrGroup(alpha, gamma) returned %d rows, want 2", len(got)) + } + + // AND-of-OR: qty>=5 AND (name==alpha OR name==gamma) → only alpha/9. + got, err = c.Query(ctx). + Filter(`ge(qty, "5")`). + OrGroup( + typed.NewDetachedQuery[widget]().Filter(`eq(name, "alpha")`), + typed.NewDetachedQuery[widget]().Filter(`eq(name, "gamma")`), + ).Nodes() + if err != nil { + t.Fatalf("AND-of-OR Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "alpha" { + t.Fatalf("qty>=5 AND (alpha OR gamma) returned %+v, want [alpha/9]", got) + } +} + +func TestQuery_OrderAscDesc(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Insert widgets with distinct Qty values in non-sorted order so a + // stable natural ordering cannot hide a missing sort. + qtys := []int{30, 10, 50, 20, 40} + for i, q := range qtys { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add widget[%d]: %v", i, err) + } + } + + // Ascending. + asc, err := c.Query(ctx).OrderAsc("qty").Nodes() + if err != nil { + t.Fatalf("OrderAsc Nodes: %v", err) + } + if len(asc) != len(qtys) { + t.Fatalf("OrderAsc returned %d records, want %d", len(asc), len(qtys)) + } + for i := range len(asc) - 1 { + if asc[i].Qty > asc[i+1].Qty { + t.Fatalf("OrderAsc: asc[%d].Qty=%d > asc[%d].Qty=%d; not ascending", + i, asc[i].Qty, i+1, asc[i+1].Qty) + } + } + + // Descending. + desc, err := c.Query(ctx).OrderDesc("qty").Nodes() + if err != nil { + t.Fatalf("OrderDesc Nodes: %v", err) + } + if len(desc) != len(qtys) { + t.Fatalf("OrderDesc returned %d records, want %d", len(desc), len(qtys)) + } + for i := range len(desc) - 1 { + if desc[i].Qty < desc[i+1].Qty { + t.Fatalf("OrderDesc: desc[%d].Qty=%d < desc[%d].Qty=%d; not descending", + i, desc[i].Qty, i+1, desc[i+1].Qty) + } + } +} + +func TestQuery_OffsetSkipsResults(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Five widgets with distinct, deliberately unsorted Qty values. + qtys := []int{40, 10, 50, 20, 30} + for i, q := range qtys { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add widget[%d]: %v", i, err) + } + } + + // Ordering ascending by qty gives 10,20,30,40,50; Offset(2) drops the + // first two, so 3 rows remain and the first is the 3rd-smallest (30). + got, err := c.Query(ctx).OrderAsc("qty").Offset(2).Nodes() + if err != nil { + t.Fatalf("Offset Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("OrderAsc.Offset(2) returned %d records, want 3", len(got)) + } + if got[0].Qty != 30 { + t.Fatalf("first row after Offset(2) has Qty=%d, want 30 (3rd-smallest)", got[0].Qty) + } +} + +func TestQuery_AfterCursor(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for i := range 5 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + // First pass: grab all rows so we can pick a non-last cursor UID. + all, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("first Nodes: %v", err) + } + if len(all) < 3 { + t.Fatalf("expected at least 3 widgets, got %d", len(all)) + } + cursor := all[1].UID // a non-last row + + // After(cursor) uses default UID ordering to skip past the cursor node. + got, err := c.Query(ctx).After(cursor).Nodes() + if err != nil { + t.Fatalf("After Nodes: %v", err) + } + if len(got) == 0 { + t.Fatal("After(cursor) returned no rows; expected the rows past the cursor") + } + for _, w := range got { + if w.UID <= cursor { + t.Fatalf("After(%s) returned UID %s, which is not strictly greater than the cursor", + cursor, w.UID) + } + } +} + +func TestQuery_CascadeDropsIncompleteNodes(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Widgets with Qty > 0 carry a qty predicate. Widgets with Qty left 0 + // have it omitted entirely (json tag is omitempty), so they have no qty + // predicate at all. + withQty := []int{5, 9, 13} + for _, q := range withQty { + if err := c.Add(ctx, &widget{Name: "has-qty", Qty: q}); err != nil { + t.Fatalf("Add qty=%d: %v", q, err) + } + } + for i := range 4 { + if err := c.Add(ctx, &widget{Name: "no-qty"}); err != nil { + t.Fatalf("Add no-qty[%d]: %v", i, err) + } + } + + // @cascade(qty) drops any node that lacks the qty predicate. + got, err := c.Query(ctx).Cascade("qty").Nodes() + if err != nil { + t.Fatalf("Cascade Nodes: %v", err) + } + if len(got) != len(withQty) { + t.Fatalf("Cascade(qty) returned %d records, want %d (only the qty-bearing widgets)", + len(got), len(withQty)) + } + for _, w := range got { + if w.Qty == 0 { + t.Fatalf("Cascade(qty) returned a widget with Qty=0 (no qty predicate): %+v", w) + } + } +} + +func TestQuery_FilterOrderLimitOffsetCombined(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // A known set: five "keep" widgets plus a "drop" widget the filter excludes. + for _, q := range []int{50, 20, 40, 10, 30} { + if err := c.Add(ctx, &widget{Name: "keep", Qty: q}); err != nil { + t.Fatalf("Add keep qty=%d: %v", q, err) + } + } + if err := c.Add(ctx, &widget{Name: "drop", Qty: 99}); err != nil { + t.Fatalf("Add drop: %v", err) + } + + // Filter to name=keep -> qtys {10,20,30,40,50}; OrderAsc -> sorted; + // Offset(1) drops 10; Limit(2) keeps {20,30}. + got, err := c.Query(ctx). + Filter(`eq(name, "keep")`). + OrderAsc("qty"). + Offset(1). + Limit(2). + Nodes() + if err != nil { + t.Fatalf("combined chain Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("combined chain returned %d records, want 2", len(got)) + } + if got[0].Qty != 20 || got[1].Qty != 30 { + t.Fatalf("combined chain window = [%d, %d], want [20, 30]", got[0].Qty, got[1].Qty) + } +} + +func TestQuery_FirstOnMultipleRows(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for _, q := range []int{30, 10, 20} { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add qty=%d: %v", q, err) + } + } + + // First on an ascending-by-qty query yields exactly the smallest row. + got, err := c.Query(ctx).OrderAsc("qty").First() + if err != nil { + t.Fatalf("First: %v", err) + } + if got == nil { + t.Fatal("First returned nil on a non-empty result set") + } + if got.Qty != 10 { + t.Fatalf("First on OrderAsc(qty) returned Qty=%d, want 10 (smallest)", got.Qty) + } +} + +func TestQuery_NodesEmptyResult(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) // fresh client, no inserts + + got, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Nodes on empty client: unexpected error %v", err) + } + if len(got) != 0 { + t.Fatalf("Nodes on empty client returned %d records, want 0", len(got)) + } +} + +func TestQuery_OrderAccumulates(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // OrderAsc and OrderDesc accumulate: both clauses must survive on the + // same query. dgman renders them as "orderasc:"/"orderdesc:" in the + // generated query string. + q := c.Query(ctx).OrderAsc("name").OrderDesc("qty") + s := q.Raw().String() + if !strings.Contains(s, "orderasc: name") { + t.Fatalf("query string missing ascending name order; got:\n%s", s) + } + if !strings.Contains(s, "orderdesc: qty") { + t.Fatalf("query string missing descending qty order; got:\n%s", s) + } +} + +func TestQuery_CascadeOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Cascade overwrites: the second call wins, the first predicate is gone. + // dgman renders predicates as @cascade(pred1,pred2,...) with no spaces. + q := c.Query(ctx).Cascade("name").Cascade("qty") + s := q.Raw().String() + if !strings.Contains(s, "@cascade(qty)") { + t.Fatalf("second Cascade(qty) not rendered in query string; got:\n%s", s) + } + if strings.Contains(s, "@cascade(name)") { + t.Fatalf("first Cascade(name) still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_TerminalRunsTwice(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + + // A terminal is re-runnable: calling Nodes twice on the same builder + // succeeds both times and yields equal-length results. + q := c.Query(ctx) + first, err := q.Nodes() + if err != nil { + t.Fatalf("first Nodes: %v", err) + } + second, err := q.Nodes() + if err != nil { + t.Fatalf("second Nodes: %v", err) + } + if len(first) != len(second) { + t.Fatalf("Nodes run twice returned %d then %d records; want equal lengths", + len(first), len(second)) + } +} + +func TestQuery_BuilderAliasesAndAccumulates(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // (i) Filter accumulates: after two Filter calls both survive, ANDed. + q := c.Query(ctx) + q.Filter(`eq(name, "alpha")`) + q.Filter(`eq(name, "beta")`) + s := q.Raw().String() + if !strings.Contains(s, `eq(name, "alpha")`) { + t.Fatalf("Filter A dropped; want both fragments present in:\n%s", s) + } + if !strings.Contains(s, `eq(name, "beta")`) { + t.Fatalf("Filter B dropped; want both fragments present in:\n%s", s) + } + if !strings.Contains(s, " AND ") { + t.Fatalf("accumulated filters not ANDed; got:\n%s", s) + } + + // (ii) The builder aliases: a saved reference and further mutation observe + // the same underlying query. ref and q point at the same *Query, so a + // mutation through one is visible through the other. This documents the + // single-use footgun: you cannot branch a saved builder. + ref := q + if ref != q { + t.Fatal("builder reference is not identical to the original *Query") + } + q.OrderAsc("name") + if ref.Raw().String() != q.Raw().String() { + t.Fatal("mutating q did not affect ref; builder is expected to alias a shared query") + } + if !strings.Contains(ref.Raw().String(), "orderasc: name") { + t.Fatalf("order applied via q not visible through ref; got:\n%s", ref.Raw().String()) + } +} + +func TestQuery_RawRoundTrips(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "raw-target", Qty: 7}); err != nil { + t.Fatalf("Add: %v", err) + } + + // Take the raw *dg.Query, apply a dgman-only builder method directly, + // then execute via the raw query's own Nodes(&dst). + var raw *dg.Query = c.Query(ctx).Raw() + raw.OrderAsc("qty") + + var dst []widget + if err := raw.Nodes(&dst); err != nil { + t.Fatalf("raw query Nodes: %v", err) + } + if len(dst) != 1 { + t.Fatalf("raw query returned %d records, want 1", len(dst)) + } + if dst[0].Name != "raw-target" || dst[0].Qty != 7 { + t.Fatalf("raw query returned %+v, want Name=raw-target Qty=7", dst[0]) + } +} + +func TestQuery_SingleQueryPerTerminal(t *testing.T) { + // Uses the global dgman logger; must not run in parallel. + ctx := context.Background() + // queriesExecuted is incremented by newCountingConn's logger each time + // dgman runs a query, so it reflects real database round-trips. + var queriesExecuted int + c := typed.NewClient[widget](newCountingConn(t, &queriesExecuted)) + + for i := range 2 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + // Building the chain runs no queries: builder methods only mutate the AST. + before := queriesExecuted + q := c.Query(ctx).Filter(`eq(name, "w")`).OrderAsc("qty").Limit(10) + if queriesExecuted != before { + t.Fatalf("builder methods executed %d queries, want 0", queriesExecuted-before) + } + + // The Nodes terminal runs exactly one query. + if _, err := q.Nodes(); err != nil { + t.Fatalf("Nodes: %v", err) + } + if got := queriesExecuted - before; got != 1 { + t.Fatalf("Nodes executed %d queries, want exactly 1", got) + } + + // A fresh builder's First terminal also runs exactly one query. + before = queriesExecuted + if _, err := c.Query(ctx).First(); err != nil { + t.Fatalf("First: %v", err) + } + if got := queriesExecuted - before; got != 1 { + t.Fatalf("First executed %d queries, want exactly 1", got) + } +} + +func TestIterNodes_StreamsAll(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 125 // > defaultPageSize (50): forces multiple pages + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for w, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + if w == nil { + t.Fatal("IterNodes yielded a nil widget") + } + seen++ + } + if seen != n { + t.Fatalf("IterNodes streamed %d records, want %d", seen, n) + } +} + +func TestIterNodes_StopsOnConsumerBreak(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + if seen == 10 { + break + } + } + if seen != 10 { + t.Fatalf("IterNodes yielded %d records after break at 10, want 10", seen) + } +} + +func TestIterNodes_EmptyResult(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + seen := 0 + for _, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes over empty set yielded error: %v", err) + } + seen++ + } + if seen != 0 { + t.Fatalf("IterNodes over empty set yielded %d records, want 0", seen) + } +} + +func TestIterNodes_RespectsLimit(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 100 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).Limit(30).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != 30 { + t.Fatalf("Limit(30).IterNodes() streamed %d records, want 30", seen) + } +} + +func TestIterNodes_LimitExceedsResultSet(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 30 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).Limit(500).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != n { + t.Fatalf("Limit(500).IterNodes() over %d records streamed %d, want %d", n, seen, n) + } +} + +func TestIterNodes_RespectsOffset(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 (not 0) so omitempty never suppresses the field, + // keeping OrderAsc("qty") a true total order over all records. + const n = 10 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + var got []int + for w, err := range c.Query(ctx).OrderAsc("qty").Offset(3).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + got = append(got, w.Qty) + } + if len(got) != 7 { + t.Fatalf("Offset(3).IterNodes() streamed %d records, want 7", len(got)) + } + for i, q := range got { + if q != i+4 { // Qty=1..10; offset 3 skips 1,2,3 → starts at 4 + t.Fatalf("Offset(3).IterNodes()[%d] Qty = %d, want %d", i, q, i+4) + } + } +} + +func TestIterNodes_RespectsOffsetAndLimit(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 so omitempty never suppresses the field and + // OrderAsc("qty") is a strict total order across all 200 records. + const n = 200 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + var got []int + for w, err := range c.Query(ctx).OrderAsc("qty").Offset(60).Limit(120).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + got = append(got, w.Qty) + } + if len(got) != 120 { + t.Fatalf("Offset(60).Limit(120).IterNodes() streamed %d records, want 120", len(got)) + } + for i, q := range got { + if q != i+61 { // Qty=1..200; offset 60 skips 1..60 → starts at 61 + t.Fatalf("result[%d] Qty = %d, want %d", i, q, i+61) + } + } +} + +func TestIterNodes_OneQueryPerPage(t *testing.T) { + ctx := context.Background() + var queriesExecuted int + c := typed.NewClient[widget](newCountingConn(t, &queriesExecuted)) + const n = 125 // ceil(125/50) = 3 page queries + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + // Obtaining the iterator runs no query — IterNodes is lazy. + seq := c.Query(ctx).IterNodes() + if queriesExecuted != 0 { + t.Fatalf("building the IterNodes iterator executed %d queries, want 0", queriesExecuted) + } + seen := 0 + for _, err := range seq { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != n { + t.Fatalf("IterNodes streamed %d records, want %d", seen, n) + } + if queriesExecuted != 3 { + t.Fatalf("IterNodes over %d records ran %d queries, want 3", n, queriesExecuted) + } +} + +func TestIterNodes_YieldsErrorAndStops(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "w", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + // A syntactically invalid @filter (unbalanced parenthesis) makes the page + // query fail at execution; IterNodes must yield one (nil, err) and stop. + gotErr := false + for w, err := range c.Query(ctx).Filter(`eq(name, "w"`).IterNodes() { + if err != nil { + gotErr = true + if w != nil { + t.Fatalf("error yield carried a non-nil widget: %+v", w) + } + break + } + t.Fatal("IterNodes over a malformed query yielded a record before erroring") + } + if !gotErr { + t.Fatal("IterNodes over a malformed query did not yield an error") + } +} + +func TestQuery_LimitOffsetStillDriveNodes(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 so omitempty never suppresses the field and + // OrderAsc("qty") is a strict total order across all records. + const n = 10 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + // Regression: Limit/Offset now also set Query struct fields; confirm they + // still drive the Nodes terminal. + got, err := c.Query(ctx).OrderAsc("qty").Offset(2).Limit(3).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("Offset(2).Limit(3).Nodes() returned %d records, want 3", len(got)) + } + for i, w := range got { + if w.Qty != i+3 { // Qty=1..10; offset 2 skips 1,2 → starts at 3 + t.Fatalf("Nodes()[%d] Qty = %d, want %d", i, w.Qty, i+3) + } + } +} + +func TestQuery_RootFuncOverridesRoot(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // RootFunc replaces the default type(widget) root with an eq() lookup; + // the query still decodes into []widget. + got, err := c.Query(ctx).RootFunc(`eq(name, "b")`).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf(`RootFunc(eq(name,"b")).Nodes() returned %d records, want 1`, len(got)) + } + if got[0].Name != "b" { + t.Fatalf("RootFunc lookup returned %q, want \"b\"", got[0].Name) + } +} + +func TestQuery_RootFuncRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // RootFunc renders into the (func: ...) position and overwrites: the + // second call wins. + q := c.Query(ctx).RootFunc(`eq(name, "x")`).RootFunc(`eq(name, "y")`) + s := q.Raw().String() + if !strings.Contains(s, `func: eq(name, "y")`) { + t.Fatalf("second RootFunc not rendered; got:\n%s", s) + } + if strings.Contains(s, `eq(name, "x")`) { + t.Fatalf("first RootFunc still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_NameDecodesAfterRename(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // Name renames the query block. dgman uses the name symmetrically to + // generate and decode, so a renamed block still decodes into []widget. + got, err := c.Query(ctx).Name("widgets").Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf(`Name("widgets").Nodes() returned %d records, want 3`, len(got)) + } +} + +func TestQuery_NameRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Name renders as the block name and overwrites: the second call wins. + q := c.Query(ctx).Name("first").Name("second") + s := q.Raw().String() + if !strings.Contains(s, "second(func:") { + t.Fatalf("second Name not rendered as block name; got:\n%s", s) + } + if strings.Contains(s, "first(func:") { + t.Fatalf("first Name still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_AsRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // As transitions to *RawQuery, prefixes the block with " as ", + // and overwrites: the second call wins. + q := c.Query(ctx).As("first").As("second") + if q == nil { + t.Fatal("As() returned nil *RawQuery") + } + s := q.String() + if !strings.Contains(s, "second as ") { + t.Fatalf("second As not rendered; got:\n%s", s) + } + if strings.Contains(s, "first as ") { + t.Fatalf("first As still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_VarsRendersQueryPrefix(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Vars renders a "query " prefix on the generated DQL. + q := c.Query(ctx).Vars("getByName($n: string)", map[string]string{"$n": "b"}) + s := q.Raw().String() + if !strings.Contains(s, "query getByName($n: string)") { + t.Fatalf("Vars did not render the query-definition prefix; got:\n%s", s) + } +} + +func TestQuery_VarsParameterizedQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // Vars supplies a GraphQL variable bound into the root function; the + // query executes via dgraph's QueryWithVars path. + got, err := c.Query(ctx). + Vars("getByName($n: string)", map[string]string{"$n": "b"}). + RootFunc("eq(name, $n)"). + Nodes() + if err != nil { + t.Fatalf("Vars query Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "b" { + t.Fatalf(`Vars parameterized query returned %+v, want one widget named "b"`, got) + } +} + +func TestQuery_VarReturnsRawQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Var transitions to *RawQuery and emits a var block: dgman renders the + // block name as "var". + rq := c.Query(ctx).Var() + if rq == nil { + t.Fatal("Var() returned nil *RawQuery") + } + s := rq.String() + if !strings.Contains(s, "var(func:") { + t.Fatalf("Var() did not render a var block; got:\n%s", s) + } +} + +func TestQuery_GroupByReturnsRawQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // GroupBy transitions to *RawQuery and emits an @groupby clause. + rq := c.Query(ctx).GroupBy("name") + if rq == nil { + t.Fatal("GroupBy() returned nil *RawQuery") + } + s := rq.String() + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf(`GroupBy("name") did not render an @groupby clause; got:\n%s`, s) + } +} + +func TestRawQuery_RawExposesUnderlyingQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + rq := c.Query(ctx).Var() + // Raw returns the underlying *dg.Query; String mirrors Raw().String(). + var raw *dg.Query = rq.Raw() + if raw == nil { + t.Fatal("RawQuery.Raw() returned nil") + } + if rq.String() != raw.String() { + t.Fatalf("RawQuery.String() and Raw().String() differ:\n%s\n---\n%s", + rq.String(), raw.String()) + } +} + +func TestRawQuery_GroupByThenVarChains(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // RawQuery re-exposes Var and GroupBy so the canonical .GroupBy(...).Var() + // composition still chains; both clauses survive. + s := c.Query(ctx).GroupBy("name").Var().String() + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf("@groupby clause missing after GroupBy().Var(); got:\n%s", s) + } + if !strings.Contains(s, "var(func:") { + t.Fatalf("var block missing after GroupBy().Var(); got:\n%s", s) + } +} + +func TestRawQuery_CarriesEarlierBuilders(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Builders applied on *Query[T] before the GroupBy transition survive + // into the *RawQuery — the two share one underlying *dg.Query. + s := c.Query(ctx).Filter(`eq(name, "z")`).GroupBy("name").String() + if !strings.Contains(s, `eq(name, "z")`) { + t.Fatalf("Filter set before GroupBy did not survive the transition; got:\n%s", s) + } + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf("@groupby clause missing; got:\n%s", s) + } +} + +// seedOwners inserts owner/pet pairs over conn for the WhereEdge tests. Each +// map entry is one owner owning one pet of the given name; the pet is inserted +// first so the owner's edge links an already-persisted node. It returns an +// owner client bound to conn. +func seedOwners(ctx context.Context, t *testing.T, conn modusgraph.Client, ownerToPet map[string]string) *typed.Client[owner] { + t.Helper() + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + for ownerName, petName := range ownerToPet { + p := &pet{Name: petName} + if err := pets.Add(ctx, p); err != nil { + t.Fatalf("Add pet %q: %v", petName, err) + } + if err := owners.Add(ctx, &owner{Name: ownerName, Pets: []*pet{p}}); err != nil { + t.Fatalf("Add owner %q: %v", ownerName, err) + } + } + return owners +} + +func TestQuery_WhereEdgeFiltersByEdgeTarget(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + // WhereEdge constrains owners by a scalar of the pet reached over the + // "pets" edge — something a root Filter cannot express. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Fido")`).Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("WhereEdge(pets, name=Fido) returned %d owners, want 2 (Alice, Carol)", len(got)) + } + for _, o := range got { + if o.Name != "Alice" && o.Name != "Carol" { + t.Fatalf("WhereEdge returned %q, want only Fido owners (Alice, Carol)", o.Name) + } + } +} + +func TestQuery_WhereEdgeNoMatchReturnsEmpty(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // No pet is named Nemo: the pre-pass matches zero roots, so Nodes returns + // an empty result — not an error — and never runs the main query. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Nemo")`).Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: unexpected error %v", err) + } + if len(got) != 0 { + t.Fatalf("WhereEdge for an unowned pet name returned %d owners, want 0", len(got)) + } +} + +func TestQuery_WhereEdgeBindsParams(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // The $1 placeholder in a WhereEdge filter binds exactly as it does for Filter. + got, err := owners.Query(ctx).WhereEdge("pets", "eq(name, $1)", "Rex").Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Bob" { + t.Fatalf("WhereEdge(pets, name=$1, Rex) returned %+v, want [Bob]", got) + } +} + +func TestQuery_WhereEdgeCombinesWithFilter(t *testing.T) { + ctx := context.Background() + // Alice and Carol both own a Fido; a root Filter on name narrows to Alice. + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + got, err := owners.Query(ctx). + Filter(`eq(name, "Alice")`). + WhereEdge("pets", `eq(name, "Fido")`). + Nodes() + if err != nil { + t.Fatalf("Filter+WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("Filter(name=Alice)+WhereEdge(pets,name=Fido) returned %+v, want [Alice]", got) + } +} + +func TestQuery_WhereEdgeMultipleConstraintsAnd(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + + // Alice owns both Fido and Rex; Bob owns only Fido. + fido, rex := &pet{Name: "Fido"}, &pet{Name: "Rex"} + for _, p := range []*pet{fido, rex} { + if err := pets.Add(ctx, p); err != nil { + t.Fatalf("Add pet %q: %v", p.Name, err) + } + } + if err := owners.Add(ctx, &owner{Name: "Alice", Pets: []*pet{fido, rex}}); err != nil { + t.Fatalf("Add Alice: %v", err) + } + if err := owners.Add(ctx, &owner{Name: "Bob", Pets: []*pet{fido}}); err != nil { + t.Fatalf("Add Bob: %v", err) + } + + // Two WhereEdge calls AND together: only an owner of BOTH pets survives. + got, err := owners.Query(ctx). + WhereEdge("pets", `eq(name, "Fido")`). + WhereEdge("pets", `eq(name, "Rex")`). + Nodes() + if err != nil { + t.Fatalf("two-WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("WhereEdge(Fido) AND WhereEdge(Rex) returned %+v, want [Alice]", got) + } +} + +func TestQuery_WhereEdgeFirst(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // First runs the pre-pass too: it returns the Rex owner, never a Fido one. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Rex")`).First() + if err != nil { + t.Fatalf("WhereEdge First: %v", err) + } + if got == nil || got.Name != "Bob" { + t.Fatalf("WhereEdge(pets,name=Rex).First() = %+v, want Bob", got) + } + + // First with an edge constraint nothing satisfies is (nil, nil). + none, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Nemo")`).First() + if err != nil { + t.Fatalf("WhereEdge First no-match: unexpected error %v", err) + } + if none != nil { + t.Fatalf("WhereEdge First with no match = %+v, want nil", none) + } +} + +func TestQuery_WhereEdgeIterNodes(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + seen := 0 + for o, err := range owners.Query(ctx).WhereEdge("pets", `eq(name, "Fido")`).IterNodes() { + if err != nil { + t.Fatalf("WhereEdge IterNodes yielded error: %v", err) + } + if o.Name != "Alice" && o.Name != "Carol" { + t.Fatalf("WhereEdge IterNodes yielded %q, want a Fido owner", o.Name) + } + seen++ + } + if seen != 2 { + t.Fatalf("WhereEdge IterNodes streamed %d owners, want 2", seen) + } +} + +func TestQuery_UIDRootsAtNode(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "sprocket", Qty: 3} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + + got, err := c.Query(ctx).UID(w.UID).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "sprocket" { + t.Fatalf("UID query returned %+v, want one widget named sprocket", got) + } +} + +func TestQuery_NodesAndCountReturnsTotal(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for i := 0; i < 3; i++ { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add: %v", err) + } + } + + nodes, count, err := c.Query(ctx).NodesAndCount() + if err != nil { + t.Fatalf("NodesAndCount: %v", err) + } + if count != 3 || len(nodes) != 3 { + t.Fatalf("got count=%d len=%d, want 3 and 3", count, len(nodes)) + } +} + +func TestQuery_AllSetsTraversalDepth(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "deep", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + + // All(1) overrides the default traversal depth for this query; the call + // must chain and the query must still execute and decode. + got, err := c.Query(ctx).All(1).Nodes() + if err != nil { + t.Fatalf("Nodes with All(1): %v", err) + } + if len(got) != 1 { + t.Fatalf("got %d widgets, want 1", len(got)) + } +} + +func TestQuery_StringRendersDQL(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + dql := c.Query(ctx).Filter("eq(name, $1)", "sprocket").String() + if !strings.Contains(dql, "widget") { + t.Fatalf("String() = %q, want it to mention the widget type", dql) + } +} diff --git a/typed/search/merge.go b/typed/search/merge.go new file mode 100644 index 0000000..2546274 --- /dev/null +++ b/typed/search/merge.go @@ -0,0 +1,27 @@ +// Package search provides helpers for assembling fulltext / ranked search +// results across multiple typed query blocks. +package search + +// MergeByID concatenates inputs into a single slice while preserving +// first-seen order and dropping any subsequent occurrence of an ID already +// emitted. The id function extracts a comparable identifier from each row. +// +// MergeByID is intended for use after typed.MultiQuery.Execute, when +// consumers want a single ranked slice from N per-field result sets: +// inputs[0] takes priority, inputs[1] fills in next, etc. A nil result +// indicates no rows survived (the inputs were all empty). +func MergeByID[T any](id func(T) string, inputs ...[]T) []T { + seen := make(map[string]struct{}) + var out []T + for _, in := range inputs { + for _, row := range in { + k := id(row) + if _, dup := seen[k]; dup { + continue + } + seen[k] = struct{}{} + out = append(out, row) + } + } + return out +} diff --git a/typed/search/merge_test.go b/typed/search/merge_test.go new file mode 100644 index 0000000..e4e8583 --- /dev/null +++ b/typed/search/merge_test.go @@ -0,0 +1,86 @@ +package search_test + +import ( + "reflect" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/search" +) + +type rec struct { + ID string + Tag string +} + +func id(r rec) string { return r.ID } + +func TestMergeByID(t *testing.T) { + cases := []struct { + name string + inputs [][]rec + want []rec + }{ + { + name: "empty inputs returns nil", + inputs: nil, + want: nil, + }, + { + name: "single empty slice returns nil", + inputs: [][]rec{{}}, + want: nil, + }, + { + name: "single slice returns it as-is", + inputs: [][]rec{{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }}, + want: []rec{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }, + }, + { + name: "two slices merge in priority order", + inputs: [][]rec{ + {{ID: "a", Tag: "name"}}, + {{ID: "b", Tag: "desc"}}, + }, + want: []rec{ + {ID: "a", Tag: "name"}, + {ID: "b", Tag: "desc"}, + }, + }, + { + name: "duplicate ID keeps first-seen entry", + inputs: [][]rec{ + {{ID: "a", Tag: "name"}}, + {{ID: "a", Tag: "desc"}, {ID: "b", Tag: "desc"}}, + }, + want: []rec{ + {ID: "a", Tag: "name"}, + {ID: "b", Tag: "desc"}, + }, + }, + { + name: "intra-slice duplicates dedup too", + inputs: [][]rec{ + {{ID: "a", Tag: "1"}, {ID: "a", Tag: "2"}, {ID: "b", Tag: "1"}}, + }, + want: []rec{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := search.MergeByID(id, c.inputs...) + if !reflect.DeepEqual(got, c.want) { + t.Fatalf("got %v, want %v", got, c.want) + } + }) + } +} diff --git a/typed/tracing.go b/typed/tracing.go new file mode 100644 index 0000000..8a456ed --- /dev/null +++ b/typed/tracing.go @@ -0,0 +1,58 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "reflect" +) + +// Span is a tracing span for a single database operation. End is called once, +// with the operation's final error (nil on success). +type Span interface { + End(err error) +} + +// Tracer starts a Span around a typed-layer database operation. The typed +// client calls the installed Tracer for every DB call; the default is a no-op, +// so the typed package itself carries no tracing dependency. Install a real +// tracer — for example github.com/mlwelles/modusgraph-telemetry's OpenTelemetry +// tracer — with SetTracer. +type Tracer interface { + // StartSpan begins a span for operation op (for example "get") on the named + // collection, returning a context carrying the span and the Span itself. + StartSpan(ctx context.Context, op, collection string) (context.Context, Span) +} + +type noopSpan struct{} + +func (noopSpan) End(error) {} + +type noopTracer struct{} + +func (noopTracer) StartSpan(ctx context.Context, _, _ string) (context.Context, Span) { + return ctx, noopSpan{} +} + +// tracer is the process-wide tracer the typed package uses. It is a no-op until +// a host installs one via SetTracer. +var tracer Tracer = noopTracer{} + +// SetTracer installs the process-wide tracer for typed-layer DB spans. Passing +// nil restores the no-op tracer. Install once during startup; it is not safe to +// call concurrently with active queries. +func SetTracer(t Tracer) { + if t == nil { + t = noopTracer{} + } + tracer = t +} + +// entityName returns the unqualified Go type name of T (for example "Resource"), +// used as the db.collection.name span attribute. +func entityName[T any]() string { + return reflect.TypeFor[T]().Name() +} diff --git a/typed/tracing_test.go b/typed/tracing_test.go new file mode 100644 index 0000000..d9aab78 --- /dev/null +++ b/typed/tracing_test.go @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "testing" +) + +func TestSetTracer_InstallsAndResets(t *testing.T) { + t.Cleanup(func() { SetTracer(nil) }) + + rec := &recordingTracer{} + SetTracer(rec) + + _, span := tracer.StartSpan(context.Background(), "get", "Widget") + span.End(nil) + + if rec.op != "get" || rec.collection != "Widget" { + t.Fatalf("installed tracer not invoked: %+v", rec) + } + if !rec.ended { + t.Fatal("span.End was not called") + } + + // nil restores the no-op tracer, which must not panic. + SetTracer(nil) + _, span = tracer.StartSpan(context.Background(), "x", "Y") + span.End(nil) +} + +type recordingTracer struct { + op, collection string + ended bool +} + +func (r *recordingTracer) StartSpan(ctx context.Context, op, collection string) (context.Context, Span) { + r.op, r.collection = op, collection + return ctx, &recordingSpan{r} +} + +type recordingSpan struct{ r *recordingTracer } + +func (s *recordingSpan) End(error) { s.r.ended = true }