diff --git a/client.go b/client.go index be9813b..7834db2 100644 --- a/client.go +++ b/client.go @@ -124,6 +124,7 @@ type clientOptions struct { maxEdgeTraversal int cacheSizeMB int maxRecvMsgSize int + grpcDialOptions []grpc.DialOption namespace string logger logr.Logger validator StructValidator @@ -189,6 +190,18 @@ func WithMaxRecvMsgSize(size int) ClientOpt { } } +// WithGRPCDialOption appends a custom grpc.DialOption applied when opening a +// remote (dgraph://) connection. It is the general escape hatch for gRPC dial +// settings the dedicated options do not cover — TLS transport credentials, +// interceptors, keepalive parameters, and so on. May be supplied multiple +// times; the options are applied in the order given, after any option implied +// by WithMaxRecvMsgSize. Ignored for embedded (file://) URIs. +func WithGRPCDialOption(opt grpc.DialOption) ClientOpt { + return func(o *clientOptions) { + o.grpcDialOptions = append(o.grpcDialOptions, opt) + } +} + // WithValidator sets a validator instance for struct validation. // The validator will be used to validate structs before insert, upsert, and update operations. // If no validator is provided, validation will be skipped. @@ -279,16 +292,26 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) { client.logger.V(2).Info("Opening new Dgraph connection", "uri", uri) return dgo.Open(uri) } + // Assemble any custom gRPC dial options. maxRecvMsgSize is folded + // into the same mechanism as WithGRPCDialOption so the two compose. + var dialOpts []grpc.DialOption if options.maxRecvMsgSize > 0 { + dialOpts = append(dialOpts, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(options.maxRecvMsgSize))) + } + dialOpts = append(dialOpts, options.grpcDialOptions...) + if len(dialOpts) > 0 { endpoint, dgoOpts, err := parseDgraphURI(uri) if err != nil { return nil, err } - dgoOpts = append(dgoOpts, dgo.WithGrpcOption( - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(options.maxRecvMsgSize)))) + for _, opt := range dialOpts { + dgoOpts = append(dgoOpts, dgo.WithGrpcOption(opt)) + } factory = func() (*dgo.Dgraph, error) { client.logger.V(2).Info("Opening new Dgraph connection", - "uri", uri, "maxRecvMsgSize", options.maxRecvMsgSize) + "uri", uri, "maxRecvMsgSize", options.maxRecvMsgSize, + "grpcDialOptions", len(options.grpcDialOptions)) return dgo.NewClient(endpoint, dgoOpts...) } } @@ -430,9 +453,9 @@ func (c client) key() string { if c.options.embeddingProvider != nil { embeddingKey = fmt.Sprintf("%p", c.options.embeddingProvider) } - return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize, + return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s:%d", c.uri, c.options.autoSchema, c.options.poolSize, c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.maxRecvMsgSize, - c.options.namespace, validatorKey, embeddingKey) + c.options.namespace, validatorKey, embeddingKey, len(c.options.grpcDialOptions)) } // embeddingProvider implements the embeddingClient interface, exposing the diff --git a/dial_options_test.go b/dial_options_test.go new file mode 100644 index 0000000..c64e257 --- /dev/null +++ b/dial_options_test.go @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "testing" + + "google.golang.org/grpc" +) + +func TestWithGRPCDialOptionAppends(t *testing.T) { + var o clientOptions + WithGRPCDialOption(grpc.WithUserAgent("a"))(&o) + WithGRPCDialOption(grpc.WithUserAgent("b"))(&o) + if got := len(o.grpcDialOptions); got != 2 { + t.Fatalf("expected 2 dial options, got %d", got) + } +} + +func TestKeyDistinguishesGRPCDialOptions(t *testing.T) { + base := client{uri: "dgraph://localhost:9080"} + withOpt := client{uri: "dgraph://localhost:9080"} + WithGRPCDialOption(grpc.WithUserAgent("x"))(&withOpt.options) + if base.key() == withOpt.key() { + t.Fatal("client.key() must differ when grpcDialOptions differ, else clients dedup incorrectly") + } +}