From 67c0b3153560caf599976078a756f9bceb445da8 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:11:17 -0400 Subject: [PATCH 1/2] feat: aborted-transaction retry policy and runner Add RetryPolicy / DefaultRetryPolicy and a runner that re-executes a function on aborted Dgraph transactions with exponential backoff (retry.go). Self-contained; a follow-up wires retry into the client via a WithRetry method. --- retry.go | 96 ++++++++++++++++++++++++++++++++++++++++++ retry_internal_test.go | 68 ++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 retry.go create mode 100644 retry_internal_test.go diff --git a/retry.go b/retry.go new file mode 100644 index 0000000..9b49fda --- /dev/null +++ b/retry.go @@ -0,0 +1,96 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "context" + "errors" + "math/rand/v2" + "time" + + "github.com/dgraph-io/dgo/v250" +) + +// RetryPolicy controls how WithRetry handles aborted transactions. +// Modeled after dgraph4j's RetryPolicy: exponential backoff with jitter. +type RetryPolicy struct { + // MaxRetries is the maximum number of retry attempts after the initial try. + MaxRetries int + + // BaseDelay is the initial delay before the first retry. + // Subsequent delays grow exponentially: BaseDelay * 2^attempt. + BaseDelay time.Duration + + // MaxDelay caps the backoff duration. No single delay exceeds this. + MaxDelay time.Duration + + // Jitter adds randomness to each delay to prevent thundering herd. + // Expressed as a fraction of the computed delay (e.g. 0.1 = 10%). + Jitter float64 +} + +// DefaultRetryPolicy mirrors dgraph4j's defaults: +// 5 retries, 100ms base delay, 5s max delay, 10% jitter. +var DefaultRetryPolicy = RetryPolicy{ + MaxRetries: 10, + BaseDelay: 100 * time.Millisecond, + MaxDelay: 5 * time.Second, + Jitter: 0.1, +} + +// delay computes the backoff duration for a given attempt (0-indexed). +// Formula: min(BaseDelay * 2^attempt, MaxDelay) + random(0, delay * Jitter) +func (p RetryPolicy) delay(attempt int) time.Duration { + d := p.BaseDelay * time.Duration(1< p.MaxDelay { + d = p.MaxDelay + } + if p.Jitter > 0 { + d += time.Duration(float64(d) * p.Jitter * rand.Float64()) + } + return d +} + +// WithRetry executes fn, retrying on aborted transactions according to policy. +// +// This is an opt-in mechanism modeled after dgraph4j's client.withRetry(). +// The caller wraps their mutation logic in fn; WithRetry handles creating +// fresh attempts with exponential backoff when Dgraph returns a transaction +// abort due to concurrent conflicts. +// +// fn is called at least once. On each aborted-transaction error, WithRetry +// waits according to the policy's backoff schedule and calls fn again, up to +// policy.MaxRetries additional times. Non-abort errors are returned immediately. +// +// The context is checked between retries; if cancelled during a backoff sleep, +// the context error is returned. +// +// Usage: +// +// err := client.WithRetry(ctx, modusgraph.DefaultRetryPolicy, func() error { +// return client.Insert(ctx, &entity) +// }) +func (c client) WithRetry(ctx context.Context, policy RetryPolicy, fn func() error) error { + for attempt := range policy.MaxRetries + 1 { + err := fn() + if err == nil { + return nil + } + if !errors.Is(err, dgo.ErrAborted) || attempt >= policy.MaxRetries { + return err + } + d := policy.delay(attempt) + c.logger.V(1).Info("Transaction aborted, retrying", + "attempt", attempt+1, "maxRetries", policy.MaxRetries, "delay", d) + select { + case <-time.After(d): + case <-ctx.Done(): + return ctx.Err() + } + } + // Unreachable: the loop runs MaxRetries+1 times and returns on every path. + panic("unreachable") +} diff --git a/retry_internal_test.go b/retry_internal_test.go new file mode 100644 index 0000000..ce6bd2b --- /dev/null +++ b/retry_internal_test.go @@ -0,0 +1,68 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRetryPolicyDelayExponentialGrowth(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + Jitter: 0, + } + + assert.Equal(t, 100*time.Millisecond, p.delay(0)) + assert.Equal(t, 200*time.Millisecond, p.delay(1)) + assert.Equal(t, 400*time.Millisecond, p.delay(2)) + assert.Equal(t, 800*time.Millisecond, p.delay(3)) + assert.Equal(t, 1600*time.Millisecond, p.delay(4)) +} + +func TestRetryPolicyDelayMaxCap(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 1 * time.Second, + MaxDelay: 3 * time.Second, + Jitter: 0, + } + + assert.Equal(t, 1*time.Second, p.delay(0)) + assert.Equal(t, 2*time.Second, p.delay(1)) + assert.Equal(t, 3*time.Second, p.delay(2)) + assert.Equal(t, 3*time.Second, p.delay(3)) + assert.Equal(t, 3*time.Second, p.delay(10)) +} + +func TestRetryPolicyDelayWithJitter(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + Jitter: 0.5, + } + + for range 100 { + d := p.delay(0) + assert.GreaterOrEqual(t, d, 100*time.Millisecond, "delay should be at least base") + assert.LessOrEqual(t, d, 150*time.Millisecond, "delay should not exceed base + 50% jitter") + } +} + +func TestRetryPolicyDelayZeroJitter(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + Jitter: 0, + } + + for range 10 { + assert.Equal(t, 100*time.Millisecond, p.delay(0)) + assert.Equal(t, 200*time.Millisecond, p.delay(1)) + } +} From 971559cf00711c59b6408c2a8da21d9248c24a06 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:21:56 -0400 Subject: [PATCH 2/2] feat: configurable bulk data loader Add a bulk loader for RDF/JSON files: - load package: BatchSize, MutationWorkers, Schema, file match/sort options. - loaddata.go (embedded/Namespace path) + loaddata_grpc.go (gRPC path). - exposed as client.LoadData(ctx, dataDir, opts...). Replaces the previous live-loader (live.go). Uses RetryPolicy from the retry layer. --- client.go | 6 + live.go | 253 ----------------------- load/options.go | 200 ++++++++++++++++++ load/options_test.go | 352 ++++++++++++++++++++++++++++++++ loaddata.go | 323 +++++++++++++++++++++++++++++ loaddata_grpc.go | 160 +++++++++++++++ loaddata_test.go | 472 +++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 1513 insertions(+), 253 deletions(-) delete mode 100644 live.go create mode 100644 load/options.go create mode 100644 load/options_test.go create mode 100644 loaddata.go create mode 100644 loaddata_grpc.go create mode 100644 loaddata_test.go diff --git a/client.go b/client.go index be9813b..7609e9c 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,8 @@ import ( "github.com/go-playground/validator/v10" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + + "github.com/matthewmcneely/modusgraph/load" ) // Client provides an interface for ModusGraph operations @@ -87,6 +89,10 @@ type Client interface { // DgraphClient returns a gRPC Dgraph client from the connection pool and a cleanup function. // The cleanup function must be called when finished with the client to return it to the pool. DgraphClient() (*dgo.Dgraph, func(), error) + + // LoadData bulk-loads RDF or JSON data files from dataDir into the database, + // configured by the given options. It replaces the previous live-loader. + LoadData(ctx context.Context, dataDir string, opts ...load.Option) error } const ( diff --git a/live.go b/live.go deleted file mode 100644 index 0c89711..0000000 --- a/live.go +++ /dev/null @@ -1,253 +0,0 @@ -/* - * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package modusgraph - -import ( - "context" - "fmt" - "io" - "os" - "sync" - "time" - - "github.com/dgraph-io/dgo/v250/protos/api" - "github.com/dgraph-io/dgraph/v25/chunker" - "github.com/dgraph-io/dgraph/v25/filestore" - "github.com/dgraph-io/dgraph/v25/x" - "github.com/pkg/errors" - "golang.org/x/sync/errgroup" -) - -const ( - maxRoutines = 4 - batchSize = 1000 - numBatchesInBuf = 100 - progressFrequency = 5 * time.Second -) - -type liveLoader struct { - n *Namespace - blankNodes map[string]string - mutex sync.RWMutex -} - -func (n *Namespace) Load(ctx context.Context, schemaPath, dataPath string) error { - schemaData, err := os.ReadFile(schemaPath) - if err != nil { - return fmt.Errorf("error reading schema file [%v]: %w", schemaPath, err) - } - if err := n.AlterSchema(ctx, string(schemaData)); err != nil { - return fmt.Errorf("error altering schema: %w", err) - } - - if err := n.LoadData(ctx, dataPath); err != nil { - return fmt.Errorf("error loading data: %w", err) - } - return nil -} - -// TODO: Add support for CSV file -func (n *Namespace) LoadData(inCtx context.Context, dataDir string) error { - fs := filestore.NewFileStore(dataDir) - files := fs.FindDataFiles(dataDir, []string{".rdf", ".rdf.gz", ".json", ".json.gz"}) - if len(files) == 0 { - return errors.Errorf("no data files found in [%v]", dataDir) - } - n.engine.logger.Info("Found data files to process", "count", len(files)) - - // Here, we build a context tree so that we can wait for the goroutines towards the - // end. This also ensures that we can cancel the context tree if there is an error. - rootG, rootCtx := errgroup.WithContext(inCtx) - procG, procCtx := errgroup.WithContext(rootCtx) - procG.SetLimit(maxRoutines) - - // start a goroutine to do the mutations - start := time.Now() - nqudsProcessed := 0 - nqch := make(chan *api.Mutation, 10000) - rootG.Go(func() error { - ticker := time.NewTicker(progressFrequency) - defer ticker.Stop() - - last := nqudsProcessed - for { - select { - case <-rootCtx.Done(): - return rootCtx.Err() - - case <-ticker.C: - elapsed := time.Since(start).Round(time.Second) - rate := float64(nqudsProcessed-last) / progressFrequency.Seconds() - n.engine.logger.Info("Data loading progress", "elapsed", x.FixedDuration(elapsed), - "nquadsProcessed", nqudsProcessed, - "writesPerSecond", fmt.Sprintf("%5.0f", rate)) - last = nqudsProcessed - - case nqs, ok := <-nqch: - if !ok { - return nil - } - uids, err := n.Mutate(rootCtx, []*api.Mutation{nqs}) - if err != nil { - return fmt.Errorf("error applying mutations: %w", err) - } - x.AssertTruef(len(uids) == 0, "no UIDs should be returned for live loader") - nqudsProcessed += len(nqs.Set) - } - } - }) - - ll := &liveLoader{n: n, blankNodes: make(map[string]string)} - for _, datafile := range files { - procG.Go(func() error { - return ll.processFile(procCtx, fs, datafile, nqch) - }) - } - - // Wait until all the files are processed - if errProcG := procG.Wait(); errProcG != nil { - rootG.Go(func() error { - return errProcG - }) - } - - // close the channel and wait for the mutations to finish - close(nqch) - return rootG.Wait() -} - -func (l *liveLoader) processFile(inCtx context.Context, fs filestore.FileStore, - filename string, nqch chan *api.Mutation) error { - - l.n.engine.logger.Info("Processing data file", "filename", filename) - - rd, cleanup := fs.ChunkReader(filename, nil) - defer cleanup() - - loadType := chunker.DataFormat(filename, "") - if loadType == chunker.UnknownFormat { - if isJson, err := chunker.IsJSONData(rd); err == nil { - if isJson { - loadType = chunker.JsonFormat - } else { - return errors.Errorf("unable to figure out data format for [%v]", filename) - } - } - } - - g, ctx := errgroup.WithContext(inCtx) - ck := chunker.NewChunker(loadType, batchSize) - nqbuf := ck.NQuads() - - g.Go(func() error { - buffer := make([]*api.NQuad, 0, numBatchesInBuf*batchSize) - - drain := func() { - for len(buffer) > 0 { - sz := batchSize - if len(buffer) < batchSize { - sz = len(buffer) - } - nqch <- &api.Mutation{Set: buffer[:sz]} - buffer = buffer[sz:] - } - } - - loop := true - for loop { - select { - case <-ctx.Done(): - return ctx.Err() - - case nqs, ok := <-nqbuf.Ch(): - if !ok { - loop = false - break - } - if len(nqs) == 0 { - continue - } - - var err error - for _, nq := range nqs { - nq.Subject, err = l.uid(nq.Namespace, nq.Subject) - if err != nil { - return fmt.Errorf("error getting UID for subject: %w", err) - } - if len(nq.ObjectId) > 0 { - nq.ObjectId, err = l.uid(nq.Namespace, nq.ObjectId) - if err != nil { - return fmt.Errorf("error getting UID for object: %w", err) - } - } - } - - buffer = append(buffer, nqs...) - if len(buffer) < numBatchesInBuf*batchSize { - continue - } - drain() - } - } - drain() - return nil - }) - - g.Go(func() error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - chunkBuf, errChunk := ck.Chunk(rd) - if errChunk != nil && errChunk != io.EOF { - return fmt.Errorf("error chunking data: %w", errChunk) - } - if err := ck.Parse(chunkBuf); err != nil { - return fmt.Errorf("error parsing chunk: %w", err) - } - // We do this here in case of io.EOF, so that we can flush the last batch. - if errChunk != nil { - break - } - } - - nqbuf.Flush() - return nil - }) - - return g.Wait() -} - -func (l *liveLoader) uid(ns uint64, val string) (string, error) { - key := x.NamespaceAttr(ns, val) - - l.mutex.RLock() - uid, ok := l.blankNodes[key] - l.mutex.RUnlock() - if ok { - return uid, nil - } - - l.mutex.Lock() - defer l.mutex.Unlock() - - uid, ok = l.blankNodes[key] - if ok { - return uid, nil - } - - asUID, err := l.n.engine.LeaseUIDs(1) - if err != nil { - return "", fmt.Errorf("error allocating UID: %w", err) - } - - uid = fmt.Sprintf("%#x", asUID.StartId) - l.blankNodes[key] = uid - return uid, nil -} diff --git a/load/options.go b/load/options.go new file mode 100644 index 0000000..4a8a090 --- /dev/null +++ b/load/options.go @@ -0,0 +1,200 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package load provides the Loader and option types for configuring +// modusgraph.Client.LoadData calls. +// +// Usage: +// +// client.LoadData(ctx, dataDir, +// load.WithBatchSize(10000), +// load.WithMutationWorkers(8), +// load.WithSchema("schema.dgraph"), +// ) +package load + +import "strings" + +// Option configures a LoadData call. +type Option func(*Options) + +// FileMatch selects which files in the data directory to load. +type FileMatch interface { + Match(path string) bool +} + +// FileMatchFunc adapts a plain function to the FileMatch interface. +type FileMatchFunc func(path string) bool + +// Match implements FileMatch. +func (f FileMatchFunc) Match(path string) bool { return f(path) } + +// FileSort reorders the list of data files before processing. +type FileSort func([]string) []string + +// defaultExtensions is the set of file extensions loaded by LoadData when +// no FileMatch is configured. +var defaultExtensions = []string{".rdf", ".rdf.gz", ".json", ".json.gz"} + +// DefaultExtensions returns the default file extensions loaded by LoadData. +func DefaultExtensions() []string { + out := make([]string, len(defaultExtensions)) + copy(out, defaultExtensions) + return out +} + +// extensionMatch matches files by suffix. +type extensionMatch struct { + exts []string +} + +func (m extensionMatch) Match(path string) bool { + for _, ext := range m.exts { + if strings.HasSuffix(path, ext) { + return true + } + } + return false +} + +// NewExtensionMatch returns a FileMatch that accepts files whose path ends +// with any of the given suffixes. +func NewExtensionMatch(exts ...string) FileMatch { + return extensionMatch{exts: exts} +} + +// Options control the behavior of LoadData. +// Zero values use defaults (BatchSize=1000, MutationWorkers=1). +// +// File processing pipeline: walk directory → FilterFiles (applies FileMatch) → SortFiles → process. +type Options struct { + // SchemaPath is the path to a Dgraph schema file applied before loading. + // Empty means the schema must already exist. + SchemaPath string + + // BatchSize is the number of NQuads per mutation batch. + // Larger batches reduce gRPC round-trips but increase per-transaction + // memory on the server. Default is 1000. + BatchSize int + + // MutationWorkers is the number of concurrent goroutines submitting + // mutations. Higher values increase throughput but put more load on the + // Dgraph cluster. Default is 1 (sequential). + MutationWorkers int + + // FileMatch, if set, selects which individual files to include. + // A nil FileMatch matches all files. + FileMatch FileMatch + + // SortFiles, if set, reorders data files before processing. + // Applied after FilterFiles. By default files are in the lexicographic + // order returned by filepath.Walk. + SortFiles FileSort +} + +// DefaultBatchSize is the default number of NQuads per mutation batch. +const DefaultBatchSize = 1000 + +// DefaultMutationWorkers is the default number of concurrent mutation goroutines. +const DefaultMutationWorkers = 1 + +// GetBatchSize returns the effective batch size, defaulting to DefaultBatchSize. +func (o Options) GetBatchSize() int { + if o.BatchSize <= 0 { + return DefaultBatchSize + } + return o.BatchSize +} + +// GetMutationWorkers returns the effective worker count, defaulting to DefaultMutationWorkers. +func (o Options) GetMutationWorkers() int { + if o.MutationWorkers <= 0 { + return DefaultMutationWorkers + } + return o.MutationWorkers +} + +// MatchFile reports whether an individual path should be included. +// If FileMatch is nil, all files match. +func (o Options) MatchFile(path string) bool { + if o.FileMatch == nil { + return true + } + return o.FileMatch.Match(path) +} + +// FilterFiles returns only the files from the input that pass MatchFile. +func (o Options) FilterFiles(files []string) []string { + if o.FileMatch == nil { + return files + } + var out []string + for _, f := range files { + if o.FileMatch.Match(f) { + out = append(out, f) + } + } + return out +} + +// WithSchema applies the given Dgraph schema file before loading data. +func WithSchema(path string) Option { + return func(o *Options) { + o.SchemaPath = path + } +} + +// WithBatchSize sets the number of NQuads per mutation batch. +// Default is 1000. +func WithBatchSize(n int) Option { + return func(o *Options) { + o.BatchSize = n + } +} + +// WithMutationWorkers sets the number of concurrent mutation goroutines. +// Default is 1 (sequential). +func WithMutationWorkers(n int) Option { + return func(o *Options) { + o.MutationWorkers = n + } +} + +// WithFileMatch sets a FileMatch for per-file matching during directory walking. +// A nil filter matches all files. +func WithFileMatch(f FileMatch) Option { + return func(o *Options) { + o.FileMatch = f + } +} + +// WithFileSort sets a function that reorders data files before processing. +// Called after FilterFiles. +func WithFileSort(fn FileSort) Option { + return func(o *Options) { + o.SortFiles = fn + } +} + +// WithOptions applies all non-zero fields from the given Options struct. +func WithOptions(lo Options) Option { + return func(o *Options) { + if lo.SchemaPath != "" { + o.SchemaPath = lo.SchemaPath + } + if lo.BatchSize > 0 { + o.BatchSize = lo.BatchSize + } + if lo.MutationWorkers > 0 { + o.MutationWorkers = lo.MutationWorkers + } + if lo.FileMatch != nil { + o.FileMatch = lo.FileMatch + } + if lo.SortFiles != nil { + o.SortFiles = lo.SortFiles + } + } +} diff --git a/load/options_test.go b/load/options_test.go new file mode 100644 index 0000000..7e616be --- /dev/null +++ b/load/options_test.go @@ -0,0 +1,352 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package load + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOptionsDefaults(t *testing.T) { + var opts Options + + assert.Equal(t, DefaultBatchSize, opts.GetBatchSize(), "default batch size") + assert.Equal(t, DefaultMutationWorkers, opts.GetMutationWorkers(), "default mutation workers") +} + +func TestOptionsZeroValues(t *testing.T) { + opts := Options{BatchSize: 0, MutationWorkers: 0} + + assert.Equal(t, DefaultBatchSize, opts.GetBatchSize(), "zero batch size should use default") + assert.Equal(t, DefaultMutationWorkers, opts.GetMutationWorkers(), "zero workers should use default") +} + +func TestOptionsNegativeValues(t *testing.T) { + opts := Options{BatchSize: -1, MutationWorkers: -5} + + assert.Equal(t, DefaultBatchSize, opts.GetBatchSize(), "negative batch size should use default") + assert.Equal(t, DefaultMutationWorkers, opts.GetMutationWorkers(), "negative workers should use default") +} + +func TestOptionsExplicitValues(t *testing.T) { + opts := Options{BatchSize: 5000, MutationWorkers: 8} + + assert.Equal(t, 5000, opts.GetBatchSize()) + assert.Equal(t, 8, opts.GetMutationWorkers()) +} + +func TestWithBatchSizeOption(t *testing.T) { + var opts Options + WithBatchSize(10000)(&opts) + assert.Equal(t, 10000, opts.BatchSize) +} + +func TestWithMutationWorkersOption(t *testing.T) { + var opts Options + WithMutationWorkers(16)(&opts) + assert.Equal(t, 16, opts.MutationWorkers) +} + +func TestWithSchemaOption(t *testing.T) { + var opts Options + WithSchema("/path/to/schema.dgraph")(&opts) + assert.Equal(t, "/path/to/schema.dgraph", opts.SchemaPath) +} + +func TestWithOptionsFullOverride(t *testing.T) { + var opts Options + + WithOptions(Options{ + SchemaPath: "/schema.dgraph", + BatchSize: 5000, + MutationWorkers: 4, + })(&opts) + + assert.Equal(t, "/schema.dgraph", opts.SchemaPath) + assert.Equal(t, 5000, opts.BatchSize) + assert.Equal(t, 4, opts.MutationWorkers) +} + +func TestWithOptionsZeroFieldsIgnored(t *testing.T) { + opts := Options{ + SchemaPath: "/existing.dgraph", + BatchSize: 2000, + MutationWorkers: 8, + } + + WithOptions(Options{})(&opts) + + assert.Equal(t, "/existing.dgraph", opts.SchemaPath) + assert.Equal(t, 2000, opts.BatchSize) + assert.Equal(t, 8, opts.MutationWorkers) +} + +func TestWithOptionsPartialOverride(t *testing.T) { + opts := Options{ + SchemaPath: "/old.dgraph", + BatchSize: 2000, + MutationWorkers: 8, + } + + WithOptions(Options{BatchSize: 10000})(&opts) + + assert.Equal(t, "/old.dgraph", opts.SchemaPath, "SchemaPath should be preserved") + assert.Equal(t, 10000, opts.BatchSize, "BatchSize should be overridden") + assert.Equal(t, 8, opts.MutationWorkers, "MutationWorkers should be preserved") +} + +func TestOptionFuncsCompose(t *testing.T) { + var opts Options + + fns := []Option{ + WithBatchSize(1000), + WithMutationWorkers(4), + WithSchema("/a.dgraph"), + WithBatchSize(5000), + } + for _, fn := range fns { + fn(&opts) + } + + assert.Equal(t, "/a.dgraph", opts.SchemaPath) + assert.Equal(t, 5000, opts.BatchSize) + assert.Equal(t, 4, opts.MutationWorkers) +} + +// FileMatch tests + +func TestMatchFileNilMatchesAll(t *testing.T) { + var opts Options + assert.True(t, opts.MatchFile("anything.txt")) + assert.True(t, opts.MatchFile("data.rdf")) + assert.True(t, opts.MatchFile("")) +} + +func TestNewExtensionMatch(t *testing.T) { + m := NewExtensionMatch(".csv", ".tsv") + + assert.True(t, m.Match("data.csv")) + assert.True(t, m.Match("/dir/data.tsv")) + assert.False(t, m.Match("data.rdf")) +} + +func TestNewExtensionMatchNoExtensions(t *testing.T) { + m := NewExtensionMatch() + assert.False(t, m.Match("data.rdf")) + assert.False(t, m.Match("")) +} + +func TestNewExtensionMatchOverlappingSuffixes(t *testing.T) { + m := NewExtensionMatch(".gz", ".rdf.gz") + + assert.True(t, m.Match("data.rdf.gz"), ".rdf.gz matches .gz") + assert.True(t, m.Match("data.tar.gz"), ".tar.gz matches .gz") + assert.False(t, m.Match("data.rdf"), "plain .rdf should not match") +} + +func TestNewExtensionMatchEmptyPath(t *testing.T) { + m := NewExtensionMatch(".rdf") + assert.False(t, m.Match("")) +} + +func TestNewExtensionMatchFullPathMatching(t *testing.T) { + m := NewExtensionMatch(".rdf", ".rdf.gz") + + assert.True(t, m.Match("/var/data/import/users.rdf")) + assert.True(t, m.Match("/var/data/import/users.rdf.gz")) + assert.False(t, m.Match("/var/data/import/users.csv")) + assert.False(t, m.Match("/var/data/import/schema.dgraph")) +} + +func TestFileMatchFunc(t *testing.T) { + var f FileMatch = FileMatchFunc(func(path string) bool { + return path == "special.rdf" + }) + + assert.True(t, f.Match("special.rdf")) + assert.False(t, f.Match("other.rdf")) +} + +func TestWithFileMatch(t *testing.T) { + var opts Options + custom := NewExtensionMatch(".nq", ".nq.gz") + WithFileMatch(custom)(&opts) + + assert.NotNil(t, opts.FileMatch) + assert.True(t, opts.MatchFile("data.nq")) + assert.False(t, opts.MatchFile("data.csv")) +} + +func TestWithOptionsIncludesFileMatch(t *testing.T) { + custom := NewExtensionMatch(".nq") + var opts Options + + WithOptions(Options{FileMatch: custom})(&opts) + assert.NotNil(t, opts.FileMatch) + assert.True(t, opts.MatchFile("x.nq")) +} + +func TestWithOptionsNilFileMatchPreservesExisting(t *testing.T) { + custom := NewExtensionMatch(".nq") + opts := Options{FileMatch: custom} + + WithOptions(Options{})(&opts) + assert.NotNil(t, opts.FileMatch) + assert.True(t, opts.MatchFile("x.nq")) +} + +// FilterFiles method tests + +func TestFilterFilesNilMatchReturnsAll(t *testing.T) { + var opts Options + input := []string{"a.rdf", "b.json", "c.txt", "d.csv"} + assert.Equal(t, input, opts.FilterFiles(input)) +} + +func TestFilterFilesWithMatch(t *testing.T) { + opts := Options{ + FileMatch: NewExtensionMatch(".rdf", ".json"), + } + input := []string{"a.rdf", "b.json", "c.txt", "d.csv", "e.rdf.gz"} + assert.Equal(t, []string{"a.rdf", "b.json"}, opts.FilterFiles(input)) +} + +func TestFilterFilesEmptyInput(t *testing.T) { + opts := Options{ + FileMatch: NewExtensionMatch(".rdf"), + } + assert.Nil(t, opts.FilterFiles(nil)) + assert.Nil(t, opts.FilterFiles([]string{})) +} + +func TestFilterFilesNoMatches(t *testing.T) { + opts := Options{ + FileMatch: NewExtensionMatch(".nq"), + } + input := []string{"a.rdf", "b.json"} + assert.Nil(t, opts.FilterFiles(input)) +} + +func TestFilterFilesPreservesOrder(t *testing.T) { + opts := Options{ + FileMatch: NewExtensionMatch(".rdf"), + } + input := []string{"c.rdf", "a.txt", "b.rdf", "d.csv", "a.rdf"} + assert.Equal(t, []string{"c.rdf", "b.rdf", "a.rdf"}, opts.FilterFiles(input), + "filtered files should preserve original order") +} + +func TestFilterFilesDoesNotMutateInput(t *testing.T) { + opts := Options{ + FileMatch: NewExtensionMatch(".rdf"), + } + input := []string{"a.rdf", "b.txt", "c.rdf"} + inputCopy := make([]string, len(input)) + copy(inputCopy, input) + + opts.FilterFiles(input) + assert.Equal(t, inputCopy, input, "FilterFiles should not mutate the input slice") +} + +// Pipeline tests — FilterFiles then SortFiles + +func TestFilterThenSort(t *testing.T) { + opts := Options{ + FileMatch: NewExtensionMatch(".rdf", ".rdf.gz"), + SortFiles: FileSort(func(files []string) []string { + // Reverse sort + out := make([]string, len(files)) + for i, f := range files { + out[len(files)-1-i] = f + } + return out + }), + } + + input := []string{"z.rdf", "a.csv", "m.rdf.gz", "b.rdf", "x.json"} + + // Step 1: filter + filtered := opts.FilterFiles(input) + assert.Equal(t, []string{"z.rdf", "m.rdf.gz", "b.rdf"}, filtered) + + // Step 2: sort + sorted := opts.SortFiles(filtered) + assert.Equal(t, []string{"b.rdf", "m.rdf.gz", "z.rdf"}, sorted) +} + +func TestFilterWithoutSortLeavesOrder(t *testing.T) { + opts := Options{ + FileMatch: NewExtensionMatch(".rdf"), + // SortFiles intentionally nil + } + + input := []string{"c.rdf", "a.rdf", "b.rdf"} + filtered := opts.FilterFiles(input) + assert.Equal(t, []string{"c.rdf", "a.rdf", "b.rdf"}, filtered, + "without SortFiles, original order is preserved") +} + +func TestSortWithoutFilterUsesAllFiles(t *testing.T) { + opts := Options{ + // FileMatch intentionally nil — all files match + SortFiles: FileSort(func(files []string) []string { + // Reverse + out := make([]string, len(files)) + for i, f := range files { + out[len(files)-1-i] = f + } + return out + }), + } + + input := []string{"a.rdf", "b.json", "c.txt"} + filtered := opts.FilterFiles(input) + assert.Equal(t, input, filtered, "nil FileMatch returns all files") + + sorted := opts.SortFiles(filtered) + assert.Equal(t, []string{"c.txt", "b.json", "a.rdf"}, sorted) +} + +// SortFiles tests + +func TestWithFileSort(t *testing.T) { + var opts Options + + reverse := FileSort(func(files []string) []string { + out := make([]string, len(files)) + for i, f := range files { + out[len(files)-1-i] = f + } + return out + }) + WithFileSort(reverse)(&opts) + + assert.NotNil(t, opts.SortFiles) + result := opts.SortFiles([]string{"a.rdf", "b.rdf", "c.rdf"}) + assert.Equal(t, []string{"c.rdf", "b.rdf", "a.rdf"}, result) +} + +func TestSortFilesNilByDefault(t *testing.T) { + var opts Options + assert.Nil(t, opts.SortFiles) +} + +func TestWithOptionsIncludesSortFiles(t *testing.T) { + identity := FileSort(func(files []string) []string { return files }) + var opts Options + + WithOptions(Options{SortFiles: identity})(&opts) + assert.NotNil(t, opts.SortFiles) +} + +func TestWithOptionsNilSortFilesPreservesExisting(t *testing.T) { + identity := FileSort(func(files []string) []string { return files }) + opts := Options{SortFiles: identity} + + WithOptions(Options{})(&opts) + assert.NotNil(t, opts.SortFiles) +} diff --git a/loaddata.go b/loaddata.go new file mode 100644 index 0000000..fb32455 --- /dev/null +++ b/loaddata.go @@ -0,0 +1,323 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "sync/atomic" + "time" + + "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/dgraph-io/dgraph/v25/chunker" + "github.com/dgraph-io/dgraph/v25/filestore" + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/x" + "github.com/go-logr/logr" + "github.com/matthewmcneely/modusgraph/load" + "golang.org/x/sync/errgroup" +) + +const ( + defaultMaxRoutines = 4 + defaultNumBatchesInBuf = 100 + defaultNqchBufferSize = 10000 + defaultProgressFrequency = 5 * time.Second +) + +// Mutator abstracts mutation submission for the LiveLoader. +// Implementations exist for both the embedded engine and gRPC clients. +type Mutator interface { + // Mutate submits a batch of NQuads. The returned map contains blank node + // to UID mappings from the server (gRPC path); for the embedded engine + // the map is nil because blank nodes are pre-resolved locally. + Mutate(ctx context.Context, mu *api.Mutation) (map[string]string, error) +} + +// namespaceMutator implements Mutator for the embedded engine path. +type namespaceMutator struct { + ns *Namespace +} + +func (m *namespaceMutator) Mutate(ctx context.Context, mu *api.Mutation) (map[string]string, error) { + _, err := m.ns.Mutate(ctx, []*api.Mutation{mu}) + return nil, err +} + +// UIDAllocator allocates UIDs for blank node resolution. +// For the embedded engine, the Engine type satisfies this directly. +// For gRPC, a bulk-allocating implementation leases UIDs from the Zero leader. +// A nil UIDAllocator means blank nodes are sent to the server as-is. +type UIDAllocator interface { + LeaseUIDs(n uint64) (*pb.AssignedIds, error) +} + +type liveLoader struct { + mut Mutator + uidAlloc UIDAllocator // nil when server allocates UIDs + blankNodes map[string]string + mutex sync.RWMutex + logger logr.Logger + batchSize int +} + +// Load reads a schema file and data directory, applying both to this namespace. +func (n *Namespace) Load(ctx context.Context, schemaPath, dataPath string) error { + schemaData, err := os.ReadFile(schemaPath) + if err != nil { + return fmt.Errorf("read schema file [%v]: %w", schemaPath, err) + } + if err := n.AlterSchema(ctx, string(schemaData)); err != nil { + return fmt.Errorf("alter schema: %w", err) + } + if err := n.LoadData(ctx, dataPath); err != nil { + return fmt.Errorf("load data: %w", err) + } + return nil +} + +// LoadData loads RDF or JSON data files from dataDir into this namespace. +func (n *Namespace) LoadData(inCtx context.Context, dataDir string) error { + ll := &liveLoader{ + mut: &namespaceMutator{ns: n}, + uidAlloc: n.engine, + blankNodes: make(map[string]string), + logger: n.engine.logger, + } + return loadData(inCtx, ll, dataDir, load.Options{}) +} + +// loadData runs the core data-loading pipeline: find files, spawn file +// processors, and feed mutations to concurrent workers. +func loadData(inCtx context.Context, ll *liveLoader, dataDir string, opts load.Options) error { + fs := filestore.NewFileStore(dataDir) + + var allFiles []string + if err := filepath.Walk(dataDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + allFiles = append(allFiles, path) + } + return nil + }); err != nil { + return fmt.Errorf("walk data dir [%s]: %w", dataDir, err) + } + + files := opts.FilterFiles(allFiles) + if opts.SortFiles != nil { + files = opts.SortFiles(files) + } + if len(files) == 0 { + return fmt.Errorf("no data files found in [%v]", dataDir) + } + + batchSize := opts.GetBatchSize() + numWorkers := opts.GetMutationWorkers() + ll.batchSize = batchSize + ll.logger.Info("Found data files to process", "count", len(files)) + + rootG, rootCtx := errgroup.WithContext(inCtx) + procG, procCtx := errgroup.WithContext(rootCtx) + procG.SetLimit(defaultMaxRoutines) + + start := time.Now() + var nquadsProcessed atomic.Int64 + nqch := make(chan *api.Mutation, defaultNqchBufferSize) + + // Progress reporter — runs outside the errgroup so it doesn't block + // completion. Stopped via context cancellation when loadData returns. + tickCtx, tickCancel := context.WithCancel(rootCtx) + defer tickCancel() + go func() { + ticker := time.NewTicker(defaultProgressFrequency) + defer ticker.Stop() + + var last int64 + for { + select { + case <-tickCtx.Done(): + return + case <-ticker.C: + cur := nquadsProcessed.Load() + elapsed := time.Since(start).Round(time.Second) + rate := float64(cur-last) / defaultProgressFrequency.Seconds() + ll.logger.Info("Data loading progress", "elapsed", x.FixedDuration(elapsed), + "nquadsProcessed", cur, + "writesPerSecond", fmt.Sprintf("%5.0f", rate)) + last = cur + } + } + }() + + // Mutation workers — with pre-allocated UIDs, mutations are independent + // and can execute concurrently. + for range numWorkers { + rootG.Go(func() error { + for nqs := range nqch { + if _, err := ll.mut.Mutate(rootCtx, nqs); err != nil { + return fmt.Errorf("apply mutations: %w", err) + } + nquadsProcessed.Add(int64(len(nqs.Set))) + } + return nil + }) + } + + for _, datafile := range files { + procG.Go(func() error { + return ll.processFile(procCtx, fs, datafile, nqch) + }) + } + + if err := procG.Wait(); err != nil { + rootG.Go(func() error { return err }) + } + + close(nqch) + return rootG.Wait() +} + +func (l *liveLoader) processFile(inCtx context.Context, fs filestore.FileStore, + filename string, nqch chan *api.Mutation) error { + + l.logger.Info("Processing data file", "filename", filename) + + rd, cleanup := fs.ChunkReader(filename, nil) + defer cleanup() + + loadType := chunker.DataFormat(filename, "") + if loadType == chunker.UnknownFormat { + isJSON, err := chunker.IsJSONData(rd) + if err == nil && isJSON { + loadType = chunker.JsonFormat + } else { + return fmt.Errorf("unable to determine data format for [%v]", filename) + } + } + + bs := l.batchSize + g, ctx := errgroup.WithContext(inCtx) + ck := chunker.NewChunker(loadType, bs) + nqbuf := ck.NQuads() + + g.Go(func() error { + buffer := make([]*api.NQuad, 0, defaultNumBatchesInBuf*bs) + + drain := func() { + for len(buffer) > 0 { + sz := bs + if len(buffer) < bs { + sz = len(buffer) + } + nqch <- &api.Mutation{Set: buffer[:sz]} + buffer = buffer[sz:] + } + } + + loop := true + for loop { + select { + case <-ctx.Done(): + return ctx.Err() + + case nqs, ok := <-nqbuf.Ch(): + if !ok { + loop = false + break + } + if len(nqs) == 0 { + continue + } + + var err error + for _, nq := range nqs { + nq.Subject, err = l.uid(nq.Namespace, nq.Subject) + if err != nil { + return fmt.Errorf("get UID for subject: %w", err) + } + if len(nq.ObjectId) > 0 { + nq.ObjectId, err = l.uid(nq.Namespace, nq.ObjectId) + if err != nil { + return fmt.Errorf("get UID for object: %w", err) + } + } + } + + buffer = append(buffer, nqs...) + if len(buffer) < defaultNumBatchesInBuf*bs { + continue + } + drain() + } + } + drain() + return nil + }) + + g.Go(func() error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + chunkBuf, errChunk := ck.Chunk(rd) + if errChunk != nil && errChunk != io.EOF { + return fmt.Errorf("chunk data: %w", errChunk) + } + if err := ck.Parse(chunkBuf); err != nil { + return fmt.Errorf("parse chunk: %w", err) + } + if errChunk != nil { + break + } + } + + nqbuf.Flush() + return nil + }) + + return g.Wait() +} + +func (l *liveLoader) uid(ns uint64, val string) (string, error) { + key := x.NamespaceAttr(ns, val) + + l.mutex.RLock() + uid, ok := l.blankNodes[key] + l.mutex.RUnlock() + if ok { + return uid, nil + } + + if l.uidAlloc == nil { + return val, nil + } + + l.mutex.Lock() + defer l.mutex.Unlock() + + uid, ok = l.blankNodes[key] + if ok { + return uid, nil + } + + asUID, err := l.uidAlloc.LeaseUIDs(1) + if err != nil { + return "", fmt.Errorf("allocate UID: %w", err) + } + + uid = fmt.Sprintf("%#x", asUID.StartId) + l.blankNodes[key] = uid + return uid, nil +} diff --git a/loaddata_grpc.go b/loaddata_grpc.go new file mode 100644 index 0000000..1988a25 --- /dev/null +++ b/loaddata_grpc.go @@ -0,0 +1,160 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "context" + "errors" + "fmt" + "os" + "sync" + "time" + + "github.com/dgraph-io/dgo/v250" + "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/matthewmcneely/modusgraph/load" +) + +// grpcUIDAllocator implements UIDAllocator for the gRPC path by calling +// dgo.Dgraph.AllocateUIDs in bulk. UIDs are leased in batches to minimise +// round-trips to the Zero leader. +type grpcUIDAllocator struct { + pool *clientPool + mu sync.Mutex + nextUID uint64 + maxUID uint64 // exclusive upper bound of the current lease +} + +const uidLeaseBatch uint64 = 10000 + +func (a *grpcUIDAllocator) LeaseUIDs(n uint64) (*pb.AssignedIds, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.nextUID+n > a.maxUID { + alloc := uidLeaseBatch + if n > uidLeaseBatch { + alloc = n + } + dc, err := a.pool.get() + if err != nil { + return nil, fmt.Errorf("get client from pool for UID allocation: %w", err) + } + start, end, err := dc.AllocateUIDs(context.TODO(), alloc) + a.pool.put(dc) + if err != nil { + return nil, fmt.Errorf("allocate UIDs: %w", err) + } + a.nextUID = start + a.maxUID = end + } + + uid := a.nextUID + a.nextUID += n + return &pb.AssignedIds{StartId: uid}, nil +} + +// grpcMutator implements Mutator for the gRPC (remote Dgraph) path. +// It retries aborted transactions using the same backoff policy as +// RetryPolicy (exponential backoff with jitter). +type grpcMutator struct { + pool *clientPool + policy RetryPolicy +} + +func (m *grpcMutator) Mutate(ctx context.Context, mu *api.Mutation) (map[string]string, error) { + policy := m.policy + if policy.MaxRetries <= 0 { + policy = DefaultRetryPolicy + } + mu.CommitNow = true + + for attempt := 0; ; attempt++ { + dc, err := m.pool.get() + if err != nil { + return nil, fmt.Errorf("get client from pool: %w", err) + } + + txn := dc.NewTxn() + resp, err := txn.Mutate(ctx, mu) + txn.Discard(ctx) + m.pool.put(dc) + + if err == nil { + return resp.GetUids(), nil + } + if !errors.Is(err, dgo.ErrAborted) || attempt >= policy.MaxRetries { + return nil, err + } + + d := policy.delay(attempt) + select { + case <-sleepTimer(d): + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +// sleepTimer returns a channel that fires after d. Extracted for consistency +// with retry.go (both use the same mechanism). +func sleepTimer(d time.Duration) <-chan time.Time { + return time.After(d) +} + +// LoadData loads RDF or JSON data files from dataDir into the database. +func (c client) LoadData(ctx context.Context, dataDir string, opts ...load.Option) error { + options := load.Options{} + for _, opt := range opts { + opt(&options) + } + + // Apply schema if requested. + if options.SchemaPath != "" { + schemaData, err := os.ReadFile(options.SchemaPath) + if err != nil { + return fmt.Errorf("read schema file [%s]: %w", options.SchemaPath, err) + } + if err := c.alterSchema(ctx, string(schemaData)); err != nil { + return fmt.Errorf("alter schema: %w", err) + } + } + + // Both paths go through loadData() so that caller options (FileMatch, + // SortFiles, BatchSize, MutationWorkers) are always respected. + var ll *liveLoader + if c.engine != nil { + ll = &liveLoader{ + mut: &namespaceMutator{ns: c.engine.db0}, + uidAlloc: c.engine, + blankNodes: make(map[string]string), + logger: c.engine.logger, + } + } else { + ll = &liveLoader{ + mut: &grpcMutator{pool: c.pool}, + uidAlloc: &grpcUIDAllocator{pool: c.pool}, + blankNodes: make(map[string]string), + logger: c.logger, + } + } + return loadData(ctx, ll, dataDir, options) +} + +func (c client) alterSchema(ctx context.Context, schema string) error { + if c.engine != nil { + return c.engine.db0.AlterSchema(ctx, schema) + } + + dc, err := c.pool.get() + if err != nil { + return err + } + defer c.pool.put(dc) + + return dc.Alter(ctx, &api.Operation{Schema: schema}) +} diff --git a/loaddata_test.go b/loaddata_test.go new file mode 100644 index 0000000..92ae329 --- /dev/null +++ b/loaddata_test.go @@ -0,0 +1,472 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "bytes" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" + + mg "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/load" + "github.com/stretchr/testify/require" +) + +// TestClientLoadDataFile tests LoadData via the file:// URI (embedded engine) path. +func TestClientLoadDataFile(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + // Create RDF data directory and file. + rdfDir := filepath.Join(tmpDir, "rdf_data") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + + rdfData := `_:alice "Person" . +_:alice "Alice" . +_:alice "30"^^ . +_:bob "Person" . +_:bob "Bob" . +_:bob "25"^^ . +_:alice _:bob . +` + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "data.rdf"), []byte(rdfData), 0600)) + + // Create schema file (outside the data dir). + schemaFile := filepath.Join(tmpDir, "schema.dgraph") + schemaData := `name: string @index(exact, term) . +age: int . +friend: [uid] @reverse . +type Person { + name + age + friend +} +` + require.NoError(t, os.WriteFile(schemaFile, []byte(schemaData), 0600)) + + // Load data with schema. + ctx := context.Background() + err := client.LoadData(ctx, rdfDir, load.WithSchema(schemaFile)) + require.NoError(t, err) + + // Query for Alice and verify friend edge to Bob. + const query = `{ + q(func: eq(name, "Alice")) { + name + age + friend { + name + age + } + } + }` + + resp, err := client.QueryRaw(ctx, query, nil) + require.NoError(t, err) + + var result struct { + Q []struct { + Name string `json:"name"` + Age int `json:"age"` + Friend []struct { + Name string `json:"name"` + Age int `json:"age"` + } `json:"friend"` + } `json:"q"` + } + require.NoError(t, json.Unmarshal(resp, &result)) + + require.Len(t, result.Q, 1, "expected exactly one Alice node") + require.Equal(t, "Alice", result.Q[0].Name) + require.Equal(t, 30, result.Q[0].Age) + require.Len(t, result.Q[0].Friend, 1, "Alice should have exactly one friend") + require.Equal(t, "Bob", result.Q[0].Friend[0].Name) + require.Equal(t, 25, result.Q[0].Friend[0].Age) +} + +// NoSchemaNode is a test struct used by TestClientLoadDataFileNoSchema. +type NoSchemaNode struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Title string `json:"title,omitempty" dgraph:"index=exact"` +} + +// TestClientLoadDataFileNoSchema tests LoadData without WithSchema — the schema +// must already exist in the database. +func TestClientLoadDataFileNoSchema(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + ctx := context.Background() + + // Manually set up schema first via autoSchema. + err := client.UpdateSchema(ctx, &NoSchemaNode{}) + require.NoError(t, err) + + // Write RDF file with no schema option. + rdfDir := filepath.Join(tmpDir, "rdf_noschema") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + rdf := `_:a "NoSchemaNode" . +_:a "Hello" . +` + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "data.rdf"), []byte(rdf), 0600)) + + err = client.LoadData(ctx, rdfDir) + require.NoError(t, err) +} + +// TestClientLoadDataBadSchemaPath verifies that a bad schema path returns an error. +func TestClientLoadDataBadSchemaPath(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + rdfDir := filepath.Join(tmpDir, "rdf_empty") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "data.rdf"), []byte(`_:a <name> "x" .`+"\n"), 0600)) + + err := client.LoadData(context.Background(), rdfDir, load.WithSchema("/nonexistent/schema.dgraph")) + require.Error(t, err) + require.Contains(t, err.Error(), "read schema file") +} + +// TestClientLoadDataEmptyDir verifies that an empty data directory returns an error. +func TestClientLoadDataEmptyDir(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + emptyDir := filepath.Join(tmpDir, "empty_rdf") + require.NoError(t, os.MkdirAll(emptyDir, 0755)) + + err := client.LoadData(context.Background(), emptyDir) + require.Error(t, err) + require.Contains(t, err.Error(), "no data files found") +} + +// TestClientLoadDataWithIndividualOpts verifies that WithBatchSize and WithMutationWorkers +// are accepted and don't cause errors. +func TestClientLoadDataWithIndividualOpts(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + rdfDir := filepath.Join(tmpDir, "rdf_opts") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + + rdf := `_:x <dgraph.type> "OptsTestNode" . +_:x <name> "test" . +` + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "data.rdf"), []byte(rdf), 0600)) + + schemaFile := filepath.Join(tmpDir, "opts_schema.dgraph") + require.NoError(t, os.WriteFile(schemaFile, []byte("name: string @index(exact) .\ntype OptsTestNode {\n name: string\n}\n"), 0600)) + + // Use all option funcs together. + err := client.LoadData(context.Background(), rdfDir, + load.WithSchema(schemaFile), + load.WithBatchSize(5000), + load.WithMutationWorkers(4), + ) + require.NoError(t, err) +} + +// TestClientLoadDataWithOptions verifies the struct-based option func. +func TestClientLoadDataWithOptions(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + rdfDir := filepath.Join(tmpDir, "rdf_struct_opts") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + + rdf := `_:y <dgraph.type> "StructOptsNode" . +_:y <name> "struct-test" . +` + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "data.rdf"), []byte(rdf), 0600)) + + schemaFile := filepath.Join(tmpDir, "struct_opts_schema.dgraph") + require.NoError(t, os.WriteFile(schemaFile, []byte("name: string @index(exact) .\ntype StructOptsNode {\n name: string\n}\n"), 0600)) + + err := client.LoadData(context.Background(), rdfDir, + load.WithOptions(load.Options{ + SchemaPath: schemaFile, + BatchSize: 10000, + MutationWorkers: 8, + }), + ) + require.NoError(t, err) +} + +// TestClientLoadDataFileMatchFiltersFiles verifies that WithFileMatch controls +// which files are loaded. We place two RDF files in a directory but use a +// FileMatch that only accepts one of them. +func TestClientLoadDataFileMatchFiltersFiles(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + ctx := context.Background() + + rdfDir := filepath.Join(tmpDir, "rdf_filematch") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + + // File 1: creates Alice + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "alice.rdf"), + []byte("_:alice <dgraph.type> \"FMPerson\" .\n_:alice <name> \"Alice\" .\n"), 0600)) + + // File 2: creates Bob — should be excluded by FileMatch + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "bob.rdf"), + []byte("_:bob <dgraph.type> \"FMPerson\" .\n_:bob <name> \"Bob\" .\n"), 0600)) + + schemaFile := filepath.Join(tmpDir, "fm_schema.dgraph") + require.NoError(t, os.WriteFile(schemaFile, + []byte("name: string @index(exact) .\ntype FMPerson {\n name: string\n}\n"), 0600)) + + // Only load alice.rdf + err := client.LoadData(ctx, rdfDir, + load.WithSchema(schemaFile), + load.WithFileMatch(load.FileMatchFunc(func(path string) bool { + return filepath.Base(path) == "alice.rdf" + })), + ) + require.NoError(t, err) + + resp, err := client.QueryRaw(ctx, `{ q(func: type(FMPerson)) { name } }`, nil) + require.NoError(t, err) + + var result struct { + Q []struct { + Name string `json:"name"` + } `json:"q"` + } + require.NoError(t, json.Unmarshal(resp, &result)) + require.Len(t, result.Q, 1, "only Alice should be loaded") + require.Equal(t, "Alice", result.Q[0].Name) +} + +// TestClientLoadDataMultipleFiles verifies loading multiple RDF files from one directory. +func TestClientLoadDataMultipleFiles(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + ctx := context.Background() + + rdfDir := filepath.Join(tmpDir, "rdf_multi") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "01_alice.rdf"), + []byte("_:alice <dgraph.type> \"MPerson\" .\n_:alice <name> \"Alice\" .\n"), 0600)) + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "02_bob.rdf"), + []byte("_:bob <dgraph.type> \"MPerson\" .\n_:bob <name> \"Bob\" .\n"), 0600)) + + schemaFile := filepath.Join(tmpDir, "multi_schema.dgraph") + require.NoError(t, os.WriteFile(schemaFile, + []byte("name: string @index(exact) .\ntype MPerson {\n name: string\n}\n"), 0600)) + + err := client.LoadData(ctx, rdfDir, load.WithSchema(schemaFile)) + require.NoError(t, err) + + resp, err := client.QueryRaw(ctx, `{ q(func: type(MPerson)) { count(uid) } }`, nil) + require.NoError(t, err) + + var result struct { + Q []struct { + Count int `json:"count"` + } `json:"q"` + } + require.NoError(t, json.Unmarshal(resp, &result)) + require.Len(t, result.Q, 1) + require.Equal(t, 2, result.Q[0].Count, "both files should be loaded") +} + +// TestClientLoadDataGzippedRDF verifies loading gzip-compressed RDF files. +func TestClientLoadDataGzippedRDF(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + ctx := context.Background() + + rdfDir := filepath.Join(tmpDir, "rdf_gz") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + + // Write gzipped RDF + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, err := gz.Write([]byte("_:x <dgraph.type> \"GZPerson\" .\n_:x <name> \"Gzipped\" .\n")) + require.NoError(t, err) + require.NoError(t, gz.Close()) + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "data.rdf.gz"), buf.Bytes(), 0600)) + + schemaFile := filepath.Join(tmpDir, "gz_schema.dgraph") + require.NoError(t, os.WriteFile(schemaFile, + []byte("name: string @index(exact) .\ntype GZPerson {\n name: string\n}\n"), 0600)) + + err = client.LoadData(ctx, rdfDir, load.WithSchema(schemaFile)) + require.NoError(t, err) + + resp, err := client.QueryRaw(ctx, `{ q(func: eq(name, "Gzipped")) { name } }`, nil) + require.NoError(t, err) + require.Contains(t, string(resp), "Gzipped") +} + +// TestClientLoadDataBlankNodeAcrossFiles verifies that blank nodes resolve +// correctly when the same blank node name appears in different files. +func TestClientLoadDataBlankNodeAcrossFiles(t *testing.T) { + tmpDir := GetTempDir(t) + client, cleanup := CreateTestClient(t, "file://"+tmpDir) + defer cleanup() + + ctx := context.Background() + + rdfDir := filepath.Join(tmpDir, "rdf_xref") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + + // File 1: define Alice with a name + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "01_nodes.rdf"), + []byte("_:alice <dgraph.type> \"XRefPerson\" .\n_:alice <name> \"Alice\" .\n"+ + "_:bob <dgraph.type> \"XRefPerson\" .\n_:bob <name> \"Bob\" .\n"), 0600)) + + // File 2: add edge from Alice to Bob using same blank node names + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "02_edges.rdf"), + []byte("_:alice <friend> _:bob .\n"), 0600)) + + schemaFile := filepath.Join(tmpDir, "xref_schema.dgraph") + require.NoError(t, os.WriteFile(schemaFile, + []byte("name: string @index(exact) .\nfriend: [uid] .\ntype XRefPerson {\n name: string\n friend\n}\n"), 0600)) + + err := client.LoadData(ctx, rdfDir, load.WithSchema(schemaFile)) + require.NoError(t, err) + + resp, err := client.QueryRaw(ctx, `{ q(func: eq(name, "Alice")) { name friend { name } } }`, nil) + require.NoError(t, err) + + var result struct { + Q []struct { + Name string `json:"name"` + Friend []struct { + Name string `json:"name"` + } `json:"friend"` + } `json:"q"` + } + require.NoError(t, json.Unmarshal(resp, &result)) + require.Len(t, result.Q, 1) + require.Equal(t, "Alice", result.Q[0].Name) + require.Len(t, result.Q[0].Friend, 1, "Alice should have friend Bob from cross-file blank node") + require.Equal(t, "Bob", result.Q[0].Friend[0].Name) +} + +// TestClientLoadDataGRPC tests LoadData via the dgraph:// URI (gRPC) path. +// This is the critical test — it verifies blank node resolution works across +// batches over gRPC. +func TestClientLoadDataGRPC(t *testing.T) { + addr := os.Getenv("MODUSGRAPH_TEST_ADDR") + if addr == "" { + t.Skip("Skipping: MODUSGRAPH_TEST_ADDR not set") + } + + ctx := context.Background() + + // Create client manually (not via CreateTestClient) so we can DropAll first. + client, err := mg.NewClient("dgraph://" + addr) + require.NoError(t, err) + defer client.Close() + + require.NoError(t, client.DropAll(ctx)) + + // Build RDF data: 100 LoadTestPerson nodes, each linked to the previous one. + tmpDir := t.TempDir() + rdfDir := filepath.Join(tmpDir, "rdf_data") + require.NoError(t, os.MkdirAll(rdfDir, 0755)) + + var rdf string + for i := 0; i < 100; i++ { + blank := fmt.Sprintf("_:person%d", i) + rdf += fmt.Sprintf("%s <dgraph.type> \"LoadTestPerson\" .\n", blank) + rdf += fmt.Sprintf("%s <name> \"Person %d\" .\n", blank, i) + rdf += fmt.Sprintf("%s <age> \"%d\"^^<xs:int> .\n", blank, 20+i) + if i > 0 { + prev := fmt.Sprintf("_:person%d", i-1) + rdf += fmt.Sprintf("%s <friend> %s .\n", blank, prev) + } + } + require.NoError(t, os.WriteFile(filepath.Join(rdfDir, "data.rdf"), []byte(rdf), 0600)) + + // Create schema file. + schemaFile := filepath.Join(tmpDir, "schema.dgraph") + schemaData := `name: string @index(exact, term) . +age: int . +friend: [uid] @reverse . +type LoadTestPerson { + name + age + friend +} +` + require.NoError(t, os.WriteFile(schemaFile, []byte(schemaData), 0600)) + + // Load data with schema. + err = client.LoadData(ctx, rdfDir, load.WithSchema(schemaFile)) + require.NoError(t, err) + + // Verify count is 100. + countQuery := `{ + q(func: type(LoadTestPerson)) { + count(uid) + } + }` + resp, err := client.QueryRaw(ctx, countQuery, nil) + require.NoError(t, err) + + var countResult struct { + Q []struct { + Count int `json:"count"` + } `json:"q"` + } + require.NoError(t, json.Unmarshal(resp, &countResult)) + require.Len(t, countResult.Q, 1) + require.Equal(t, 100, countResult.Q[0].Count, "expected 100 LoadTestPerson nodes") + + // Verify blank node resolution: Person 99's friend should be Person 98. + friendQuery := `{ + q(func: eq(name, "Person 99")) { + name + friend { + name + } + } + }` + resp, err = client.QueryRaw(ctx, friendQuery, nil) + require.NoError(t, err) + + var friendResult struct { + Q []struct { + Name string `json:"name"` + Friend []struct { + Name string `json:"name"` + } `json:"friend"` + } `json:"q"` + } + require.NoError(t, json.Unmarshal(resp, &friendResult)) + + require.Len(t, friendResult.Q, 1, "expected exactly one Person 99 node") + require.Equal(t, "Person 99", friendResult.Q[0].Name) + require.Len(t, friendResult.Q[0].Friend, 1, "Person 99 should have exactly one friend") + require.Equal(t, "Person 98", friendResult.Q[0].Friend[0].Name, + "blank node resolution failed: Person 99's friend should be Person 98") + + // Clean up. + require.NoError(t, client.DropAll(ctx)) +}