diff --git a/.gitignore b/.gitignore index 8ebb177..c97762e 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,5 @@ Thumbs.db # Solr data (managed by Docker volume) solr-data/ /solr-mem-indexer +/solr-mem-bench +/solr-mem-backfill diff --git a/Makefile b/Makefile index e776d44..36b9e79 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-indexer install install-indexer test tidy run dev up down logs reset docker-build docker-up config systemd-install systemd-uninstall indexer-install indexer-uninstall launchd-install launchd-install-server launchd-install-indexer launchd-uninstall skill-install skill-uninstall +.PHONY: build build-indexer build-bench build-backfill install install-indexer test tidy run dev up down logs reset docker-build docker-up config systemd-install systemd-uninstall indexer-install indexer-uninstall launchd-install launchd-install-server launchd-install-indexer launchd-uninstall skill-install skill-uninstall bench backfill backfill-dry # Go targets build: @@ -7,7 +7,31 @@ build: build-indexer: go build -o bin/solr-mem-indexer ./cmd/solr-mem-indexer -build-all: build build-indexer +build-all: build build-indexer build-bench build-backfill + +build-bench: + go build -o bin/solr-mem-bench ./cmd/solr-mem-bench + +build-backfill: + go build -o bin/solr-mem-backfill ./cmd/solr-mem-backfill + +# Backfill embeddings on existing memories. Requires OPENAI_API_KEY. +# Use backfill-dry to preview without writing. +BACKFILL_URL ?= http://pax89.local:8983/solr/memories +backfill: build-backfill + ./bin/solr-mem-backfill -solr-url $(BACKFILL_URL) + +backfill-dry: build-backfill + ./bin/solr-mem-backfill -solr-url $(BACKFILL_URL) -dry-run + +# Retrieval benchmark. Seeds the memories collection with namespaced bench-* +# docs (safe to run against a live collection — only touches bench-* IDs) and +# runs the shipped query set. Override BENCH_URL to point elsewhere. +BENCH_URL ?= http://pax89.local:8983/solr/memories +bench: build-bench + ./bin/solr-mem-bench -solr-url $(BENCH_URL) -seed \ + -corpus cmd/solr-mem-bench/testdata/corpus.jsonl \ + -queries cmd/solr-mem-bench/testdata/queries.jsonl install: go install ./cmd/solr-mem-server diff --git a/cmd/solr-mem-backfill/backfill.go b/cmd/solr-mem-backfill/backfill.go new file mode 100644 index 0000000..4e00631 --- /dev/null +++ b/cmd/solr-mem-backfill/backfill.go @@ -0,0 +1,240 @@ +// Package main is the solr-mem-backfill tool: it computes embeddings for +// existing memories that don't have one yet and writes them back via atomic +// update. Safe to run incrementally — it always queries "missing embedding" +// first, so a partial run can be resumed just by running it again. +package main + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "github.com/arreyder/solr-mem/internal/embed" + "github.com/arreyder/solr-mem/internal/solr" +) + +// Options controls a backfill run. +type Options struct { + BatchSize int // docs per Solr page (default 50) + Concurrency int // parallel embed calls (default 4) + DryRun bool // compute but don't write + MaxDocs int // cap total processed (0 = unlimited) + Force bool // re-embed docs even if they already have an embedding + Pause time.Duration // sleep between batches (rate-limit headroom) +} + +// Stats summarize a run. +type Stats struct { + Scanned int // docs returned from Solr + Skipped int // already had an embedding (and !Force) + Embedded int // docs we computed vectors for + Written int // docs we actually wrote back (0 when DryRun) + Errors int // embed failures; these get logged but don't abort +} + +// Run scans the memories collection for docs needing embeddings, computes +// them, and writes back. It returns when the backlog is empty, MaxDocs is +// hit, or ctx is canceled. +func Run(ctx context.Context, client *solr.Client, provider embed.Provider, opts Options) (Stats, error) { + if provider == nil { + return Stats{}, fmt.Errorf("backfill: no embedding provider configured") + } + if opts.BatchSize <= 0 { + opts.BatchSize = 50 + } + if opts.Concurrency <= 0 { + opts.Concurrency = 4 + } + + // Track IDs we've already attempted this run. Without this, a doc that + // fails to embed (and therefore never gets its field written) would keep + // showing up in every "-_exists_:embedding" query and loop forever. + seen := make(map[string]bool) + + var stats Stats + for { + if opts.MaxDocs > 0 && stats.Embedded >= opts.MaxDocs { + break + } + if ctx.Err() != nil { + return stats, ctx.Err() + } + + batch, err := fetchBatch(ctx, client, opts.Force, opts.BatchSize) + if err != nil { + return stats, fmt.Errorf("fetch batch: %w", err) + } + if len(batch) == 0 { + break + } + stats.Scanned += len(batch) + + // Build the work list: skip docs already attempted this run, and + // skip docs with existing embeddings unless Force. + work := make([]memoryDoc, 0, len(batch)) + for _, d := range batch { + if seen[d.ID] { + continue + } + seen[d.ID] = true + if !opts.Force && d.HasEmbedding { + stats.Skipped++ + continue + } + work = append(work, d) + } + if opts.MaxDocs > 0 && stats.Embedded+len(work) > opts.MaxDocs { + work = work[:opts.MaxDocs-stats.Embedded] + } + + // If every doc in this page was already seen, we've exhausted the + // backlog that this run can make progress on. + if len(work) == 0 { + break + } + + updates := embedBatch(ctx, provider, work, opts.Concurrency, &stats) + + if len(updates) > 0 && !opts.DryRun { + if err := client.BulkUpdate(ctx, updates); err != nil { + return stats, fmt.Errorf("bulk update: %w", err) + } + stats.Written += len(updates) + } + + log.Printf("backfill: scanned=%d skipped=%d embedded=%d written=%d errors=%d", + stats.Scanned, stats.Skipped, stats.Embedded, stats.Written, stats.Errors) + + if opts.Pause > 0 { + select { + case <-ctx.Done(): + return stats, ctx.Err() + case <-time.After(opts.Pause): + } + } + } + return stats, nil +} + +type memoryDoc struct { + ID string + Title string + Content string + HasEmbedding bool +} + +// fetchBatch returns up to batchSize memory docs. When !force, scopes to +// docs missing the embedding field so subsequent calls make progress as we +// write back. +func fetchBatch(ctx context.Context, client *solr.Client, force bool, batchSize int) ([]memoryDoc, error) { + params := solr.QueryParams{ + Query: "*:*", + Rows: batchSize, + Fields: []string{"id", "title", "content", "embedding"}, + Sort: "created_at asc", + } + if !force { + // "-_exists_:embedding" is Solr's canonical missing-field filter + // and works for any field type, including DenseVectorField. + params.FilterQueries = append(params.FilterQueries, "-_exists_:embedding") + } + resp, err := client.Query(ctx, params) + if err != nil { + return nil, err + } + out := make([]memoryDoc, 0, len(resp.Docs)) + for _, d := range resp.Docs { + id, _ := d["id"].(string) + title, _ := d["title"].(string) + content, _ := d["content"].(string) + has := false + if v, ok := d["embedding"]; ok { + if arr, ok := v.([]any); ok && len(arr) > 0 { + has = true + } + } + out = append(out, memoryDoc{ID: id, Title: title, Content: content, HasEmbedding: has}) + } + return out, nil +} + +// embedBatch runs opts.Concurrency embed calls in parallel and returns the +// atomic-update payloads for docs that succeeded. Failures are logged and +// counted but don't stop the run. +func embedBatch(ctx context.Context, provider embed.Provider, docs []memoryDoc, concurrency int, stats *Stats) []map[string]any { + if len(docs) == 0 { + return nil + } + + type result struct { + ID string + Embedding []float32 + Err error + } + + in := make(chan memoryDoc) + out := make(chan result) + + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for d := range in { + text := buildEmbedText(d.Title, d.Content) + if text == "" { + out <- result{ID: d.ID, Err: fmt.Errorf("empty text")} + continue + } + cctx, cancel := context.WithTimeout(ctx, 30*time.Second) + v, err := provider.Embed(cctx, text) + cancel() + out <- result{ID: d.ID, Embedding: v, Err: err} + } + }() + } + + go func() { + for _, d := range docs { + select { + case <-ctx.Done(): + break + case in <- d: + } + } + close(in) + wg.Wait() + close(out) + }() + + updates := make([]map[string]any, 0, len(docs)) + for r := range out { + if r.Err != nil { + stats.Errors++ + log.Printf("backfill: embed %s failed: %v", r.ID, r.Err) + continue + } + stats.Embedded++ + updates = append(updates, map[string]any{ + "id": r.ID, + "embedding": map[string]any{"set": r.Embedding}, + }) + } + return updates +} + +// buildEmbedText joins title and content with a blank line. Mirrors the +// shape used at write time in the server (embed_helper.go) so backfilled +// vectors live in the same space as live writes. +func buildEmbedText(title, content string) string { + t := title + if t != "" && content != "" { + return t + "\n\n" + content + } + if t != "" { + return t + } + return content +} diff --git a/cmd/solr-mem-backfill/backfill_test.go b/cmd/solr-mem-backfill/backfill_test.go new file mode 100644 index 0000000..b255767 --- /dev/null +++ b/cmd/solr-mem-backfill/backfill_test.go @@ -0,0 +1,296 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/arreyder/solr-mem/internal/solr" +) + +// fakeProvider returns a deterministic 3-dim vector for any input, or errors +// when Fail is non-nil. +type fakeProvider struct { + mu sync.Mutex + calls []string + Fail map[string]error // text -> err +} + +func (f *fakeProvider) Embed(ctx context.Context, text string) ([]float32, error) { + f.mu.Lock() + f.calls = append(f.calls, text) + f.mu.Unlock() + if err, ok := f.Fail[text]; ok { + return nil, err + } + return []float32{0.1, 0.2, 0.3}, nil +} +func (f *fakeProvider) Dim() int { return 3 } +func (f *fakeProvider) Name() string { return "fake" } + +// fakeSolr is a minimal httptest.Server standing in for the memories +// collection. It supports /select (returning the configured docs, shrunk +// after each update call) and /update (accepting atomic updates and +// removing the target doc from the "missing embedding" set). +type fakeSolr struct { + mu sync.Mutex + docs []map[string]any // mutable "collection" + queries int + updates [][]map[string]any // captured update payloads per request +} + +func newFakeSolr(initial []map[string]any) *fakeSolr { + return &fakeSolr{docs: initial} +} + +func (s *fakeSolr) handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/select"): + s.handleSelect(w, r) + case strings.Contains(r.URL.Path, "/update"): + s.handleUpdate(w, r) + default: + http.Error(w, "not found: "+r.URL.Path, 404) + } + }) +} + +func (s *fakeSolr) handleSelect(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + s.queries++ + + fqs := r.URL.Query()["fq"] + missingOnly := false + for _, fq := range fqs { + if fq == "-_exists_:embedding" { + missingOnly = true + break + } + } + + var out []map[string]any + for _, d := range s.docs { + if missingOnly { + if v, ok := d["embedding"]; ok && v != nil { + if arr, ok := v.([]any); ok && len(arr) > 0 { + continue + } + if _, ok := v.([]float32); ok { + continue + } + } + } + out = append(out, d) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "response": map[string]any{ + "numFound": len(out), + "start": 0, + "docs": out, + }, + }) +} + +func (s *fakeSolr) handleUpdate(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + body, _ := io.ReadAll(r.Body) + var payload []map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + http.Error(w, err.Error(), 400) + return + } + s.updates = append(s.updates, payload) + + // Apply atomic updates to our in-memory corpus. + for _, upd := range payload { + id, _ := upd["id"].(string) + for i, d := range s.docs { + if d["id"] != id { + continue + } + for k, v := range upd { + if k == "id" { + continue + } + if setter, ok := v.(map[string]any); ok { + if val, has := setter["set"]; has { + // Normalize to []any for consistent downstream checks. + if vec, ok := val.([]float32); ok { + arr := make([]any, len(vec)) + for j, f := range vec { + arr[j] = float64(f) + } + val = arr + } + s.docs[i][k] = val + } + } + } + } + } + _, _ = w.Write([]byte(`{"responseHeader":{"status":0}}`)) +} + +func TestRunEmbedsMissingDocs(t *testing.T) { + fs := newFakeSolr([]map[string]any{ + {"id": "a", "title": "A", "content": "alpha"}, + {"id": "b", "title": "B", "content": "beta"}, + }) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{} + + stats, err := Run(context.Background(), client, provider, Options{BatchSize: 10, Concurrency: 2}) + if err != nil { + t.Fatalf("run: %v", err) + } + if stats.Embedded != 2 { + t.Errorf("expected 2 embedded, got %d", stats.Embedded) + } + if stats.Written != 2 { + t.Errorf("expected 2 written, got %d", stats.Written) + } + if stats.Errors != 0 { + t.Errorf("expected 0 errors, got %d", stats.Errors) + } + if len(provider.calls) != 2 { + t.Errorf("expected 2 provider calls, got %d", len(provider.calls)) + } +} + +func TestRunDryRunSkipsWrite(t *testing.T) { + fs := newFakeSolr([]map[string]any{ + {"id": "a", "title": "A", "content": "alpha"}, + }) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{} + + stats, err := Run(context.Background(), client, provider, Options{DryRun: true}) + if err != nil { + t.Fatalf("run: %v", err) + } + if stats.Embedded != 1 { + t.Errorf("expected 1 embedded, got %d", stats.Embedded) + } + if stats.Written != 0 { + t.Errorf("dry-run should write nothing, got %d", stats.Written) + } + if len(fs.updates) != 0 { + t.Errorf("dry-run should not POST updates, got %d payloads", len(fs.updates)) + } +} + +func TestRunRespectsMaxDocs(t *testing.T) { + var docs []map[string]any + for i := 0; i < 10; i++ { + docs = append(docs, map[string]any{ + "id": fmt.Sprintf("d-%d", i), + "title": "t", + "content": "c", + }) + } + fs := newFakeSolr(docs) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{} + + stats, err := Run(context.Background(), client, provider, Options{BatchSize: 3, MaxDocs: 5}) + if err != nil { + t.Fatalf("run: %v", err) + } + if stats.Embedded != 5 { + t.Errorf("MaxDocs=5: expected 5 embedded, got %d", stats.Embedded) + } +} + +func TestRunContinuesThroughEmbedErrors(t *testing.T) { + fs := newFakeSolr([]map[string]any{ + {"id": "good", "title": "G", "content": "g"}, + {"id": "bad", "title": "B", "content": "b"}, + }) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{ + Fail: map[string]error{"B\n\nb": fmt.Errorf("simulated failure")}, + } + + stats, err := Run(context.Background(), client, provider, Options{Concurrency: 1}) + if err != nil { + t.Fatalf("run: %v", err) + } + if stats.Errors != 1 { + t.Errorf("expected 1 error, got %d", stats.Errors) + } + if stats.Embedded != 1 { + t.Errorf("expected 1 successful embed, got %d", stats.Embedded) + } + if stats.Written != 1 { + t.Errorf("expected 1 write, got %d", stats.Written) + } +} + +func TestRunStopsWhenNoMissingRemain(t *testing.T) { + // After the first batch updates every doc, the next fetch should return + // zero and we terminate. + fs := newFakeSolr([]map[string]any{ + {"id": "a", "title": "A", "content": "a"}, + }) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{} + + _, err := Run(context.Background(), client, provider, Options{BatchSize: 10}) + if err != nil { + t.Fatalf("run: %v", err) + } + // Expect exactly 2 selects: first returns the doc, second returns empty. + if fs.queries != 2 { + t.Errorf("expected 2 queries (work + empty probe), got %d", fs.queries) + } +} + +func TestRunRefusesWithoutProvider(t *testing.T) { + fs := newFakeSolr(nil) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + client := solr.NewClient(srv.URL) + + if _, err := Run(context.Background(), client, nil, Options{}); err == nil { + t.Error("expected error when provider is nil") + } +} + +func TestBuildEmbedText(t *testing.T) { + cases := []struct{ title, content, want string }{ + {"t", "c", "t\n\nc"}, + {"", "c", "c"}, + {"t", "", "t"}, + {"", "", ""}, + } + for _, c := range cases { + if got := buildEmbedText(c.title, c.content); got != c.want { + t.Errorf("buildEmbedText(%q,%q)=%q, want %q", c.title, c.content, got, c.want) + } + } +} diff --git a/cmd/solr-mem-backfill/main.go b/cmd/solr-mem-backfill/main.go new file mode 100644 index 0000000..a8468e1 --- /dev/null +++ b/cmd/solr-mem-backfill/main.go @@ -0,0 +1,62 @@ +package main + +import ( + "context" + "flag" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/arreyder/solr-mem/internal/embed" + "github.com/arreyder/solr-mem/internal/solr" +) + +func main() { + var ( + solrURL = flag.String("solr-url", getenv("SOLR_URL", "http://localhost:8983/solr/memories"), "Memories collection URL") + batchSize = flag.Int("batch-size", 50, "Docs per Solr page") + concurrency = flag.Int("concurrency", 4, "Parallel embed calls") + dryRun = flag.Bool("dry-run", false, "Compute embeddings but do not write") + maxDocs = flag.Int("max-docs", 0, "Cap total docs embedded (0 = unlimited)") + force = flag.Bool("force", false, "Re-embed docs that already have an embedding") + pauseMS = flag.Int("pause-ms", 0, "Sleep between batches in milliseconds") + ) + flag.Parse() + + provider, err := embed.FromEnv() + if err != nil { + log.Fatalf("embed provider init: %v", err) + } + if provider == nil { + log.Fatalf("no embedding provider configured — set OPENAI_API_KEY") + } + log.Printf("backfill: provider=%s dim=%d", provider.Name(), provider.Dim()) + + client := solr.NewClient(*solrURL) + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + stats, err := Run(ctx, client, provider, Options{ + BatchSize: *batchSize, + Concurrency: *concurrency, + DryRun: *dryRun, + MaxDocs: *maxDocs, + Force: *force, + Pause: time.Duration(*pauseMS) * time.Millisecond, + }) + if err != nil { + log.Fatalf("backfill: %v", err) + } + log.Printf("backfill done: scanned=%d skipped=%d embedded=%d written=%d errors=%d", + stats.Scanned, stats.Skipped, stats.Embedded, stats.Written, stats.Errors) +} + +func getenv(k, def string) string { + if v := os.Getenv(k); v != "" { + return v + } + return def +} diff --git a/cmd/solr-mem-bench/main.go b/cmd/solr-mem-bench/main.go new file mode 100644 index 0000000..d08f52f --- /dev/null +++ b/cmd/solr-mem-bench/main.go @@ -0,0 +1,230 @@ +// solr-mem-bench is a small harness for measuring memory retrieval quality. +// +// Usage: +// +// solr-mem-bench -solr-url http://localhost:8983/solr -collection memories_bench \ +// -corpus testdata/corpus.jsonl -queries testdata/queries.jsonl -seed +// +// The -seed flag clears the target collection and writes the corpus before +// running queries. Without -seed, the harness assumes the collection is +// already populated with the gold IDs referenced by queries. +// +// Gold format (JSONL, one per line): +// +// corpus: {"id":"mem-001","title":"...","content":"...","tags":["..."]} +// queries: {"id":"q-001","text":"database N+1 fix","gold":["mem-007"]} +package main + +import ( + "bufio" + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "time" + + "github.com/arreyder/solr-mem/internal/solr" +) + +type corpusItem struct { + ID string `json:"id"` + Title string `json:"title"` + Content string `json:"content"` + Tags []string `json:"tags,omitempty"` +} + +type queryItem struct { + ID string `json:"id"` + Text string `json:"text"` + Gold []string `json:"gold"` +} + +func main() { + var ( + solrURL = flag.String("solr-url", "http://localhost:8983/solr/memories", "Base URL of the memories collection") + corpusPath = flag.String("corpus", "cmd/solr-mem-bench/testdata/corpus.jsonl", "Path to corpus JSONL") + queryPath = flag.String("queries", "cmd/solr-mem-bench/testdata/queries.jsonl", "Path to queries JSONL") + seed = flag.Bool("seed", false, "Clear the collection and write the corpus before running queries") + topK = flag.Int("topk", 10, "Number of hits to retrieve per query") + variant = flag.String("variant", "bm25", "Retrieval variant name (for labeling output)") + ) + flag.Parse() + + corpus, err := loadCorpus(*corpusPath) + if err != nil { + log.Fatalf("load corpus: %v", err) + } + queries, err := loadQueries(*queryPath) + if err != nil { + log.Fatalf("load queries: %v", err) + } + log.Printf("loaded %d corpus items, %d queries", len(corpus), len(queries)) + + client := solr.NewClient(*solrURL) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + if *seed { + log.Printf("seeding collection at %s (clearing first)", *solrURL) + if err := seedCorpus(ctx, client, corpus); err != nil { + log.Fatalf("seed: %v", err) + } + // Give Solr a moment to commit. + time.Sleep(1 * time.Second) + } + + agg := NewAggregate(1, 3, 5, 10) + perQuery := make([]queryResult, 0, len(queries)) + + start := time.Now() + for _, q := range queries { + hits, err := runQuery(ctx, client, q.Text, *topK) + if err != nil { + log.Printf("query %s failed: %v", q.ID, err) + continue + } + agg.Add(q.Gold, hits) + perQuery = append(perQuery, queryResult{ID: q.ID, Text: q.Text, Hits: hits, Gold: q.Gold}) + } + elapsed := time.Since(start) + + agg.Finalize() + printReport(*variant, agg, perQuery, elapsed) +} + +type queryResult struct { + ID string + Text string + Hits []string + Gold []string +} + +func loadCorpus(path string) ([]corpusItem, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + var out []corpusItem + s := bufio.NewScanner(f) + s.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for s.Scan() { + line := s.Bytes() + if len(line) == 0 { + continue + } + var item corpusItem + if err := json.Unmarshal(line, &item); err != nil { + return nil, fmt.Errorf("parse corpus line: %w", err) + } + out = append(out, item) + } + return out, s.Err() +} + +func loadQueries(path string) ([]queryItem, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + var out []queryItem + s := bufio.NewScanner(f) + s.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for s.Scan() { + line := s.Bytes() + if len(line) == 0 { + continue + } + var item queryItem + if err := json.Unmarshal(line, &item); err != nil { + return nil, fmt.Errorf("parse query line: %w", err) + } + out = append(out, item) + } + return out, s.Err() +} + +// seedCorpus writes the bench corpus to Solr. It scopes its cleanup to +// IDs starting with "bench-" so it is safe to run against a live memories +// collection — it will only ever touch its own namespaced docs. +func seedCorpus(ctx context.Context, client *solr.Client, items []corpusItem) error { + for _, it := range items { + if !isBenchID(it.ID) { + return fmt.Errorf("corpus id %q must start with 'bench-' for safe seeding", it.ID) + } + } + if err := client.DeleteByQuery(ctx, "id:bench-*"); err != nil { + return fmt.Errorf("clear bench docs: %w", err) + } + docs := make([]solr.Document, 0, len(items)) + now := time.Now().UTC() + for _, it := range items { + docs = append(docs, solr.Document{ + ID: it.ID, + Title: it.Title, + Content: it.Content, + Tags: it.Tags, + CreatedAt: now, + UpdatedAt: now, + Lifetime: "permanent", + Format: "prose", + }) + } + return client.Add(ctx, docs...) +} + +func isBenchID(s string) bool { + return len(s) > 6 && s[:6] == "bench-" +} + +func runQuery(ctx context.Context, client *solr.Client, text string, topK int) ([]string, error) { + params := solr.QueryParams{ + Query: text, + Rows: topK, + Fields: []string{"id"}, + } + resp, err := client.Query(ctx, params) + if err != nil { + return nil, err + } + out := make([]string, 0, len(resp.Docs)) + for _, d := range resp.Docs { + if id, ok := d["id"].(string); ok { + out = append(out, id) + } + } + return out, nil +} + +func printReport(variant string, agg *Aggregate, results []queryResult, elapsed time.Duration) { + fmt.Printf("# solr-mem retrieval benchmark\n\n") + fmt.Printf("Variant: `%s` | Queries: %d | Elapsed: %s\n\n", variant, agg.Queries, elapsed.Round(time.Millisecond)) + fmt.Printf("| Metric | Value |\n|---|---|\n") + for _, k := range []int{1, 3, 5, 10} { + if v, ok := agg.RecallAt[k]; ok { + fmt.Printf("| R@%d | %.3f |\n", k, v) + } + } + fmt.Printf("| MRR | %.3f |\n\n", agg.MRR) + + fmt.Printf("## Per-query breakdown\n\n") + fmt.Printf("| Query | Gold | R@5 | MRR | Top hits |\n|---|---|---|---|---|\n") + for _, r := range results { + top := r.Hits + if len(top) > 5 { + top = top[:5] + } + fmt.Printf("| %s (%s) | %v | %.2f | %.2f | %v |\n", + r.ID, truncate(r.Text, 40), r.Gold, RecallAtK(r.Gold, r.Hits, 5), MRR(r.Gold, r.Hits), top) + } +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n-1] + "…" +} diff --git a/cmd/solr-mem-bench/metrics.go b/cmd/solr-mem-bench/metrics.go new file mode 100644 index 0000000..6853dcc --- /dev/null +++ b/cmd/solr-mem-bench/metrics.go @@ -0,0 +1,78 @@ +package main + +// RecallAtK returns the fraction of gold IDs that appear in the top-K hits. +// gold is the set of relevant IDs for a query; hits is the ranked result list. +func RecallAtK(gold, hits []string, k int) float64 { + if len(gold) == 0 { + return 0 + } + if k <= 0 || k > len(hits) { + k = len(hits) + } + goldSet := make(map[string]struct{}, len(gold)) + for _, g := range gold { + goldSet[g] = struct{}{} + } + found := 0 + for i := 0; i < k; i++ { + if _, ok := goldSet[hits[i]]; ok { + found++ + } + } + return float64(found) / float64(len(gold)) +} + +// MRR returns the reciprocal rank of the first gold hit in the ranked list, +// or 0 if no gold item was retrieved. Ranks are 1-indexed. +func MRR(gold, hits []string) float64 { + if len(gold) == 0 || len(hits) == 0 { + return 0 + } + goldSet := make(map[string]struct{}, len(gold)) + for _, g := range gold { + goldSet[g] = struct{}{} + } + for i, h := range hits { + if _, ok := goldSet[h]; ok { + return 1.0 / float64(i+1) + } + } + return 0 +} + +// Aggregate holds per-run averaged metrics. +type Aggregate struct { + Queries int + RecallAt map[int]float64 // k -> mean recall + MRR float64 +} + +// NewAggregate accumulates per-query metrics over a set of queries. +func NewAggregate(ks ...int) *Aggregate { + m := make(map[int]float64, len(ks)) + for _, k := range ks { + m[k] = 0 + } + return &Aggregate{RecallAt: m} +} + +// Add records one query's hits against its gold. +func (a *Aggregate) Add(gold, hits []string) { + a.Queries++ + for k := range a.RecallAt { + a.RecallAt[k] += RecallAtK(gold, hits, k) + } + a.MRR += MRR(gold, hits) +} + +// Finalize divides sums by query count, producing mean metrics. +func (a *Aggregate) Finalize() { + if a.Queries == 0 { + return + } + n := float64(a.Queries) + for k, v := range a.RecallAt { + a.RecallAt[k] = v / n + } + a.MRR /= n +} diff --git a/cmd/solr-mem-bench/metrics_test.go b/cmd/solr-mem-bench/metrics_test.go new file mode 100644 index 0000000..a447ee5 --- /dev/null +++ b/cmd/solr-mem-bench/metrics_test.go @@ -0,0 +1,111 @@ +package main + +import ( + "math" + "testing" +) + +func approx(a, b float64) bool { return math.Abs(a-b) < 1e-9 } + +func TestRecallAtKPerfect(t *testing.T) { + gold := []string{"a", "b"} + hits := []string{"a", "b", "c", "d", "e"} + if got := RecallAtK(gold, hits, 5); !approx(got, 1.0) { + t.Errorf("R@5 should be 1.0, got %f", got) + } +} + +func TestRecallAtKPartial(t *testing.T) { + gold := []string{"a", "b", "c"} + hits := []string{"a", "x", "y", "z", "b"} + // 2 of 3 gold found in top 5. + if got := RecallAtK(gold, hits, 5); !approx(got, 2.0/3.0) { + t.Errorf("expected 2/3, got %f", got) + } + // Only 1 in top 2. + if got := RecallAtK(gold, hits, 2); !approx(got, 1.0/3.0) { + t.Errorf("expected 1/3, got %f", got) + } +} + +func TestRecallAtKNoHits(t *testing.T) { + gold := []string{"a"} + hits := []string{"x", "y", "z"} + if got := RecallAtK(gold, hits, 3); !approx(got, 0) { + t.Errorf("expected 0, got %f", got) + } +} + +func TestRecallAtKKLargerThanHits(t *testing.T) { + gold := []string{"a"} + hits := []string{"a", "b"} + if got := RecallAtK(gold, hits, 10); !approx(got, 1.0) { + t.Errorf("k > len(hits) should not error, got %f", got) + } +} + +func TestRecallAtKEmptyGold(t *testing.T) { + if got := RecallAtK(nil, []string{"a"}, 5); got != 0 { + t.Errorf("no gold → 0, got %f", got) + } +} + +func TestMRRFirstPosition(t *testing.T) { + gold := []string{"a"} + hits := []string{"a", "b", "c"} + if got := MRR(gold, hits); !approx(got, 1.0) { + t.Errorf("first-rank gold → 1.0, got %f", got) + } +} + +func TestMRRThirdPosition(t *testing.T) { + gold := []string{"c"} + hits := []string{"a", "b", "c"} + if got := MRR(gold, hits); !approx(got, 1.0/3.0) { + t.Errorf("rank 3 → 1/3, got %f", got) + } +} + +func TestMRRNoMatch(t *testing.T) { + gold := []string{"z"} + hits := []string{"a", "b", "c"} + if got := MRR(gold, hits); got != 0 { + t.Errorf("no match → 0, got %f", got) + } +} + +func TestMRRMultipleGoldTakesBest(t *testing.T) { + // Two gold items; MRR uses the best (earliest) rank. + gold := []string{"c", "a"} + hits := []string{"a", "b", "c"} + if got := MRR(gold, hits); !approx(got, 1.0) { + t.Errorf("best rank should win: expected 1.0, got %f", got) + } +} + +func TestAggregateEndToEnd(t *testing.T) { + agg := NewAggregate(5, 10) + agg.Add([]string{"a"}, []string{"a", "b", "c"}) // R@5=1, R@10=1, MRR=1 + agg.Add([]string{"x"}, []string{"a", "b", "c"}) // all zero + agg.Add([]string{"c"}, []string{"a", "b", "c"}) // R@5=1, R@10=1, MRR=1/3 + agg.Finalize() + + // Means across 3 queries: R@5 = 2/3, R@10 = 2/3, MRR = (1 + 0 + 1/3)/3 = 4/9 + if !approx(agg.RecallAt[5], 2.0/3.0) { + t.Errorf("R@5 mean expected 2/3, got %f", agg.RecallAt[5]) + } + if !approx(agg.RecallAt[10], 2.0/3.0) { + t.Errorf("R@10 mean expected 2/3, got %f", agg.RecallAt[10]) + } + if !approx(agg.MRR, 4.0/9.0) { + t.Errorf("MRR mean expected 4/9, got %f", agg.MRR) + } +} + +func TestAggregateEmpty(t *testing.T) { + agg := NewAggregate(5) + agg.Finalize() // should not panic on zero queries + if agg.MRR != 0 || agg.RecallAt[5] != 0 { + t.Errorf("empty aggregate should stay zero") + } +} diff --git a/cmd/solr-mem-bench/testdata/corpus.jsonl b/cmd/solr-mem-bench/testdata/corpus.jsonl new file mode 100644 index 0000000..b71784f --- /dev/null +++ b/cmd/solr-mem-bench/testdata/corpus.jsonl @@ -0,0 +1,30 @@ +{"id":"bench-001","title":"JWT auth via jose middleware","content":"Project uses jose (not jsonwebtoken) for Edge runtime compatibility. Middleware in src/middleware/auth.ts validates Authorization: Bearer tokens, decodes JWT claims, attaches user to request context. Token signing uses RS256 with keys rotated weekly.","tags":["auth","jwt","middleware","edge"]} +{"id":"bench-002","title":"N+1 query fix on orders list","content":"OrdersController.list was fetching each order's customer in a separate SELECT, causing N+1. Switched to a single JOIN with customers table. p95 latency dropped from 1200ms to 140ms on the orders index page. Related to ticket ORD-413.","tags":["database","performance","n+1","orders"]} +{"id":"bench-003","title":"Rate limiting with Redis token bucket","content":"API gateway uses a sliding-window token-bucket algorithm backed by Redis. 100 requests per minute per API key, 1000 per org. Implementation in pkg/ratelimit/bucket.go. Tokens refill in 600ms increments to avoid thundering herd at minute boundaries.","tags":["rate-limit","redis","api","gateway"]} +{"id":"bench-004","title":"Postgres connection pool sizing","content":"Production pool maxed at 100 connections per service instance, 5 instances → 500 total. PgBouncer in transaction mode in front. Lowering pool size to 50 per instance reduced idle connection overhead by 30% with no throughput loss.","tags":["postgres","connection-pool","pgbouncer","performance"]} +{"id":"bench-005","title":"React context provider pattern for feature flags","content":"FeatureFlagProvider wraps the app root, hydrates flags once from /api/flags at mount, exposes useFeatureFlag(name) hook. Flag changes from admin UI push via WebSocket and trigger provider re-render. Stale-while-revalidate fallback to localStorage for offline.","tags":["react","feature-flags","websocket","frontend"]} +{"id":"bench-006","title":"Memory leak in WebSocket reconnect handler","content":"The reconnect exponential-backoff timer kept a closure reference to the old socket, preventing GC. Fix: explicitly null socket.onmessage and socket.onclose in cleanup, store the current timer ID outside the closure so successive retries replace it.","tags":["websocket","memory-leak","javascript","frontend"]} +{"id":"bench-007","title":"Slow Solr range facet on timestamp field","content":"facet.range on created_at with gap=+1DAY over 90 days was taking 3s. Switched to docValues=true on the pdate field and enabled facet.method=dv. Latency dropped to 80ms. Applies to any range faceting on high-cardinality date fields.","tags":["solr","facet","performance","date"]} +{"id":"bench-008","title":"Docker Compose healthcheck for Solr","content":"solr service in docker-compose.yml needs healthcheck: curl -sf http://localhost:8983/solr/admin/ping. Without it, dependent services start before the configset is loaded, causing schema-not-found errors on first run. 10s interval, 5 retries.","tags":["docker","solr","healthcheck","devops"]} +{"id":"bench-009","title":"TypeScript strict mode migration","content":"Migrated tsconfig with strict: true in stages: strictNullChecks first (revealed 180 null-deref bugs), then noImplicitAny (caught 40 places passing any to third-party APIs). Suppress with // @ts-expect-error initially, burn down incrementally.","tags":["typescript","migration","strict","types"]} +{"id":"bench-010","title":"Kafka consumer lag alerting","content":"Alert fires when consumer group lag exceeds 10k for 5 minutes on any partition. Check by: kafka-consumer-groups --describe --group X. Most common root cause: slow downstream DB writes. Fix is usually raising batch size, not scaling consumers.","tags":["kafka","alerting","lag","monitoring"]} +{"id":"bench-011","title":"GraphQL N+1 via DataLoader","content":"Nested resolvers for User.posts and Post.comments were firing one query per row. Wrapped each resolver in a DataLoader that batches IDs per request. Cut DB round-trips on the feed query from 300 to 3.","tags":["graphql","n+1","dataloader","database"]} +{"id":"bench-012","title":"AWS S3 presigned URL expiry","content":"Upload URLs signed with 15-min expiry. Users on slow networks occasionally fail right at the boundary. Bumped to 60 min for uploads, kept 5 min for downloads. Monitoring the RequestExpired 4xx count on the bucket metrics dashboard.","tags":["aws","s3","presigned-url","upload"]} +{"id":"bench-013","title":"Go generics for repository pattern","content":"Replaced a dozen Find/Create/Delete methods across 6 repos with one generic[T any] Repository struct. Used constraints.Ordered for sortable fields. Reduced LOC by ~40% and caught two places where type mismatches were silently compiling.","tags":["go","generics","repository","refactor"]} +{"id":"bench-014","title":"Flaky integration test on CI","content":"auth_integration_test.go failed 1-in-20 on GitHub Actions, always passed locally. Root cause: test reused a Redis key across parallel goroutines; local runs happened to serialize. Fix: use t.Name() as key prefix to isolate.","tags":["flaky","test","redis","ci"]} +{"id":"bench-015","title":"CORS preflight cache tuning","content":"Access-Control-Max-Age defaulted to 5s; browsers were firing OPTIONS every request. Raised to 86400 (1 day). Preflight traffic dropped 98%. Works well because our CORS config is stable.","tags":["cors","http","performance","browser"]} +{"id":"bench-016","title":"Pulumi stack for ephemeral preview envs","content":"Each PR gets a preview environment: Pulumi stack named pr-, S3 bucket, RDS serverless, domain cnamed to preview.example.com. Cleanup lambda removes stacks after PR close or 7 days inactivity. Cost per stack: ~$2/day.","tags":["pulumi","infra","preview","aws"]} +{"id":"bench-017","title":"Prometheus histogram bucket tuning","content":"Default buckets are tuned for web latency (le=0.005..10). For background job duration we used le=1,5,10,30,60,120,300. Too-coarse buckets hide p99 changes; too-fine explodes time-series cardinality. Revisit quarterly.","tags":["prometheus","metrics","histogram","monitoring"]} +{"id":"bench-018","title":"ElasticSearch to Solr migration","content":"Migrated search from ES 7.x to Solr 9. Key differences: Solr uses managed-schema.xml vs ES mappings, edismax is the edismax parser (no dis_max clause), field boosting syntax is title^2 not title:(query)^2. Faceting syntax maps cleanly 1:1.","tags":["elasticsearch","solr","migration","search"]} +{"id":"bench-019","title":"Stripe webhook idempotency","content":"Store the webhook event ID in a processed_events table with a unique constraint. Reject duplicates with 200 (not 4xx, Stripe will retry). 90-day retention then expire. Saw duplicate deliveries ~0.3% of the time.","tags":["stripe","webhook","idempotency","payments"]} +{"id":"bench-020","title":"pgx vs database/sql connection handling","content":"Switched from lib/pq to pgx/v5. pgx.Conn is stateful per-connection (prepared statements cached locally) so it doesn't play with PgBouncer in transaction mode. Use pgx.Pool with simple_protocol=true or stay with database/sql.","tags":["postgres","pgx","go","connection"]} +{"id":"bench-021","title":"nginx buffer size for large POST bodies","content":"Uploads >1MB were hitting client_body_buffer_size limits and spilling to /tmp on nginx hosts. Raised to 10m and moved temp dir to tmpfs. 413 errors went to zero. Also bumped client_max_body_size to 100m for the upload endpoint.","tags":["nginx","upload","buffer","http"]} +{"id":"bench-022","title":"Redis key prefix conventions","content":"We use :: (e.g. auth:session:abc123). Makes KEYS/SCAN across a namespace easy and lets us purge per-service with a single pattern. Avoid ::: double-colons; Redis pattern matching is happier with single separators.","tags":["redis","conventions","cache","keys"]} +{"id":"bench-023","title":"Go pprof on production pods","content":"Enabled net/http/pprof on /debug/pprof, bound to localhost only, exposed via kubectl port-forward on demand. Used pprof -http to analyze heap dumps. Found a string-append loop in JSON marshaling that was 40% of allocations.","tags":["go","pprof","performance","profiling"]} +{"id":"bench-024","title":"Datadog log pattern cardinality","content":"Log messages containing user-supplied UUIDs blew up pattern aggregation. Now we scrub UUIDs to in log formatter before shipping. Cardinality on the /patterns query dropped 100x, making anomaly detection useful again.","tags":["datadog","logs","cardinality","observability"]} +{"id":"bench-025","title":"Linear API rate limit","content":"Linear's GraphQL API caps at 1500 requests per hour per API key. Issue sync job batches 50 issues per request. Retry-After header is respected. Hit the cap once during a full re-import — now we paginate with 1s jitter between pages.","tags":["linear","api","rate-limit","integration"]} +{"id":"bench-026","title":"macOS launchd user agent for background service","content":"Created ~/Library/LaunchAgents/com.example.service.plist with KeepAlive and RunAtLoad. Logs to /tmp, environment vars set under EnvironmentVariables dict. launchctl load starts immediately and on login. Beats writing a custom systemd wrapper.","tags":["macos","launchd","service","daemon"]} +{"id":"bench-027","title":"Frontend bundle size budget enforcement","content":"bundle-analyzer in CI fails if main.js > 400KB gzipped. Broke twice this quarter — once from an accidental import of all of lodash, once from a moment.js locale bundle. Now we lint for barrel imports and use dayjs exclusively.","tags":["frontend","bundle","webpack","ci"]} +{"id":"bench-028","title":"PostgreSQL partial index for sparse column","content":"created_at IS NOT NULL on a 50M row table; only ~200k rows have it set. Regular btree took 800MB; partial index (WHERE created_at IS NOT NULL) is 12MB. Query planner picks it correctly after ANALYZE.","tags":["postgres","index","partial","performance"]} +{"id":"bench-029","title":"Claude Code hooks for auto-capture","content":"settings.json PostToolUse hook runs a shell script that parses the tool event JSON, filters for file-touching tools, and POSTs to a local observer endpoint. UserPromptSubmit mirrors prompts to a journaling log. Zero-effort session recording.","tags":["claude-code","hooks","observability","agent"]} +{"id":"bench-030","title":"Solr more-like-this for related memories","content":"MLT on title,content with mintf=2 mindf=2 gives decent recall for 'memories similar to X'. Filter out the source doc id. For memory system: raise mintf for short memories so we don't match on single common tokens.","tags":["solr","mlt","similarity","search"]} diff --git a/cmd/solr-mem-bench/testdata/queries.jsonl b/cmd/solr-mem-bench/testdata/queries.jsonl new file mode 100644 index 0000000..df1a2b5 --- /dev/null +++ b/cmd/solr-mem-bench/testdata/queries.jsonl @@ -0,0 +1,25 @@ +{"id":"q-01","text":"JWT authentication middleware","gold":["bench-001"]} +{"id":"q-02","text":"rate limiting API gateway","gold":["bench-003"]} +{"id":"q-03","text":"Postgres connection pool sizing","gold":["bench-004"]} +{"id":"q-04","text":"feature flag React provider","gold":["bench-005"]} +{"id":"q-05","text":"Stripe webhook idempotency","gold":["bench-019"]} +{"id":"q-06","text":"Kafka consumer lag alerts","gold":["bench-010"]} +{"id":"q-07","text":"database performance optimization","gold":["bench-002","bench-011","bench-028","bench-007"]} +{"id":"q-08","text":"secure token validation edge runtime","gold":["bench-001"]} +{"id":"q-09","text":"fixing slow search facets on dates","gold":["bench-007"]} +{"id":"q-10","text":"preventing duplicate payment processing","gold":["bench-019"]} +{"id":"q-11","text":"browser preflight caching","gold":["bench-015"]} +{"id":"q-12","text":"monitoring job duration histograms","gold":["bench-017"]} +{"id":"q-13","text":"find related records efficiently","gold":["bench-011","bench-030"]} +{"id":"q-14","text":"fixing a memory leak in reconnection logic","gold":["bench-006"]} +{"id":"q-15","text":"ephemeral environments per pull request","gold":["bench-016"]} +{"id":"q-16","text":"upload large files HTTP body limit","gold":["bench-021","bench-012"]} +{"id":"q-17","text":"type safety migration staged rollout","gold":["bench-009"]} +{"id":"q-18","text":"flaky tests on continuous integration","gold":["bench-014"]} +{"id":"q-19","text":"moving from Elasticsearch","gold":["bench-018"]} +{"id":"q-20","text":"launchd plist for a background service","gold":["bench-026"]} +{"id":"q-21","text":"Go profiling heap allocations","gold":["bench-023"]} +{"id":"q-22","text":"N+1 query with GraphQL resolvers","gold":["bench-011","bench-002"]} +{"id":"q-23","text":"automatic observation capture from agent hooks","gold":["bench-029"]} +{"id":"q-24","text":"generic repository pattern in Go","gold":["bench-013"]} +{"id":"q-25","text":"scrubbing high-cardinality identifiers from logs","gold":["bench-024"]} diff --git a/cmd/solr-mem-server/broker.go b/cmd/solr-mem-server/broker.go index 7a55270..ad77584 100644 --- a/cmd/solr-mem-server/broker.go +++ b/cmd/solr-mem-server/broker.go @@ -36,6 +36,11 @@ const ( // defaultRunSweeperInterval is how often the stale-run sweeper runs. defaultRunSweeperInterval = 5 * time.Minute + + // brokerSessionCap is the max number of memory items per session_id in a + // single packet. Code items (no session) are unaffected. With maxItems=5 + // this leaves room for at least 3 distinct sessions (or a mix with code). + brokerSessionCap = 2 ) // WorkObservation is a single structured report from a worker agent. @@ -65,6 +70,7 @@ type MemoryPacketItem struct { MemoryType string `json:"memory_type,omitempty"` FilePath string `json:"file_path,omitempty"` SymbolName string `json:"symbol_name,omitempty"` + SessionID string `json:"session_id,omitempty"` // populated for memory items; used for session-diversified ranking } // MemoryPacket is a precomputed bundle of relevant context for a worker agent. @@ -320,13 +326,17 @@ func (b *Broker) doBuild(obs WorkObservation, buildSeq int) *MemoryPacket { // 3. Score and dedupe. scored := scoreCandidates(candidates, obs) - // 4. Pick top items (max 5). + // 4. Cap per session so one chatty session can't dominate the packet. + // Code items have empty SessionID and are unaffected. + scored = diversifyBySession(scored, func(i MemoryPacketItem) string { return i.SessionID }, brokerSessionCap) + + // 5. Pick top items (max 5). const maxItems = 5 if len(scored) > maxItems { scored = scored[:maxItems] } - // 5. Determine delivery class. + // 6. Determine delivery class. delivery := DeliveryCheckpoint for _, item := range scored { // Promote to interrupt if we found a high-relevance hazard or prior solution. @@ -336,7 +346,7 @@ func (b *Broker) doBuild(obs WorkObservation, buildSeq int) *MemoryPacket { } } - // 6. Build summary. + // 7. Build summary. summary := buildPacketSummary(scored, obs) return &MemoryPacket{ @@ -371,6 +381,7 @@ func (b *Broker) searchMemories(ctx context.Context, queryTerms string, obs Work title, _ := doc["title"].(string) content, _ := doc["content"].(string) memType, _ := doc["memory_type"].(string) + sessionID, _ := doc["session_id"].(string) tags := getStringSliceFromDoc(doc, "tags") // Truncate content for summary. @@ -388,6 +399,7 @@ func (b *Broker) searchMemories(ctx context.Context, queryTerms string, obs Work Reason: "memory search hit", Tags: tags, MemoryType: memType, + SessionID: sessionID, }) } return items diff --git a/cmd/solr-mem-server/broker_tool.go b/cmd/solr-mem-server/broker_tool.go index 4b5de62..0f1ee69 100644 --- a/cmd/solr-mem-server/broker_tool.go +++ b/cmd/solr-mem-server/broker_tool.go @@ -14,17 +14,24 @@ func observeWorkTool(broker *Broker) ToolHandler { return nil, fmt.Errorf("run_id is required") } + // Scrub free-text fields; entities/code_refs/repo/phase are symbolic + // and unlikely to carry secrets. + task, _ := scrubString(getString(args, "task")) + subgoal, _ := scrubString(getString(args, "subgoal")) + uncertainty, _ := scrubString(getString(args, "uncertainty")) + nextAction, _ := scrubString(getString(args, "next_action")) + obs := WorkObservation{ RunID: runID, AgentID: getString(args, "agent_id"), Repo: getString(args, "repo"), Phase: getString(args, "phase"), - Task: getString(args, "task"), - Subgoal: getString(args, "subgoal"), + Task: task, + Subgoal: subgoal, Entities: getStringSlice(args, "entities"), CodeRefs: getStringSlice(args, "code_refs"), - Uncertainty: getString(args, "uncertainty"), - NextAction: getString(args, "next_action"), + Uncertainty: uncertainty, + NextAction: nextAction, } result := broker.Observe(obs) diff --git a/cmd/solr-mem-server/bulk_store_tool.go b/cmd/solr-mem-server/bulk_store_tool.go index df8ebfb..7af99cf 100644 --- a/cmd/solr-mem-server/bulk_store_tool.go +++ b/cmd/solr-mem-server/bulk_store_tool.go @@ -3,8 +3,11 @@ package main import ( "context" "fmt" + "log" "time" + "github.com/arreyder/solr-mem/internal/contenthash" + "github.com/arreyder/solr-mem/internal/privacy" "github.com/arreyder/solr-mem/internal/solr" "github.com/google/uuid" ) @@ -26,6 +29,11 @@ func bulkStoreMemoriesTool(ctx context.Context, args map[string]any) (any, error now := time.Now().UTC() var docs []solr.Document var ids []string + var duplicates []map[string]any + totalScrubbed := 0 + + windowSec := getInt(args, "dedup_window_seconds", defaultDedupWindowSec) + onDup := normalizeOnDuplicate(getString(args, "on_duplicate")) for i, raw := range memories { m, ok := raw.(map[string]any) @@ -41,43 +49,94 @@ func bulkStoreMemoriesTool(ctx context.Context, args map[string]any) (any, error lifetime := normalizeLifetime(getString(m, "lifetime")) expiresAt := resolveExpiration(lifetime, getString(m, "expires_at")) - id := uuid.New().String() - ids = append(ids, id) - format := getString(m, "format") if format == "" { format = "prose" } + title := getString(m, "title") + metadata := getString(m, "metadata") + scrubbedContent, contentHits := scrubString(content) + scrubbedTitle, titleHits := scrubString(title) + allHits := privacy.MergeHits(contentHits, titleHits) + metadata = privacy.MergeMetadata(metadata, allHits) + for _, v := range allHits { + totalScrubbed += v + } + + tags := getStringSlice(m, "tags") + hash := contenthash.Compute(scrubbedTitle, scrubbedContent, tags) + + if hash != "" && windowSec > 0 && onDup != "force" { + existingID, err := findRecentByHash(ctx, solrClient, hash, windowSec) + if err != nil { + log.Printf("bulk_store: dedup lookup failed at index %d (continuing as new): %v", i, err) + } else if existingID != "" { + action := "skipped" + if onDup == "merge" { + upd := map[string]any{"updated_at": now.Format(time.RFC3339)} + if err := solrClient.Update(ctx, existingID, upd); err != nil { + log.Printf("bulk_store: merge update failed at index %d: %v", i, err) + } else { + action = "merged" + } + } + duplicates = append(duplicates, map[string]any{ + "index": i, + "id": existingID, + "action": action, + }) + continue + } + } + + id := uuid.New().String() + ids = append(ids, id) + docs = append(docs, solr.Document{ - ID: id, - AgentID: getString(m, "agent_id"), - MemoryType: getString(m, "memory_type"), - Content: content, - Title: getString(m, "title"), - Tags: getStringSlice(m, "tags"), - Source: getString(m, "source"), - Importance: getFloat(m, "importance", 0.5), - Metadata: getString(m, "metadata"), - CreatedAt: now, - UpdatedAt: now, - ExpiresAt: expiresAt, - Lifetime: lifetime, - SessionID: getString(m, "session_id"), - RelatedIDs: getStringSlice(m, "related_ids"), - Format: format, + ID: id, + AgentID: getString(m, "agent_id"), + MemoryType: getString(m, "memory_type"), + Content: scrubbedContent, + Title: scrubbedTitle, + Tags: tags, + Source: getString(m, "source"), + Importance: getFloat(m, "importance", 0.5), + Metadata: metadata, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: expiresAt, + Lifetime: lifetime, + SessionID: getString(m, "session_id"), + RelatedIDs: getStringSlice(m, "related_ids"), + Format: format, + ContentHash: hash, + Embedding: embedForStore(ctx, scrubbedTitle, scrubbedContent), }) } - if err := solrClient.Add(ctx, docs...); err != nil { - return nil, fmt.Errorf("failed to bulk store memories: %w", err) + if len(docs) > 0 { + if err := solrClient.Add(ctx, docs...); err != nil { + return nil, fmt.Errorf("failed to bulk store memories: %w", err) + } + } + + text := fmt.Sprintf("Stored %d memories.\nIDs: %v", len(docs), ids) + structured := map[string]any{ + "count": len(docs), + "ids": ids, + } + if len(duplicates) > 0 { + text += fmt.Sprintf("\nDedup: %d duplicates within %ds window", len(duplicates), windowSec) + structured["duplicates"] = duplicates + } + if totalScrubbed > 0 { + text += fmt.Sprintf("\nPrivacy: redacted %d secret(s) across all memories", totalScrubbed) + structured["scrub_count"] = totalScrubbed } return ToolOutput{ - Text: fmt.Sprintf("Successfully stored %d memories.\nIDs: %v", len(docs), ids), - Structured: map[string]any{ - "count": len(docs), - "ids": ids, - }, + Text: text, + Structured: structured, }, nil } diff --git a/cmd/solr-mem-server/dedup.go b/cmd/solr-mem-server/dedup.go new file mode 100644 index 0000000..6ec9d1a --- /dev/null +++ b/cmd/solr-mem-server/dedup.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + "fmt" + + "github.com/arreyder/solr-mem/internal/solr" +) + +// defaultDedupWindowSec is the default window for content-hash dedup on writes. +// Callers can override per-call with dedup_window_seconds; 0 disables. +const defaultDedupWindowSec = 300 + +// findRecentByHash returns the ID of the most recent memory with content_hash +// equal to hash and created_at within the last windowSec seconds. Empty string +// means no match (or the lookup errored — callers treat it as "not a dup"). +func findRecentByHash(ctx context.Context, client *solr.Client, hash string, windowSec int) (string, error) { + if hash == "" || windowSec <= 0 || client == nil { + return "", nil + } + params := solr.QueryParams{ + Query: "*:*", + FilterQueries: []string{ + fmt.Sprintf("content_hash:%s", hash), + fmt.Sprintf("created_at:[NOW-%dSECONDS TO *]", windowSec), + }, + Fields: []string{"id"}, + Sort: "created_at desc", + Rows: 1, + } + resp, err := client.Query(ctx, params) + if err != nil { + return "", err + } + if len(resp.Docs) == 0 { + return "", nil + } + id, _ := resp.Docs[0]["id"].(string) + return id, nil +} + +// normalizeOnDuplicate canonicalizes the on_duplicate argument. +// "" / unrecognized → "skip". +func normalizeOnDuplicate(s string) string { + switch s { + case "skip", "merge", "force": + return s + default: + return "skip" + } +} diff --git a/cmd/solr-mem-server/dedup_test.go b/cmd/solr-mem-server/dedup_test.go new file mode 100644 index 0000000..5b61f19 --- /dev/null +++ b/cmd/solr-mem-server/dedup_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "testing" +) + +func TestNormalizeOnDuplicate(t *testing.T) { + cases := []struct{ in, want string }{ + {"", "skip"}, + {"skip", "skip"}, + {"merge", "merge"}, + {"force", "force"}, + {"garbage", "skip"}, + {"SKIP", "skip"}, // case-sensitive — only lowercase forms are accepted + } + for _, c := range cases { + if got := normalizeOnDuplicate(c.in); got != c.want { + t.Errorf("normalizeOnDuplicate(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestFindRecentByHashNoopPaths(t *testing.T) { + // Empty hash, zero window, or nil client should all short-circuit without error. + if id, err := findRecentByHash(context.Background(), nil, "", 0); id != "" || err != nil { + t.Errorf("nil inputs should noop, got id=%q err=%v", id, err) + } + if id, err := findRecentByHash(context.Background(), nil, "abc", 300); id != "" || err != nil { + t.Errorf("nil client should noop, got id=%q err=%v", id, err) + } +} diff --git a/cmd/solr-mem-server/diversify.go b/cmd/solr-mem-server/diversify.go new file mode 100644 index 0000000..2d4a71c --- /dev/null +++ b/cmd/solr-mem-server/diversify.go @@ -0,0 +1,28 @@ +package main + +// diversifyBySession caps the number of items per session in a ranked slice. +// Items with an empty group key are passed through without counting — callers +// use this for hits that have no natural grouping (e.g. code docs without a +// session_id). +// +// cap <= 0 disables diversification and returns the input unchanged. +func diversifyBySession[T any](items []T, key func(T) string, cap int) []T { + if cap <= 0 || len(items) == 0 { + return items + } + counts := make(map[string]int, len(items)) + out := make([]T, 0, len(items)) + for _, item := range items { + k := key(item) + if k == "" { + out = append(out, item) + continue + } + if counts[k] >= cap { + continue + } + counts[k]++ + out = append(out, item) + } + return out +} diff --git a/cmd/solr-mem-server/diversify_test.go b/cmd/solr-mem-server/diversify_test.go new file mode 100644 index 0000000..668efbc --- /dev/null +++ b/cmd/solr-mem-server/diversify_test.go @@ -0,0 +1,96 @@ +package main + +import "testing" + +type divItem struct { + id string + session string +} + +func ids(items []divItem) []string { + out := make([]string, len(items)) + for i, it := range items { + out[i] = it.id + } + return out +} + +func sessionKey(it divItem) string { return it.session } + +func TestDiversifyBySessionCapsPerSession(t *testing.T) { + in := []divItem{ + {"a", "s1"}, + {"b", "s1"}, + {"c", "s1"}, // should be dropped at cap=2 + {"d", "s2"}, + {"e", "s1"}, // dropped + {"f", "s2"}, + {"g", "s3"}, + } + got := diversifyBySession(in, sessionKey, 2) + want := []string{"a", "b", "d", "f", "g"} + if !equalStrings(ids(got), want) { + t.Errorf("cap=2: got %v, want %v", ids(got), want) + } +} + +func TestDiversifyBySessionPreservesOrder(t *testing.T) { + // Top-ranked items should be kept over later same-session items. + in := []divItem{ + {"top", "s1"}, + {"mid", "s2"}, + {"also1", "s1"}, + {"low1", "s1"}, // dropped at cap=2 + {"low2", "s1"}, // dropped + } + got := diversifyBySession(in, sessionKey, 2) + want := []string{"top", "mid", "also1"} + if !equalStrings(ids(got), want) { + t.Errorf("order: got %v, want %v", ids(got), want) + } +} + +func TestDiversifyBySessionEmptyKeyNotCapped(t *testing.T) { + // Items with empty key (e.g. code docs) pass through without counting. + in := []divItem{ + {"a", ""}, + {"b", ""}, + {"c", ""}, + {"d", "s1"}, + {"e", "s1"}, + {"f", "s1"}, // dropped + } + got := diversifyBySession(in, sessionKey, 2) + want := []string{"a", "b", "c", "d", "e"} + if !equalStrings(ids(got), want) { + t.Errorf("empty key: got %v, want %v", ids(got), want) + } +} + +func TestDiversifyBySessionDisabled(t *testing.T) { + in := []divItem{{"a", "s1"}, {"b", "s1"}, {"c", "s1"}} + if got := diversifyBySession(in, sessionKey, 0); len(got) != 3 { + t.Errorf("cap=0 should passthrough, got %d items", len(got)) + } + if got := diversifyBySession(in, sessionKey, -1); len(got) != 3 { + t.Errorf("cap<0 should passthrough, got %d items", len(got)) + } +} + +func TestDiversifyBySessionEmpty(t *testing.T) { + if got := diversifyBySession[divItem](nil, sessionKey, 2); got != nil { + t.Errorf("nil input: got %v, want nil", got) + } +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/cmd/solr-mem-server/embed_helper.go b/cmd/solr-mem-server/embed_helper.go new file mode 100644 index 0000000..ca77b4a --- /dev/null +++ b/cmd/solr-mem-server/embed_helper.go @@ -0,0 +1,39 @@ +package main + +import ( + "context" + "log" + "strings" + "time" +) + +// embedForStore produces an embedding for a memory write. It combines title +// and content (title first, it's usually the most informative signal) and +// calls the configured provider with a short timeout. On error or when no +// provider is configured, returns nil so callers can still store the doc. +func embedForStore(ctx context.Context, title, content string) []float32 { + if embedProvider == nil { + return nil + } + text := strings.TrimSpace(title) + if content != "" { + if text != "" { + text += "\n\n" + } + text += content + } + if text == "" { + return nil + } + + // Short timeout: a slow provider shouldn't block a memory write. + cctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + v, err := embedProvider.Embed(cctx, text) + if err != nil { + log.Printf("embed: failed (continuing without vector): %v", err) + return nil + } + return v +} diff --git a/cmd/solr-mem-server/knn.go b/cmd/solr-mem-server/knn.go new file mode 100644 index 0000000..15a7812 --- /dev/null +++ b/cmd/solr-mem-server/knn.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/arreyder/solr-mem/internal/solr" +) + +// knnQueryString formats a Solr KNN local-params query for the embedding field. +// Produces: {!knn f=embedding topK=N}[v1,v2,...] +func knnQueryString(vec []float32, topK int) string { + parts := make([]string, len(vec)) + for i, v := range vec { + parts[i] = strconv.FormatFloat(float64(v), 'f', -1, 32) + } + return fmt.Sprintf("{!knn f=embedding topK=%d}[%s]", topK, strings.Join(parts, ",")) +} + +// runKNN performs a KNN search against the embedding field. Filters from the +// caller (agent_id, tags, etc.) are passed through as fq so semantic hits +// respect the same scoping as BM25 hits. +func runKNN(ctx context.Context, client *solr.Client, vec []float32, filters []string, topK int) (*solr.QueryResponse, error) { + params := solr.QueryParams{ + Query: knnQueryString(vec, topK), + FilterQueries: filters, + Rows: topK, + } + return client.Query(ctx, params) +} diff --git a/cmd/solr-mem-server/main.go b/cmd/solr-mem-server/main.go index e9a4c64..13942cf 100644 --- a/cmd/solr-mem-server/main.go +++ b/cmd/solr-mem-server/main.go @@ -8,12 +8,14 @@ import ( "net/http" "os" + "github.com/arreyder/solr-mem/internal/embed" "github.com/arreyder/solr-mem/internal/solr" "github.com/modelcontextprotocol/go-sdk/mcp" ) var solrClient *solr.Client var codeClient *solr.Client +var embedProvider embed.Provider func main() { solrURL := os.Getenv("SOLR_URL") @@ -28,6 +30,23 @@ func main() { } codeClient = solr.NewClient(codeURL) + // Optional embedding provider. If neither OLLAMA_EMBEDDING_URL nor + // OPENAI_API_KEY is set, embedProvider stays nil and the server runs in + // BM25-only mode. + if p, err := embed.FromEnv(); err != nil { + log.Printf("embedding provider init failed: %v (continuing BM25-only)", err) + } else if p != nil { + embedProvider = p + if p.Dim() > 0 { + log.Printf("embedding provider: %s (dim=%d)", p.Name(), p.Dim()) + } else { + // Ollama's dim is learned on the first successful call. + log.Printf("embedding provider: %s (dim to be inferred on first call)", p.Name()) + } + } else { + log.Printf("no embedding provider configured (set OLLAMA_EMBEDDING_URL for local or OPENAI_API_KEY for managed)") + } + // Start expiration sweeper ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/cmd/solr-mem-server/scrub.go b/cmd/solr-mem-server/scrub.go new file mode 100644 index 0000000..6d0695e --- /dev/null +++ b/cmd/solr-mem-server/scrub.go @@ -0,0 +1,42 @@ +package main + +import ( + "log" + "os" + "strings" + "sync" + + "github.com/arreyder/solr-mem/internal/privacy" +) + +// privacyScrubDisabled is resolved once from the environment. +var ( + privacyScrubOnce sync.Once + privacyScrubOff bool +) + +// privacyScrubEnabled reports whether secret scrubbing is on. Controlled by +// SOLR_MEM_PRIVACY_SCRUB=off (default: on). +func privacyScrubEnabled() bool { + privacyScrubOnce.Do(func() { + v := strings.ToLower(strings.TrimSpace(os.Getenv("SOLR_MEM_PRIVACY_SCRUB"))) + if v == "off" || v == "false" || v == "0" || v == "disabled" { + privacyScrubOff = true + log.Println("privacy: scrub disabled by SOLR_MEM_PRIVACY_SCRUB") + } + }) + return !privacyScrubOff +} + +// scrubString returns a scrubbed copy of s and the hit map; if scrubbing is +// disabled or s is empty, the hit map is nil. +func scrubString(s string) (string, map[string]int) { + if !privacyScrubEnabled() || s == "" { + return s, nil + } + r := privacy.Scrub(s) + if r.Count() == 0 { + return r.Content, nil + } + return r.Content, r.Hits +} diff --git a/cmd/solr-mem-server/search_tool.go b/cmd/solr-mem-server/search_tool.go index a18ae80..cc23c95 100644 --- a/cmd/solr-mem-server/search_tool.go +++ b/cmd/solr-mem-server/search_tool.go @@ -4,8 +4,10 @@ import ( "context" "encoding/json" "fmt" + "log" "strings" + "github.com/arreyder/solr-mem/internal/retrieval" "github.com/arreyder/solr-mem/internal/solr" ) @@ -15,42 +17,85 @@ func searchMemoriesTool(ctx context.Context, args map[string]any) (any, error) { return nil, fmt.Errorf("query is required") } - params := solr.QueryParams{ - Query: query, - Rows: getInt(args, "limit", 10), - Highlight: getBool(args, "highlight", true), - Facet: getBool(args, "facet", false), + rows := getInt(args, "limit", 10) + + // Build filter queries shared by both BM25 and KNN passes. This keeps + // semantic hits scoped to the same agent / session / tag set as keyword + // hits. + filters := buildSearchFilters(args) + + bm25Params := solr.QueryParams{ + Query: query, + Rows: rows, + Highlight: getBool(args, "highlight", true), + Facet: getBool(args, "facet", false), + FilterQueries: filters, + } + if bm25Params.Facet { + bm25Params.FacetFields = []string{"memory_type", "tags", "agent_id", "source"} } - // Build filter queries + bm25Resp, err := solrClient.Query(ctx, bm25Params) + if err != nil { + return nil, fmt.Errorf("search failed: %w", err) + } + + // Semantic (KNN) stream: only runs when the server has an embed provider + // configured and the caller hasn't opted out via semantic=false. + semantic := getBool(args, "semantic", true) && embedProvider != nil + knnTopK := getInt(args, "knn_topk", rows*3) // fetch wider so fusion has something to cross-pollinate + + resp := bm25Resp + if semantic { + if fused, err := runHybrid(ctx, query, filters, bm25Resp, knnTopK, rows); err == nil && fused != nil { + resp = fused + } else if err != nil { + log.Printf("search: semantic fallback to BM25-only: %v", err) + } + } + + // Cap per session so one chatty session can't dominate results. + // Default 3; pass 0 to disable. + sessionCap := getInt(args, "session_cap", 3) + if sessionCap > 0 { + resp.Docs = diversifyBySession(resp.Docs, func(d map[string]any) string { + s, _ := d["session_id"].(string) + return s + }, sessionCap) + } + + return ToolOutput{ + Text: formatSearchResults(resp), + Structured: resp, + }, nil +} + +// buildSearchFilters translates the search_memories args into Solr fq clauses. +func buildSearchFilters(args map[string]any) []string { + var filters []string if v := getString(args, "agent_id"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("agent_id:%q", v)) + filters = append(filters, fmt.Sprintf("agent_id:%q", v)) } if v := getString(args, "memory_type"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("memory_type:%q", v)) + filters = append(filters, fmt.Sprintf("memory_type:%q", v)) } if v := getString(args, "source"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("source:%q", v)) + filters = append(filters, fmt.Sprintf("source:%q", v)) } if tags := getStringSlice(args, "tags"); len(tags) > 0 { for _, tag := range tags { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("tags:%q", tag)) + filters = append(filters, fmt.Sprintf("tags:%q", tag)) } } if v := getString(args, "session_id"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("session_id:%q", v)) + filters = append(filters, fmt.Sprintf("session_id:%q", v)) } if v := getString(args, "lifetime"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("lifetime:%q", v)) + filters = append(filters, fmt.Sprintf("lifetime:%q", v)) } - - // Importance filter - impMin := getString(args, "importance_min") - if impMin != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("importance:[%s TO *]", impMin)) + if impMin := getString(args, "importance_min"); impMin != "" { + filters = append(filters, fmt.Sprintf("importance:[%s TO *]", impMin)) } - - // Date range filters from := getString(args, "from") to := getString(args, "to") if from != "" || to != "" { @@ -62,24 +107,73 @@ func searchMemoriesTool(ctx context.Context, args map[string]any) (any, error) { if to != "" { toVal = to } - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("created_at:[%s TO %s]", fromVal, toVal)) + filters = append(filters, fmt.Sprintf("created_at:[%s TO %s]", fromVal, toVal)) } + return filters +} - if params.Facet { - params.FacetFields = []string{"memory_type", "tags", "agent_id", "source"} +// runHybrid embeds the query, runs a KNN search alongside the BM25 hits +// already in bm25Resp, and fuses them with RRF. Returns a synthetic +// QueryResponse whose Docs are in fused order and whose Facets/Highlighting +// are inherited from the BM25 response. Returns (nil, nil) if the embed +// failed softly (KNN skipped) — caller stays on BM25. +func runHybrid(ctx context.Context, query string, filters []string, bm25Resp *solr.QueryResponse, knnTopK, finalRows int) (*solr.QueryResponse, error) { + vec, err := embedProvider.Embed(ctx, query) + if err != nil { + return nil, fmt.Errorf("embed query: %w", err) } - - resp, err := solrClient.Query(ctx, params) + knnResp, err := runKNN(ctx, solrClient, vec, filters, knnTopK) if err != nil { - return nil, fmt.Errorf("search failed: %w", err) + return nil, fmt.Errorf("knn query: %w", err) } - return ToolOutput{ - Text: formatSearchResults(resp), - Structured: resp, + bm25IDs := idsFromDocs(bm25Resp.Docs) + knnIDs := idsFromDocs(knnResp.Docs) + + fused := retrieval.FuseIDs([]retrieval.Stream{ + {Name: "bm25", IDs: bm25IDs}, + {Name: "vec", IDs: knnIDs}, + }, retrieval.RRFWithTopK(finalRows)) + + // Hydrate: prefer the BM25 doc (has highlighting) over the KNN doc for + // any ID present in both. + byID := make(map[string]map[string]any, len(bm25Resp.Docs)+len(knnResp.Docs)) + for _, d := range knnResp.Docs { + if id, ok := d["id"].(string); ok { + byID[id] = d + } + } + for _, d := range bm25Resp.Docs { + if id, ok := d["id"].(string); ok { + byID[id] = d + } + } + + ordered := make([]map[string]any, 0, len(fused)) + for _, id := range fused { + if d, ok := byID[id]; ok { + ordered = append(ordered, d) + } + } + + return &solr.QueryResponse{ + NumFound: len(ordered), + Docs: ordered, + Highlighting: bm25Resp.Highlighting, + Facets: bm25Resp.Facets, }, nil } +func idsFromDocs(docs []map[string]any) []string { + out := make([]string, 0, len(docs)) + for _, d := range docs { + if id, ok := d["id"].(string); ok { + out = append(out, id) + } + } + return out +} + func formatSearchResults(resp *solr.QueryResponse) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("Found %d memories.\n\n", resp.NumFound)) diff --git a/cmd/solr-mem-server/store_tool.go b/cmd/solr-mem-server/store_tool.go index 88ddf47..3f1f577 100644 --- a/cmd/solr-mem-server/store_tool.go +++ b/cmd/solr-mem-server/store_tool.go @@ -3,8 +3,11 @@ package main import ( "context" "fmt" + "log" "time" + "github.com/arreyder/solr-mem/internal/contenthash" + "github.com/arreyder/solr-mem/internal/privacy" "github.com/arreyder/solr-mem/internal/solr" "github.com/google/uuid" ) @@ -24,23 +27,68 @@ func storeMemoryTool(ctx context.Context, args map[string]any) (any, error) { format = "prose" } + title := getString(args, "title") + metadata := getString(args, "metadata") + scrubbedContent, contentHits := scrubString(content) + scrubbedTitle, titleHits := scrubString(title) + allHits := privacy.MergeHits(contentHits, titleHits) + metadata = privacy.MergeMetadata(metadata, allHits) + + tags := getStringSlice(args, "tags") + hash := contenthash.Compute(scrubbedTitle, scrubbedContent, tags) + windowSec := getInt(args, "dedup_window_seconds", defaultDedupWindowSec) + onDup := normalizeOnDuplicate(getString(args, "on_duplicate")) + + if hash != "" && windowSec > 0 && onDup != "force" { + existingID, err := findRecentByHash(ctx, solrClient, hash, windowSec) + if err != nil { + log.Printf("store_memory: dedup lookup failed (continuing as new): %v", err) + } else if existingID != "" { + if onDup == "merge" { + upd := map[string]any{"updated_at": now.Format(time.RFC3339)} + if err := solrClient.Update(ctx, existingID, upd); err != nil { + log.Printf("store_memory: merge update failed: %v", err) + } + return ToolOutput{ + Text: fmt.Sprintf("Memory already stored within the last %ds. Merged into existing.\nID: %s", windowSec, existingID), + Structured: map[string]any{ + "id": existingID, + "duplicate": true, + "action": "merged", + }, + }, nil + } + // skip (default) + return ToolOutput{ + Text: fmt.Sprintf("Memory already stored within the last %ds. Skipped.\nID: %s", windowSec, existingID), + Structured: map[string]any{ + "id": existingID, + "duplicate": true, + "action": "skipped", + }, + }, nil + } + } + doc := solr.Document{ - ID: uuid.New().String(), - AgentID: getString(args, "agent_id"), - MemoryType: getString(args, "memory_type"), - Content: content, - Title: getString(args, "title"), - Tags: getStringSlice(args, "tags"), - Source: getString(args, "source"), - Importance: getFloat(args, "importance", 0.5), - Metadata: getString(args, "metadata"), - CreatedAt: now, - UpdatedAt: now, - ExpiresAt: expiresAt, - Lifetime: lifetime, - SessionID: getString(args, "session_id"), - RelatedIDs: getStringSlice(args, "related_ids"), - Format: format, + ID: uuid.New().String(), + AgentID: getString(args, "agent_id"), + MemoryType: getString(args, "memory_type"), + Content: scrubbedContent, + Title: scrubbedTitle, + Tags: tags, + Source: getString(args, "source"), + Importance: getFloat(args, "importance", 0.5), + Metadata: metadata, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: expiresAt, + Lifetime: lifetime, + SessionID: getString(args, "session_id"), + RelatedIDs: getStringSlice(args, "related_ids"), + Format: format, + ContentHash: hash, + Embedding: embedForStore(ctx, scrubbedTitle, scrubbedContent), } if err := solrClient.Add(ctx, doc); err != nil { @@ -53,13 +101,23 @@ func storeMemoryTool(ctx context.Context, args map[string]any) (any, error) { result += fmt.Sprintf("\nExpires: %s", expiresAt) } + structured := map[string]any{ + "id": doc.ID, + "lifetime": doc.Lifetime, + "expires_at": expiresAt, + "created_at": doc.CreatedAt.Format(time.RFC3339), + } + if n := len(allHits); n > 0 { + total := 0 + for _, v := range allHits { + total += v + } + structured["scrub_count"] = total + result += fmt.Sprintf("\nPrivacy: redacted %d secret(s) across %d kind(s)", total, n) + } + return ToolOutput{ - Text: result, - Structured: map[string]any{ - "id": doc.ID, - "lifetime": doc.Lifetime, - "expires_at": expiresAt, - "created_at": doc.CreatedAt.Format(time.RFC3339), - }, + Text: result, + Structured: structured, }, nil } diff --git a/cmd/solr-mem-server/tools.go b/cmd/solr-mem-server/tools.go index 179eceb..9b87d11 100644 --- a/cmd/solr-mem-server/tools.go +++ b/cmd/solr-mem-server/tools.go @@ -26,25 +26,29 @@ func ToolSchemas() []ToolDefinition { **When to use**: Save observations, reflections, facts, conversations, tasks, or decisions for later retrieval. **Required**: content (the memory text) -**Optional**: agent_id, memory_type, title, tags, source, importance, metadata, lifetime, session_id, related_ids, expires_at +**Optional**: agent_id, memory_type, title, tags, source, importance, metadata, lifetime, session_id, related_ids, expires_at, format, dedup_window_seconds, on_duplicate **Lifetime values**: permanent (default, never expires), session (cleaned up with session), ephemeral (1 hour TTL), temporary (7 day TTL) +**Dedup**: By default an identical memory (same title + content + tags) stored within 300 seconds is skipped and the existing ID is returned. Pass dedup_window_seconds=0 to disable, on_duplicate="merge" to bump updated_at on the existing doc, or on_duplicate="force" to always insert. + **Content format**: Use compact structured formats (YAML, key-value, tables) over prose. Put searchable summary in title, machine-readable data in metadata (JSON), and categorization in tags. See server instructions for examples.`, InputSchema: NewObjectSchema(map[string]any{ - "content": prop("string", "The main text content of the memory (required)"), - "agent_id": prop("string", "ID of the agent storing this memory"), - "memory_type": prop("string", "Type: observation, reflection, fact, conversation, task, decision"), - "title": prop("string", "Short title or summary of the memory"), - "tags": arrayPropSchema(prop("string", "Tag"), "Categorization tags"), - "source": prop("string", "Where this memory came from (e.g., conversation, tool, file)"), - "importance": numberProp("Importance score from 0.0 to 1.0", floatPtr(0), floatPtr(1)), - "metadata": prop("string", "JSON string of arbitrary metadata"), - "lifetime": prop("string", "Memory lifetime: permanent (default), session, ephemeral (1h), temporary (7d)"), - "session_id": prop("string", "Session/conversation ID to group memories"), - "related_ids": arrayPropSchema(prop("string", "ID"), "IDs of related memories"), - "expires_at": prop("string", "Explicit expiration date (ISO 8601). Overrides lifetime."), - "format": prop("string", "Content format: yaml, markdown, json, table, prose (default: prose). Helps agents choose parsing strategy."), + "content": prop("string", "The main text content of the memory (required)"), + "agent_id": prop("string", "ID of the agent storing this memory"), + "memory_type": prop("string", "Type: observation, reflection, fact, conversation, task, decision"), + "title": prop("string", "Short title or summary of the memory"), + "tags": arrayPropSchema(prop("string", "Tag"), "Categorization tags"), + "source": prop("string", "Where this memory came from (e.g., conversation, tool, file)"), + "importance": numberProp("Importance score from 0.0 to 1.0", floatPtr(0), floatPtr(1)), + "metadata": prop("string", "JSON string of arbitrary metadata"), + "lifetime": prop("string", "Memory lifetime: permanent (default), session, ephemeral (1h), temporary (7d)"), + "session_id": prop("string", "Session/conversation ID to group memories"), + "related_ids": arrayPropSchema(prop("string", "ID"), "IDs of related memories"), + "expires_at": prop("string", "Explicit expiration date (ISO 8601). Overrides lifetime."), + "format": prop("string", "Content format: yaml, markdown, json, table, prose (default: prose). Helps agents choose parsing strategy."), + "dedup_window_seconds": integerProp("Skip-if-duplicate window in seconds (default 300, 0 disables)", intPtr(0), intPtr(86400)), + "on_duplicate": prop("string", "Action when a duplicate is found: skip (default), merge (bump updated_at), force (always insert)"), }, "content"), }, Handler: storeMemoryTool, @@ -57,7 +61,9 @@ func ToolSchemas() []ToolDefinition { **When to use**: Find relevant memories by content, tags, type, agent, or time range. Uses edismax with field boosting (content^3, title^2, tags^1.5) and recency boost. **Required**: query (search text) -**Optional**: agent_id, memory_type, tags, source, importance_min, from, to, limit, highlight, facet, session_id, lifetime`, +**Optional**: agent_id, memory_type, tags, source, importance_min, from, to, limit, highlight, facet, session_id, lifetime, session_cap, semantic, knn_topk + +**Semantic search**: When the server has an embedding provider configured (OPENAI_API_KEY set), results combine BM25 keyword hits with KNN vector hits via Reciprocal Rank Fusion. Pass semantic=false to force BM25-only. knn_topk widens the KNN pool before fusion (default: 3× limit).`, InputSchema: NewObjectSchema(map[string]any{ "query": prop("string", "Full-text search query (required)"), "agent_id": prop("string", "Filter by agent ID"), @@ -72,6 +78,9 @@ func ToolSchemas() []ToolDefinition { "facet": prop("boolean", "Include facet counts (default: false)"), "session_id": prop("string", "Filter by session ID"), "lifetime": prop("string", "Filter by lifetime (permanent, session, ephemeral, temporary)"), + "session_cap": integerProp("Max hits per session_id after ranking (default 3, 0 = disable)", intPtr(0), intPtr(100)), + "semantic": prop("boolean", "Include KNN vector stream and fuse with BM25 via RRF (default: true when an embedding provider is configured)"), + "knn_topk": integerProp("KNN pool size before fusion (default: 3× limit)", intPtr(0), intPtr(500)), }, "query"), }, Handler: searchMemoriesTool, @@ -173,9 +182,10 @@ func ToolSchemas() []ToolDefinition { Name: "bulk_store_memories", Description: `Store multiple memories in a single batch operation. -**When to use**: Efficiently store many memories at once (e.g., importing notes, bulk archival). Each memory in the array uses the same fields as store_memory. +**When to use**: Efficiently store many memories at once (e.g., importing notes, bulk archival). Each memory in the array uses the same fields as store_memory. Dedup parameters apply to the batch as a whole. -**Required**: memories (array of memory objects, each with at least a "content" field)`, +**Required**: memories (array of memory objects, each with at least a "content" field) +**Optional**: dedup_window_seconds, on_duplicate`, InputSchema: NewObjectSchema(map[string]any{ "memories": arrayPropSchema( NewObjectSchema(map[string]any{ @@ -195,6 +205,8 @@ func ToolSchemas() []ToolDefinition { }, "content"), "Array of memory objects to store", ), + "dedup_window_seconds": integerProp("Skip-if-duplicate window in seconds (default 300, 0 disables)", intPtr(0), intPtr(86400)), + "on_duplicate": prop("string", "Action when a duplicate is found: skip (default), merge (bump updated_at), force (always insert)"), }, "memories"), }, Handler: bulkStoreMemoriesTool, diff --git a/cmd/solr-mem-server/update_tool.go b/cmd/solr-mem-server/update_tool.go index a8e5b24..26d25af 100644 --- a/cmd/solr-mem-server/update_tool.go +++ b/cmd/solr-mem-server/update_tool.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "time" + + "github.com/arreyder/solr-mem/internal/privacy" ) func updateMemoryTool(ctx context.Context, args map[string]any) (any, error) { @@ -13,12 +15,17 @@ func updateMemoryTool(ctx context.Context, args map[string]any) (any, error) { } fields := make(map[string]any) + scrubHits := map[string]int{} if v := getString(args, "content"); v != "" { - fields["content"] = v + scrubbed, hits := scrubString(v) + fields["content"] = scrubbed + scrubHits = privacy.MergeHits(scrubHits, hits) } if v := getString(args, "title"); v != "" { - fields["title"] = v + scrubbed, hits := scrubString(v) + fields["title"] = scrubbed + scrubHits = privacy.MergeHits(scrubHits, hits) } if v := getString(args, "memory_type"); v != "" { fields["memory_type"] = v @@ -35,6 +42,13 @@ func updateMemoryTool(ctx context.Context, args map[string]any) (any, error) { if v := getString(args, "metadata"); v != "" { fields["metadata"] = v } + // If secrets were scrubbed, merge the tally into metadata (existing + // metadata arg takes precedence over the stored doc's metadata, which + // matches atomic-update semantics). + if len(scrubHits) > 0 { + existing, _ := fields["metadata"].(string) + fields["metadata"] = privacy.MergeMetadata(existing, scrubHits) + } if v := getString(args, "session_id"); v != "" { fields["session_id"] = v } diff --git a/internal/contenthash/hash.go b/internal/contenthash/hash.go new file mode 100644 index 0000000..90c0546 --- /dev/null +++ b/internal/contenthash/hash.go @@ -0,0 +1,69 @@ +// Package contenthash produces stable SHA-256 hashes over memory content for +// dedup purposes. Normalization is deliberate: the same logical memory sent +// with different whitespace or tag ordering should collide. +package contenthash + +import ( + "crypto/sha256" + "encoding/hex" + "sort" + "strings" +) + +// Compute returns a hex-encoded SHA-256 over normalized inputs. +// +// Normalization: +// - title and content are trimmed and collapsed (runs of whitespace → single space) +// - tags are sorted and each tag is lowercased + trimmed +// - fields are joined with 0x1f (unit separator) to avoid ambiguity +// +// An empty content produces an empty hash (caller decides whether to dedup). +func Compute(title, content string, tags []string) string { + content = collapse(content) + if content == "" { + return "" + } + title = collapse(title) + + normTags := make([]string, 0, len(tags)) + for _, t := range tags { + t = strings.ToLower(strings.TrimSpace(t)) + if t != "" { + normTags = append(normTags, t) + } + } + sort.Strings(normTags) + + h := sha256.New() + h.Write([]byte(title)) + h.Write([]byte{0x1f}) + h.Write([]byte(content)) + h.Write([]byte{0x1f}) + // Join tags with 0x1e so a tag containing a literal delimiter can't + // collide with two tags split by it. + h.Write([]byte(strings.Join(normTags, "\x1e"))) + return hex.EncodeToString(h.Sum(nil)) +} + +// collapse trims and replaces runs of whitespace with a single space. +func collapse(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + var b strings.Builder + b.Grow(len(s)) + space := false + for _, r := range s { + if r == ' ' || r == '\t' || r == '\n' || r == '\r' { + if !space { + b.WriteByte(' ') + space = true + } + continue + } + b.WriteRune(r) + space = false + } + return b.String() +} diff --git a/internal/contenthash/hash_test.go b/internal/contenthash/hash_test.go new file mode 100644 index 0000000..548e0eb --- /dev/null +++ b/internal/contenthash/hash_test.go @@ -0,0 +1,72 @@ +package contenthash + +import "testing" + +func TestComputeStable(t *testing.T) { + h1 := Compute("title", "body", []string{"a", "b"}) + h2 := Compute("title", "body", []string{"a", "b"}) + if h1 != h2 { + t.Errorf("expected identical hash for identical inputs, got %q vs %q", h1, h2) + } + if len(h1) != 64 { + t.Errorf("expected 64-char hex sha-256, got %d chars", len(h1)) + } +} + +func TestComputeTagOrderAgnostic(t *testing.T) { + a := Compute("t", "c", []string{"x", "y", "z"}) + b := Compute("t", "c", []string{"z", "x", "y"}) + if a != b { + t.Errorf("tag order should not affect hash") + } +} + +func TestComputeTagCaseNormalized(t *testing.T) { + a := Compute("t", "c", []string{"Foo", "bar"}) + b := Compute("t", "c", []string{"foo", "BAR"}) + if a != b { + t.Errorf("tag case should not affect hash") + } +} + +func TestComputeWhitespaceNormalized(t *testing.T) { + a := Compute("my title", "hello world", nil) + b := Compute("my title", "hello\n\tworld\n", nil) + if a != b { + t.Errorf("whitespace differences should not affect hash: %q vs %q", a, b) + } +} + +func TestComputeDifferentContent(t *testing.T) { + a := Compute("t", "one", nil) + b := Compute("t", "two", nil) + if a == b { + t.Errorf("different content should hash differently") + } +} + +func TestComputeDifferentTitle(t *testing.T) { + a := Compute("title one", "c", nil) + b := Compute("title two", "c", nil) + if a == b { + t.Errorf("different title should hash differently") + } +} + +func TestComputeEmptyContent(t *testing.T) { + if h := Compute("t", "", []string{"x"}); h != "" { + t.Errorf("empty content should produce empty hash, got %q", h) + } + if h := Compute("", " ", nil); h != "" { + t.Errorf("whitespace-only content should produce empty hash") + } +} + +func TestComputeTagSeparatorNotAmbiguous(t *testing.T) { + // A tag containing unusual characters must not collide with a split version. + a := Compute("t", "c", []string{"a,b"}) + b := Compute("t", "c", []string{"a", "b"}) + if a == b { + t.Errorf("tags with delimiter-like chars must not collide with split tags") + } +} diff --git a/internal/embed/ollama.go b/internal/embed/ollama.go new file mode 100644 index 0000000..02bd283 --- /dev/null +++ b/internal/embed/ollama.go @@ -0,0 +1,98 @@ +package embed + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// Ollama provides embeddings via a local (or LAN) Ollama server's +// /api/embeddings endpoint. The output dimension depends on the model — +// it's discovered from the first successful call rather than hardcoded. +type Ollama struct { + baseURL string + model string + dim int + client *http.Client +} + +// DefaultOllamaModel is a small, fast general-purpose embedding model. +// 768 dim, ~137 MB on disk. +const DefaultOllamaModel = "nomic-embed-text" + +// NewOllama constructs a client. baseURL should be the Ollama server root +// (e.g. "http://localhost:11434"). If model is empty, DefaultOllamaModel is +// used. +func NewOllama(baseURL, model string) *Ollama { + baseURL = strings.TrimRight(baseURL, "/") + if model == "" { + model = DefaultOllamaModel + } + return &Ollama{ + baseURL: baseURL, + model: model, + client: &http.Client{Timeout: 60 * time.Second}, + } +} + +// Name implements Provider. +func (o *Ollama) Name() string { return "ollama:" + o.model } + +// Dim returns 0 until the first successful call discovers the true +// dimension. Callers using Dim() to preconfigure the Solr schema must +// make at least one embed call first (or hardcode the dim from docs). +func (o *Ollama) Dim() int { return o.dim } + +// Embed implements Provider. On success, updates the cached dim with the +// actual response length so downstream consumers can introspect it. +func (o *Ollama) Embed(ctx context.Context, text string) ([]float32, error) { + if text == "" { + return nil, fmt.Errorf("embed: empty text") + } + + body, err := json.Marshal(map[string]any{ + "model": o.model, + "prompt": text, + }) + if err != nil { + return nil, fmt.Errorf("embed: marshal: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, o.baseURL+"/api/embeddings", bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := o.client.Do(req) + if err != nil { + return nil, fmt.Errorf("embed: http: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + snippet, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("embed: status %d: %s", resp.StatusCode, string(snippet)) + } + + var out struct { + Embedding []float32 `json:"embedding"` + } + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, fmt.Errorf("embed: decode: %w", err) + } + if len(out.Embedding) == 0 { + // Ollama returns 200 with empty embedding on some error cases + // (e.g. unknown model). Treat as an error. + return nil, fmt.Errorf("embed: empty embedding returned (model %q may not be pulled)", o.model) + } + if o.dim == 0 || o.dim != len(out.Embedding) { + o.dim = len(out.Embedding) + } + return out.Embedding, nil +} diff --git a/internal/embed/ollama_test.go b/internal/embed/ollama_test.go new file mode 100644 index 0000000..4e6325c --- /dev/null +++ b/internal/embed/ollama_test.go @@ -0,0 +1,159 @@ +package embed + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestOllamaEmbedSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/embeddings" { + t.Errorf("wrong path: %s", r.URL.Path) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("bad body: %v", err) + } + if body["model"] != DefaultOllamaModel { + t.Errorf("wrong model: %v", body["model"]) + } + if body["prompt"] != "hello world" { + t.Errorf("wrong prompt: %v", body["prompt"]) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "embedding": []float32{0.1, 0.2, 0.3, 0.4}, + }) + })) + defer server.Close() + + p := NewOllama(server.URL, "") + v, err := p.Embed(context.Background(), "hello world") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if len(v) != 4 { + t.Errorf("expected 4 dims, got %d", len(v)) + } + // Dim should be learned from the response. + if p.Dim() != 4 { + t.Errorf("dim should update after first call, got %d", p.Dim()) + } +} + +func TestOllamaCustomModel(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + if body["model"] != "mxbai-embed-large" { + t.Errorf("expected mxbai-embed-large, got %v", body["model"]) + } + _ = json.NewEncoder(w).Encode(map[string]any{"embedding": []float32{1.0}}) + })) + defer server.Close() + + p := NewOllama(server.URL, "mxbai-embed-large") + if _, err := p.Embed(context.Background(), "x"); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !strings.Contains(p.Name(), "mxbai-embed-large") { + t.Errorf("name should include model: %s", p.Name()) + } +} + +func TestOllamaTrailingSlashInURL(t *testing.T) { + // NewOllama should trim trailing slash so we don't double it. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/embeddings" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{"embedding": []float32{0.0}}) + })) + defer server.Close() + + p := NewOllama(server.URL+"/", "") + if _, err := p.Embed(context.Background(), "x"); err != nil { + t.Errorf("unexpected err: %v", err) + } +} + +func TestOllamaEmptyInput(t *testing.T) { + p := NewOllama("http://irrelevant", "") + if _, err := p.Embed(context.Background(), ""); err == nil { + t.Error("expected error for empty text") + } +} + +func TestOllamaErrorStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "model not found", http.StatusNotFound) + })) + defer server.Close() + + p := NewOllama(server.URL, "") + _, err := p.Embed(context.Background(), "x") + if err == nil { + t.Fatal("expected error on 404") + } + if !strings.Contains(err.Error(), "404") { + t.Errorf("error should include status, got: %v", err) + } +} + +func TestOllamaEmptyEmbedding(t *testing.T) { + // Ollama returns 200 with an empty embedding when the model isn't pulled. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"embedding": []float32{}}) + })) + defer server.Close() + + p := NewOllama(server.URL, "") + _, err := p.Embed(context.Background(), "x") + if err == nil { + t.Fatal("expected error on empty embedding") + } + if !strings.Contains(err.Error(), "empty embedding") { + t.Errorf("error should name the failure mode, got: %v", err) + } +} + +func TestFromEnvPrefersOllama(t *testing.T) { + // When both Ollama and OpenAI are configured, Ollama wins (local/free beats remote/paid). + t.Setenv("OLLAMA_EMBEDDING_URL", "http://mac.local:11434") + t.Setenv("OLLAMA_EMBEDDING_MODEL", "mxbai-embed-large") + t.Setenv("OPENAI_API_KEY", "sk-test") + + p, err := FromEnv() + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if p == nil { + t.Fatal("expected a provider") + } + if !strings.HasPrefix(p.Name(), "ollama:") { + t.Errorf("expected ollama provider to win, got %s", p.Name()) + } + if !strings.Contains(p.Name(), "mxbai-embed-large") { + t.Errorf("custom model should be reflected: %s", p.Name()) + } +} + +func TestFromEnvOllamaAloneUsesDefault(t *testing.T) { + t.Setenv("OLLAMA_EMBEDDING_URL", "http://localhost:11434") + t.Setenv("OLLAMA_EMBEDDING_MODEL", "") + t.Setenv("OPENAI_API_KEY", "") + + p, err := FromEnv() + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if p == nil { + t.Fatal("expected a provider") + } + if !strings.Contains(p.Name(), DefaultOllamaModel) { + t.Errorf("should default to %s, got %s", DefaultOllamaModel, p.Name()) + } +} diff --git a/internal/embed/openai.go b/internal/embed/openai.go new file mode 100644 index 0000000..476e820 --- /dev/null +++ b/internal/embed/openai.go @@ -0,0 +1,102 @@ +package embed + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// OpenAI provides embeddings via the OpenAI /v1/embeddings API. It is the +// reference provider; other providers follow the same interface. +type OpenAI struct { + apiKey string + model string + dim int + baseURL string + client *http.Client +} + +// DefaultOpenAIModel is OpenAI's current small text embedding model. It emits +// 1536-dimensional vectors, which the Solr schema must match. +const ( + DefaultOpenAIModel = "text-embedding-3-small" + DefaultOpenAIDim = 1536 +) + +// NewOpenAI constructs a client. If model is empty, DefaultOpenAIModel is used. +func NewOpenAI(apiKey, model string) *OpenAI { + if model == "" { + model = DefaultOpenAIModel + } + return &OpenAI{ + apiKey: apiKey, + model: model, + dim: DefaultOpenAIDim, + baseURL: "https://api.openai.com/v1", + client: &http.Client{Timeout: 30 * time.Second}, + } +} + +// Name implements Provider. +func (o *OpenAI) Name() string { return "openai:" + o.model } + +// Dim implements Provider. +func (o *OpenAI) Dim() int { return o.dim } + +// Embed implements Provider. A non-200 response or malformed payload +// returns an error with enough context for debugging. +func (o *OpenAI) Embed(ctx context.Context, text string) ([]float32, error) { + if text == "" { + return nil, fmt.Errorf("embed: empty text") + } + + body, err := json.Marshal(map[string]any{ + "model": o.model, + "input": text, + }) + if err != nil { + return nil, fmt.Errorf("embed: marshal: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, o.baseURL+"/embeddings", bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+o.apiKey) + + resp, err := o.client.Do(req) + if err != nil { + return nil, fmt.Errorf("embed: http: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + snippet, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("embed: status %d: %s", resp.StatusCode, string(snippet)) + } + + var out struct { + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, fmt.Errorf("embed: decode: %w", err) + } + if len(out.Data) == 0 || len(out.Data[0].Embedding) == 0 { + return nil, fmt.Errorf("embed: empty response") + } + v := out.Data[0].Embedding + if len(v) != o.dim { + // Update our advertised dim on first successful call so queries + // don't drift. OpenAI allows a `dimensions` override; we don't + // set one here, so the response should match DefaultOpenAIDim. + o.dim = len(v) + } + return v, nil +} diff --git a/internal/embed/openai_test.go b/internal/embed/openai_test.go new file mode 100644 index 0000000..ab4a164 --- /dev/null +++ b/internal/embed/openai_test.go @@ -0,0 +1,137 @@ +package embed + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestOpenAIEmbedSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("missing/wrong Authorization header: %q", got) + } + if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/json") { + t.Errorf("wrong content-type: %q", ct) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("bad body: %v", err) + } + if body["model"] != DefaultOpenAIModel { + t.Errorf("wrong model: %v", body["model"]) + } + if body["input"] != "hello" { + t.Errorf("wrong input: %v", body["input"]) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{ + {"embedding": []float32{0.1, 0.2, 0.3}}, + }, + }) + })) + defer server.Close() + + p := NewOpenAI("test-key", "") + p.baseURL = server.URL + + v, err := p.Embed(context.Background(), "hello") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if len(v) != 3 { + t.Errorf("expected 3 dims, got %d", len(v)) + } + if v[0] != 0.1 || v[1] != 0.2 || v[2] != 0.3 { + t.Errorf("unexpected vector: %v", v) + } + // After a successful call with unexpected dim, Dim() should adapt. + if p.Dim() != 3 { + t.Errorf("dim should update to actual response size: got %d", p.Dim()) + } +} + +func TestOpenAIEmbedEmptyInput(t *testing.T) { + p := NewOpenAI("k", "") + if _, err := p.Embed(context.Background(), ""); err == nil { + t.Error("expected error for empty text") + } +} + +func TestOpenAIEmbedErrorStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + })) + defer server.Close() + + p := NewOpenAI("k", "") + p.baseURL = server.URL + + _, err := p.Embed(context.Background(), "hi") + if err == nil { + t.Fatal("expected error on 429") + } + if !strings.Contains(err.Error(), "429") { + t.Errorf("error should include status code, got: %v", err) + } +} + +func TestOpenAIEmbedEmptyData(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"data": []any{}}) + })) + defer server.Close() + + p := NewOpenAI("k", "") + p.baseURL = server.URL + + if _, err := p.Embed(context.Background(), "x"); err == nil { + t.Error("expected error on empty data array") + } +} + +func TestOpenAINameAndDim(t *testing.T) { + p := NewOpenAI("k", "") + if p.Name() != "openai:"+DefaultOpenAIModel { + t.Errorf("name: got %s", p.Name()) + } + if p.Dim() != DefaultOpenAIDim { + t.Errorf("dim: got %d", p.Dim()) + } + + p2 := NewOpenAI("k", "text-embedding-3-large") + if !strings.HasSuffix(p2.Name(), "text-embedding-3-large") { + t.Errorf("custom model not reflected in name: %s", p2.Name()) + } +} + +func TestFromEnvNoProvider(t *testing.T) { + t.Setenv("OLLAMA_EMBEDDING_URL", "") + t.Setenv("OPENAI_API_KEY", "") + p, err := FromEnv() + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if p != nil { + t.Errorf("expected nil provider, got %v", p) + } +} + +func TestFromEnvOpenAI(t *testing.T) { + t.Setenv("OLLAMA_EMBEDDING_URL", "") + t.Setenv("OPENAI_API_KEY", "sk-test") + t.Setenv("OPENAI_EMBEDDING_MODEL", "custom-model") + p, err := FromEnv() + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if p == nil { + t.Fatal("expected non-nil provider") + } + if !strings.Contains(p.Name(), "custom-model") { + t.Errorf("provider should use configured model: %s", p.Name()) + } +} diff --git a/internal/embed/provider.go b/internal/embed/provider.go new file mode 100644 index 0000000..c02ea11 --- /dev/null +++ b/internal/embed/provider.go @@ -0,0 +1,62 @@ +// Package embed turns text into dense vector embeddings for semantic search. +// +// Providers are pluggable. Users configure one via environment variables. +// FromEnv checks them in this order (first match wins): +// +// 1. Ollama — local / LAN, free: +// OLLAMA_EMBEDDING_URL e.g. http://localhost:11434 +// OLLAMA_EMBEDDING_MODEL e.g. nomic-embed-text (default) +// +// 2. OpenAI — managed, paid: +// OPENAI_API_KEY +// OPENAI_EMBEDDING_MODEL e.g. text-embedding-3-small (default) +// +// If no provider is configured, FromEnv returns (nil, nil) and the server +// runs in BM25-only mode. Callers MUST handle a nil provider as a signal to +// skip embedding rather than erroring. +// +// NOTE: The Solr schema's vectorDimension is static. Ollama's nomic-embed-text +// is 768, mxbai-embed-large is 1024, OpenAI text-embedding-3-small is 1536. +// When switching providers, update solr/managed-schema.xml's vectorDimension +// to match the chosen model and reload the configset before running. +package embed + +import ( + "context" + "os" + "strings" +) + +// Provider produces a dense vector for a piece of text. Implementations +// should be safe for concurrent use. +type Provider interface { + // Embed returns the embedding for text. The returned slice length must + // equal Dim() (after the first successful call) and must be stable + // across calls on the same provider. + Embed(ctx context.Context, text string) ([]float32, error) + + // Dim is the output dimension. Some providers learn this on the first + // successful call (returning 0 until then); others know it upfront. + Dim() int + + // Name is a short identifier for logging / metrics. + Name() string +} + +// FromEnv constructs a provider from environment variables. Returns (nil, nil) +// if no provider is configured — callers should treat nil as "embeddings +// disabled" rather than an error. +// +// Ollama is checked first so a local server preempts an accidentally-present +// OpenAI key (free/private beats paid/remote by default). +func FromEnv() (Provider, error) { + if url := strings.TrimSpace(os.Getenv("OLLAMA_EMBEDDING_URL")); url != "" { + model := strings.TrimSpace(os.Getenv("OLLAMA_EMBEDDING_MODEL")) + return NewOllama(url, model), nil + } + if key := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")); key != "" { + model := strings.TrimSpace(os.Getenv("OPENAI_EMBEDDING_MODEL")) + return NewOpenAI(key, model), nil + } + return nil, nil +} diff --git a/internal/privacy/scrub.go b/internal/privacy/scrub.go new file mode 100644 index 0000000..3fb38f9 --- /dev/null +++ b/internal/privacy/scrub.go @@ -0,0 +1,130 @@ +// Package privacy scrubs secrets from strings before they are stored as memory +// content. It is conservative: a missed redaction is preferred over mangling +// legitimate content, so patterns are anchored and stand-alone. +package privacy + +import ( + "encoding/json" + "regexp" + "sort" +) + +// Result is the outcome of a Scrub call. +type Result struct { + // Content is the input string with any matched secrets replaced. + Content string + // Hits maps pattern name -> number of matches redacted. + Hits map[string]int +} + +// Count returns the total number of redactions applied. +func (r Result) Count() int { + n := 0 + for _, v := range r.Hits { + n += v + } + return n +} + +// Kinds returns a sorted list of the pattern names that matched. +func (r Result) Kinds() []string { + out := make([]string, 0, len(r.Hits)) + for k := range r.Hits { + out = append(out, k) + } + sort.Strings(out) + return out +} + +type pattern struct { + name string + re *regexp.Regexp + // replace overrides the default "[REDACTED:]" template. + // Use Go regexp replacement syntax ($1, $2, ...) if needed. + replace string +} + +// Patterns are applied in order. Multi-line / block patterns run first so +// they subsume any single-line keys that would otherwise match inside the +// block. The sk-ant- pattern runs before the generic sk- OpenAI pattern so +// Anthropic keys aren't mislabeled. +var patterns = []pattern{ + {name: "private_key_block", re: regexp.MustCompile(`(?s)-----BEGIN (?:RSA |EC |DSA |OPENSSH |PGP )?PRIVATE KEY(?: BLOCK)?-----.*?-----END[^-]*-----`)}, + {name: "private_tag", re: regexp.MustCompile(`(?is).*?`)}, + {name: "secret_tag", re: regexp.MustCompile(`(?is).*?`)}, + {name: "anthropic_key", re: regexp.MustCompile(`sk-ant-[A-Za-z0-9_\-]{20,}`)}, + {name: "openai_key", re: regexp.MustCompile(`\bsk-[A-Za-z0-9]{48}\b`)}, + {name: "github_pat", re: regexp.MustCompile(`\bgithub_pat_[A-Za-z0-9_]{82}\b`)}, + {name: "github_token", re: regexp.MustCompile(`\bghp_[A-Za-z0-9]{36}\b`)}, + {name: "aws_access_key", re: regexp.MustCompile(`\bAKIA[0-9A-Z]{16}\b`)}, + {name: "aws_secret_key", re: regexp.MustCompile(`(?i)aws_secret_access_key\s*[=:]\s*[A-Za-z0-9/+=]{40}`)}, + {name: "slack_token", re: regexp.MustCompile(`\bxox[baprs]-[A-Za-z0-9\-]+`)}, + {name: "bearer_token", re: regexp.MustCompile(`(?i)bearer\s+[A-Za-z0-9_\-\.=]{20,}`)}, + {name: "url_creds", re: regexp.MustCompile(`(https?://)[^:/@\s]+:[^@\s]+@`), replace: "${1}[REDACTED:url_creds]@"}, +} + +// Scrub replaces recognized secrets in s with [REDACTED:] and returns +// the scrubbed content plus a tally of what was hit. +func Scrub(s string) Result { + res := Result{Content: s, Hits: map[string]int{}} + if s == "" { + return res + } + for _, p := range patterns { + matches := p.re.FindAllStringIndex(res.Content, -1) + if len(matches) == 0 { + continue + } + res.Hits[p.name] = len(matches) + replace := p.replace + if replace == "" { + replace = "[REDACTED:" + p.name + "]" + } + res.Content = p.re.ReplaceAllString(res.Content, replace) + } + return res +} + +// MergeHits combines two hit maps into a new one. +func MergeHits(a, b map[string]int) map[string]int { + out := make(map[string]int, len(a)+len(b)) + for k, v := range a { + out[k] += v + } + for k, v := range b { + out[k] += v + } + return out +} + +// MergeMetadata augments an existing metadata JSON string with scrub_count and +// scrub_kinds fields. If existing is empty or invalid JSON, a fresh object is +// produced. If the result has no hits, existing is returned unchanged. +func MergeMetadata(existing string, hits map[string]int) string { + total := 0 + kinds := make([]string, 0, len(hits)) + for k, v := range hits { + total += v + kinds = append(kinds, k) + } + if total == 0 { + return existing + } + sort.Strings(kinds) + + var obj map[string]any + if existing != "" { + _ = json.Unmarshal([]byte(existing), &obj) + } + if obj == nil { + obj = map[string]any{} + } + obj["scrub_count"] = total + obj["scrub_kinds"] = kinds + + b, err := json.Marshal(obj) + if err != nil { + return existing + } + return string(b) +} diff --git a/internal/privacy/scrub_test.go b/internal/privacy/scrub_test.go new file mode 100644 index 0000000..8248cf5 --- /dev/null +++ b/internal/privacy/scrub_test.go @@ -0,0 +1,249 @@ +package privacy + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestScrubAnthropicKey(t *testing.T) { + // Anthropic keys start with sk-ant-. + in := "My key is sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA done" + r := Scrub(in) + if strings.Contains(r.Content, "sk-ant-") { + t.Errorf("expected sk-ant- to be redacted, got: %s", r.Content) + } + if !strings.Contains(r.Content, "[REDACTED:anthropic_key]") { + t.Errorf("expected REDACTED marker, got: %s", r.Content) + } + if r.Hits["anthropic_key"] != 1 { + t.Errorf("expected 1 anthropic_key hit, got %d", r.Hits["anthropic_key"]) + } +} + +func TestScrubAnthropicBeforeOpenAI(t *testing.T) { + // sk-ant- keys must not be labeled as OpenAI keys. + in := "sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + r := Scrub(in) + if r.Hits["openai_key"] != 0 { + t.Errorf("anthropic key should not match openai pattern, got %+v", r.Hits) + } + if r.Hits["anthropic_key"] != 1 { + t.Errorf("expected anthropic match, got %+v", r.Hits) + } +} + +func TestScrubOpenAIKey(t *testing.T) { + // OpenAI keys are sk- followed by exactly 48 alphanumeric chars. + in := "use sk-" + strings.Repeat("a", 48) + " done" + r := Scrub(in) + if !strings.Contains(r.Content, "[REDACTED:openai_key]") { + t.Errorf("expected openai redaction, got: %s", r.Content) + } +} + +func TestScrubGithubToken(t *testing.T) { + in := "token=ghp_1234567890abcdefghijKLMNOPQRSTUVwxyz extra" + r := Scrub(in) + if r.Hits["github_token"] != 1 { + t.Errorf("expected github_token hit, got %+v", r.Hits) + } + if !strings.Contains(r.Content, "[REDACTED:github_token]") { + t.Errorf("expected redaction, got: %s", r.Content) + } +} + +func TestScrubGithubPAT(t *testing.T) { + // github_pat_ has exactly 82 chars of [A-Za-z0-9_] after the prefix. + suffix := strings.Repeat("a", 82) + in := "token: github_pat_" + suffix + r := Scrub(in) + if r.Hits["github_pat"] != 1 { + t.Errorf("expected github_pat hit, got %+v", r.Hits) + } +} + +func TestScrubAWSAccessKey(t *testing.T) { + in := "AWS_KEY=AKIAIOSFODNN7EXAMPLE here" + r := Scrub(in) + if r.Hits["aws_access_key"] != 1 { + t.Errorf("expected aws_access_key hit, got %+v", r.Hits) + } +} + +func TestScrubAWSSecretKey(t *testing.T) { + in := "aws_secret_access_key=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + r := Scrub(in) + if r.Hits["aws_secret_key"] != 1 { + t.Errorf("expected aws_secret_key hit, got %+v", r.Hits) + } +} + +func TestScrubSlackToken(t *testing.T) { + in := "slack=xoxb-12345-67890-abcdef leaks" + r := Scrub(in) + if r.Hits["slack_token"] != 1 { + t.Errorf("expected slack_token hit, got %+v", r.Hits) + } +} + +func TestScrubBearerToken(t *testing.T) { + in := "Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.abc" + r := Scrub(in) + if r.Hits["bearer_token"] != 1 { + t.Errorf("expected bearer_token hit, got %+v", r.Hits) + } + if strings.Contains(r.Content, "eyJ") { + t.Errorf("JWT should be redacted, got: %s", r.Content) + } +} + +func TestScrubPrivateKeyBlock(t *testing.T) { + in := "note:\n-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQEAvZ\nnextline\n-----END RSA PRIVATE KEY-----\nend" + r := Scrub(in) + if r.Hits["private_key_block"] != 1 { + t.Errorf("expected private_key_block hit, got %+v", r.Hits) + } + if strings.Contains(r.Content, "MIIEowIBAAK") { + t.Errorf("key body should be gone, got: %s", r.Content) + } +} + +func TestScrubURLCreds(t *testing.T) { + in := "connect https://admin:s3cr3t@db.example.com:5432/foo" + r := Scrub(in) + if r.Hits["url_creds"] != 1 { + t.Errorf("expected url_creds hit, got %+v", r.Hits) + } + if !strings.Contains(r.Content, "https://[REDACTED:url_creds]@db.example.com") { + t.Errorf("expected host preserved with redacted creds, got: %s", r.Content) + } +} + +func TestScrubPrivateTag(t *testing.T) { + in := "before hidden stuff after" + r := Scrub(in) + if r.Hits["private_tag"] != 1 { + t.Errorf("expected private_tag hit, got %+v", r.Hits) + } + if strings.Contains(r.Content, "hidden stuff") { + t.Errorf("private content should be gone, got: %s", r.Content) + } +} + +func TestScrubSecretTag(t *testing.T) { + in := "foo bar" + r := Scrub(in) + if r.Hits["secret_tag"] != 1 { + t.Errorf("expected secret_tag hit, got %+v", r.Hits) + } +} + +func TestScrubEmpty(t *testing.T) { + r := Scrub("") + if r.Count() != 0 { + t.Errorf("empty input should have no hits") + } +} + +func TestScrubNoMatches(t *testing.T) { + in := "this is plain text with no secrets" + r := Scrub(in) + if r.Content != in { + t.Errorf("no matches should leave content unchanged") + } + if r.Count() != 0 { + t.Errorf("no matches should have zero count") + } +} + +func TestScrubMultipleKinds(t *testing.T) { + in := "AKIAIOSFODNN7EXAMPLE and ghp_1234567890abcdefghijKLMNOPQRSTUVwxyz" + r := Scrub(in) + if r.Hits["aws_access_key"] != 1 || r.Hits["github_token"] != 1 { + t.Errorf("expected both hits, got %+v", r.Hits) + } + if r.Count() != 2 { + t.Errorf("expected total count 2, got %d", r.Count()) + } + if got := r.Kinds(); len(got) != 2 || got[0] != "aws_access_key" || got[1] != "github_token" { + t.Errorf("expected sorted kinds, got %v", got) + } +} + +func TestScrubIdempotent(t *testing.T) { + // Scrubbing already-scrubbed content should be a no-op. + in := "sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + first := Scrub(in) + second := Scrub(first.Content) + if second.Count() != 0 { + t.Errorf("re-scrubbing should produce no hits, got %+v", second.Hits) + } + if second.Content != first.Content { + t.Errorf("re-scrubbing changed content: %q -> %q", first.Content, second.Content) + } +} + +func TestMergeMetadataFresh(t *testing.T) { + hits := map[string]int{"anthropic_key": 2, "github_token": 1} + got := MergeMetadata("", hits) + var obj map[string]any + if err := json.Unmarshal([]byte(got), &obj); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if c, _ := obj["scrub_count"].(float64); int(c) != 3 { + t.Errorf("expected scrub_count=3, got %v", obj["scrub_count"]) + } + kinds, _ := obj["scrub_kinds"].([]any) + if len(kinds) != 2 { + t.Errorf("expected 2 kinds, got %v", kinds) + } +} + +func TestMergeMetadataExisting(t *testing.T) { + existing := `{"source":"chat","importance":0.7}` + hits := map[string]int{"slack_token": 1} + got := MergeMetadata(existing, hits) + var obj map[string]any + if err := json.Unmarshal([]byte(got), &obj); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if obj["source"] != "chat" { + t.Errorf("existing keys should be preserved, got %v", obj) + } + if c, _ := obj["scrub_count"].(float64); int(c) != 1 { + t.Errorf("expected scrub_count=1, got %v", obj["scrub_count"]) + } +} + +func TestMergeMetadataNoHits(t *testing.T) { + existing := `{"foo":"bar"}` + got := MergeMetadata(existing, map[string]int{}) + if got != existing { + t.Errorf("no hits should leave metadata untouched, got %q", got) + } +} + +func TestMergeMetadataInvalidExisting(t *testing.T) { + // Garbage existing metadata should be replaced by a fresh object. + got := MergeMetadata("not-json", map[string]int{"github_token": 1}) + var obj map[string]any + if err := json.Unmarshal([]byte(got), &obj); err != nil { + t.Fatalf("expected valid JSON: %v", err) + } + if _, ok := obj["scrub_count"]; !ok { + t.Errorf("expected scrub_count key, got %v", obj) + } +} + +func TestMergeHits(t *testing.T) { + a := map[string]int{"anthropic_key": 1, "github_token": 1} + b := map[string]int{"anthropic_key": 2, "slack_token": 1} + got := MergeHits(a, b) + if got["anthropic_key"] != 3 { + t.Errorf("expected sum for overlapping key, got %d", got["anthropic_key"]) + } + if got["github_token"] != 1 || got["slack_token"] != 1 { + t.Errorf("expected non-overlapping keys kept, got %+v", got) + } +} diff --git a/internal/retrieval/rrf.go b/internal/retrieval/rrf.go new file mode 100644 index 0000000..ca2c6b5 --- /dev/null +++ b/internal/retrieval/rrf.go @@ -0,0 +1,126 @@ +// Package retrieval contains the search-side primitives used by the +// solr-mem server to combine multiple retrieval signals. +package retrieval + +// Stream is one ranked list of document IDs from a single retrieval source +// (e.g. BM25 keyword match, dense vector KNN, graph walk). Order matters: +// position 0 is the top-ranked hit. +type Stream struct { + Name string + IDs []string +} + +// RRFOption configures the fuser. +type RRFOption func(*rrfConfig) + +type rrfConfig struct { + k int + topK int + weight map[string]float64 +} + +// RRFWithK sets the rank-damping constant. Larger k compresses scores across +// streams, making them count more evenly. The canonical default is 60 (per +// the original Cormack et al. paper). +func RRFWithK(k int) RRFOption { + return func(c *rrfConfig) { + if k > 0 { + c.k = k + } + } +} + +// RRFWithTopK truncates the returned fused list. 0 means "return everything +// that appeared in any stream". +func RRFWithTopK(n int) RRFOption { + return func(c *rrfConfig) { + c.topK = n + } +} + +// RRFWithWeight overrides the default 1.0 weight for a named stream. Useful +// if one signal is known to be higher-quality for a given query class. +// Streams not in the map keep weight 1.0. +func RRFWithWeight(weights map[string]float64) RRFOption { + return func(c *rrfConfig) { + c.weight = weights + } +} + +// RRFScored pairs an ID with its fused score. Sorted high-to-low on return +// from Fuse. +type RRFScored struct { + ID string + Score float64 +} + +// Fuse combines multiple ranked streams with Reciprocal Rank Fusion: +// +// score(d) = sum over streams i of weight_i / (k + rank_i(d)) +// +// rank is 1-indexed. Streams with fewer results than another don't penalize +// the missing IDs — they simply don't contribute. The returned list is +// sorted by fused score descending; ties break on lexicographic ID for +// deterministic output. +func Fuse(streams []Stream, opts ...RRFOption) []RRFScored { + cfg := rrfConfig{k: 60} + for _, opt := range opts { + opt(&cfg) + } + + scores := make(map[string]float64) + for _, s := range streams { + w := 1.0 + if cfg.weight != nil { + if v, ok := cfg.weight[s.Name]; ok { + w = v + } + } + for rank, id := range s.IDs { + if id == "" { + continue + } + scores[id] += w / float64(cfg.k+rank+1) + } + } + + out := make([]RRFScored, 0, len(scores)) + for id, sc := range scores { + out = append(out, RRFScored{ID: id, Score: sc}) + } + // Sort by score desc, then ID asc for determinism. + sortFused(out) + + if cfg.topK > 0 && len(out) > cfg.topK { + out = out[:cfg.topK] + } + return out +} + +// FuseIDs is a thin wrapper around Fuse that returns just the ranked IDs. +func FuseIDs(streams []Stream, opts ...RRFOption) []string { + scored := Fuse(streams, opts...) + ids := make([]string, len(scored)) + for i, s := range scored { + ids[i] = s.ID + } + return ids +} + +// sortFused sorts in-place: descending score, ascending ID for ties. +func sortFused(xs []RRFScored) { + // Simple selection sort is fine for small N; most memory queries return + // <= 50 hits. Using sort.Slice would pull in sort; inlined for zero-dep. + for i := 0; i < len(xs); i++ { + best := i + for j := i + 1; j < len(xs); j++ { + if xs[j].Score > xs[best].Score || + (xs[j].Score == xs[best].Score && xs[j].ID < xs[best].ID) { + best = j + } + } + if best != i { + xs[i], xs[best] = xs[best], xs[i] + } + } +} diff --git a/internal/retrieval/rrf_test.go b/internal/retrieval/rrf_test.go new file mode 100644 index 0000000..46a6f45 --- /dev/null +++ b/internal/retrieval/rrf_test.go @@ -0,0 +1,159 @@ +package retrieval + +import ( + "math" + "testing" +) + +func approx(a, b float64) bool { return math.Abs(a-b) < 1e-9 } + +func TestFuseSingleStreamPreservesOrder(t *testing.T) { + streams := []Stream{ + {Name: "bm25", IDs: []string{"a", "b", "c"}}, + } + got := FuseIDs(streams) + want := []string{"a", "b", "c"} + if !equal(got, want) { + t.Errorf("single stream order: got %v, want %v", got, want) + } +} + +func TestFuseBothStreamsAgreeTopGoesFirst(t *testing.T) { + streams := []Stream{ + {Name: "bm25", IDs: []string{"a", "b", "c"}}, + {Name: "vec", IDs: []string{"a", "b", "c"}}, + } + got := FuseIDs(streams) + if len(got) != 3 || got[0] != "a" { + t.Errorf("full agreement should put a first, got %v", got) + } +} + +func TestFuseCrossStreamBoostsHitsInBoth(t *testing.T) { + // c appears in both but never at the top; a appears once; b appears once. + // Canonical RRF with k=60: score(c) = 1/62 + 1/62 = ~0.0323 + // score(a) = 1/61 = ~0.0164 + // score(b) = 1/61 = ~0.0164 + // So c should rank first. + streams := []Stream{ + {Name: "bm25", IDs: []string{"a", "c"}}, + {Name: "vec", IDs: []string{"b", "c"}}, + } + got := FuseIDs(streams) + if got[0] != "c" { + t.Errorf("cross-stream hit should win: got %v", got) + } +} + +func TestFuseKnownScores(t *testing.T) { + // Verify exact RRF formula with k=10 for cleaner numbers. + streams := []Stream{ + {Name: "s1", IDs: []string{"x"}}, + {Name: "s2", IDs: []string{"x"}}, + } + got := Fuse(streams, RRFWithK(10)) + if len(got) != 1 { + t.Fatalf("expected 1 result, got %d", len(got)) + } + // score(x) = 1/(10+1) + 1/(10+1) = 2/11 + want := 2.0 / 11.0 + if !approx(got[0].Score, want) { + t.Errorf("expected score %f, got %f", want, got[0].Score) + } +} + +func TestFuseWeights(t *testing.T) { + // Equal-rank hits should split by weight. + streams := []Stream{ + {Name: "high", IDs: []string{"a"}}, + {Name: "low", IDs: []string{"b"}}, + } + got := Fuse(streams, RRFWithK(10), RRFWithWeight(map[string]float64{"high": 2.0, "low": 0.5})) + var scoreA, scoreB float64 + for _, s := range got { + if s.ID == "a" { + scoreA = s.Score + } + if s.ID == "b" { + scoreB = s.Score + } + } + // score(a) = 2.0/11 = 0.1818..., score(b) = 0.5/11 = 0.0454... + if scoreA <= scoreB { + t.Errorf("higher weight should produce higher score: a=%f b=%f", scoreA, scoreB) + } + if !approx(scoreA, 2.0/11.0) { + t.Errorf("expected score(a)=2/11, got %f", scoreA) + } +} + +func TestFuseTopK(t *testing.T) { + streams := []Stream{ + {Name: "s", IDs: []string{"a", "b", "c", "d", "e"}}, + } + got := FuseIDs(streams, RRFWithTopK(3)) + if len(got) != 3 { + t.Errorf("expected 3, got %d", len(got)) + } +} + +func TestFuseEmptyInput(t *testing.T) { + if got := FuseIDs(nil); len(got) != 0 { + t.Errorf("nil input should produce empty, got %v", got) + } + if got := FuseIDs([]Stream{{Name: "s"}}); len(got) != 0 { + t.Errorf("empty stream should produce empty, got %v", got) + } +} + +func TestFuseIgnoresEmptyIDs(t *testing.T) { + streams := []Stream{ + {Name: "s", IDs: []string{"a", "", "b"}}, + } + got := FuseIDs(streams) + // Empty IDs are skipped; remaining ranks keep their original positions. + // So "a" is rank 0, "b" is rank 2 (even though "a" is next to it). + if len(got) != 2 || got[0] != "a" || got[1] != "b" { + t.Errorf("empty IDs should be skipped: got %v", got) + } +} + +func TestFuseDeterministicTieBreak(t *testing.T) { + // Two IDs with identical scores should sort by ID ascending for stable output. + streams := []Stream{ + {Name: "s1", IDs: []string{"zebra"}}, + {Name: "s2", IDs: []string{"apple"}}, + } + got := FuseIDs(streams) + if len(got) != 2 || got[0] != "apple" { + t.Errorf("tie break should favor lexicographically-smaller ID: got %v", got) + } +} + +func TestFuseMissingFromOneStream(t *testing.T) { + // Asymmetric streams: "only-a" appears only in s1. Its score should + // come only from that stream's rank. + streams := []Stream{ + {Name: "s1", IDs: []string{"only-a", "shared"}}, + {Name: "s2", IDs: []string{"shared", "only-b"}}, + } + got := Fuse(streams, RRFWithK(10)) + // shared is rank 1 in both: 1/11 + 1/12 = ~0.174 + // only-a is rank 0 in s1: 1/11 = 0.0909 + // only-b is rank 1 in s2: 1/12 = 0.0833 + if got[0].ID != "shared" { + t.Errorf("shared should rank first: got %v", got) + } +} + +func equal(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/solr/types.go b/internal/solr/types.go index c32de05..ae21e80 100644 --- a/internal/solr/types.go +++ b/internal/solr/types.go @@ -17,9 +17,11 @@ type Document struct { UpdatedAt time.Time `json:"updated_at"` ExpiresAt string `json:"expires_at,omitempty"` Lifetime string `json:"lifetime,omitempty"` - SessionID string `json:"session_id,omitempty"` - RelatedIDs []string `json:"related_ids,omitempty"` - Format string `json:"format,omitempty"` + SessionID string `json:"session_id,omitempty"` + RelatedIDs []string `json:"related_ids,omitempty"` + Format string `json:"format,omitempty"` + ContentHash string `json:"content_hash,omitempty"` + Embedding []float32 `json:"embedding,omitempty"` } // QueryParams holds parameters for a Solr search query. diff --git a/solr/managed-schema.xml b/solr/managed-schema.xml index 0de096e..fb706b6 100644 --- a/solr/managed-schema.xml +++ b/solr/managed-schema.xml @@ -9,6 +9,20 @@ + + + @@ -48,5 +62,7 @@ + +