From 4ebb8da5ac831c44c52e250981c6f5599d4dac6f Mon Sep 17 00:00:00 2001 From: arreyder Date: Sat, 18 Apr 2026 10:12:30 -0500 Subject: [PATCH 1/8] Cap results per session in broker packets and search_memories Adds a generic session-diversify helper and wires it into both the broker packet builder (hardcoded cap=2) and search_memories (session_cap arg, default 3, 0 disables). Prevents one chatty session from starving other relevant context from top-K results. Closes #13. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/solr-mem-server/broker.go | 18 ++++- cmd/solr-mem-server/diversify.go | 28 ++++++++ cmd/solr-mem-server/diversify_test.go | 96 +++++++++++++++++++++++++++ cmd/solr-mem-server/search_tool.go | 10 +++ cmd/solr-mem-server/tools.go | 3 +- 5 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 cmd/solr-mem-server/diversify.go create mode 100644 cmd/solr-mem-server/diversify_test.go diff --git a/cmd/solr-mem-server/broker.go b/cmd/solr-mem-server/broker.go index 7a55270..ad77584 100644 --- a/cmd/solr-mem-server/broker.go +++ b/cmd/solr-mem-server/broker.go @@ -36,6 +36,11 @@ const ( // defaultRunSweeperInterval is how often the stale-run sweeper runs. defaultRunSweeperInterval = 5 * time.Minute + + // brokerSessionCap is the max number of memory items per session_id in a + // single packet. Code items (no session) are unaffected. With maxItems=5 + // this leaves room for at least 3 distinct sessions (or a mix with code). + brokerSessionCap = 2 ) // WorkObservation is a single structured report from a worker agent. @@ -65,6 +70,7 @@ type MemoryPacketItem struct { MemoryType string `json:"memory_type,omitempty"` FilePath string `json:"file_path,omitempty"` SymbolName string `json:"symbol_name,omitempty"` + SessionID string `json:"session_id,omitempty"` // populated for memory items; used for session-diversified ranking } // MemoryPacket is a precomputed bundle of relevant context for a worker agent. @@ -320,13 +326,17 @@ func (b *Broker) doBuild(obs WorkObservation, buildSeq int) *MemoryPacket { // 3. Score and dedupe. scored := scoreCandidates(candidates, obs) - // 4. Pick top items (max 5). + // 4. Cap per session so one chatty session can't dominate the packet. + // Code items have empty SessionID and are unaffected. + scored = diversifyBySession(scored, func(i MemoryPacketItem) string { return i.SessionID }, brokerSessionCap) + + // 5. Pick top items (max 5). const maxItems = 5 if len(scored) > maxItems { scored = scored[:maxItems] } - // 5. Determine delivery class. + // 6. Determine delivery class. delivery := DeliveryCheckpoint for _, item := range scored { // Promote to interrupt if we found a high-relevance hazard or prior solution. @@ -336,7 +346,7 @@ func (b *Broker) doBuild(obs WorkObservation, buildSeq int) *MemoryPacket { } } - // 6. Build summary. + // 7. Build summary. summary := buildPacketSummary(scored, obs) return &MemoryPacket{ @@ -371,6 +381,7 @@ func (b *Broker) searchMemories(ctx context.Context, queryTerms string, obs Work title, _ := doc["title"].(string) content, _ := doc["content"].(string) memType, _ := doc["memory_type"].(string) + sessionID, _ := doc["session_id"].(string) tags := getStringSliceFromDoc(doc, "tags") // Truncate content for summary. @@ -388,6 +399,7 @@ func (b *Broker) searchMemories(ctx context.Context, queryTerms string, obs Work Reason: "memory search hit", Tags: tags, MemoryType: memType, + SessionID: sessionID, }) } return items diff --git a/cmd/solr-mem-server/diversify.go b/cmd/solr-mem-server/diversify.go new file mode 100644 index 0000000..2d4a71c --- /dev/null +++ b/cmd/solr-mem-server/diversify.go @@ -0,0 +1,28 @@ +package main + +// diversifyBySession caps the number of items per session in a ranked slice. +// Items with an empty group key are passed through without counting — callers +// use this for hits that have no natural grouping (e.g. code docs without a +// session_id). +// +// cap <= 0 disables diversification and returns the input unchanged. +func diversifyBySession[T any](items []T, key func(T) string, cap int) []T { + if cap <= 0 || len(items) == 0 { + return items + } + counts := make(map[string]int, len(items)) + out := make([]T, 0, len(items)) + for _, item := range items { + k := key(item) + if k == "" { + out = append(out, item) + continue + } + if counts[k] >= cap { + continue + } + counts[k]++ + out = append(out, item) + } + return out +} diff --git a/cmd/solr-mem-server/diversify_test.go b/cmd/solr-mem-server/diversify_test.go new file mode 100644 index 0000000..668efbc --- /dev/null +++ b/cmd/solr-mem-server/diversify_test.go @@ -0,0 +1,96 @@ +package main + +import "testing" + +type divItem struct { + id string + session string +} + +func ids(items []divItem) []string { + out := make([]string, len(items)) + for i, it := range items { + out[i] = it.id + } + return out +} + +func sessionKey(it divItem) string { return it.session } + +func TestDiversifyBySessionCapsPerSession(t *testing.T) { + in := []divItem{ + {"a", "s1"}, + {"b", "s1"}, + {"c", "s1"}, // should be dropped at cap=2 + {"d", "s2"}, + {"e", "s1"}, // dropped + {"f", "s2"}, + {"g", "s3"}, + } + got := diversifyBySession(in, sessionKey, 2) + want := []string{"a", "b", "d", "f", "g"} + if !equalStrings(ids(got), want) { + t.Errorf("cap=2: got %v, want %v", ids(got), want) + } +} + +func TestDiversifyBySessionPreservesOrder(t *testing.T) { + // Top-ranked items should be kept over later same-session items. + in := []divItem{ + {"top", "s1"}, + {"mid", "s2"}, + {"also1", "s1"}, + {"low1", "s1"}, // dropped at cap=2 + {"low2", "s1"}, // dropped + } + got := diversifyBySession(in, sessionKey, 2) + want := []string{"top", "mid", "also1"} + if !equalStrings(ids(got), want) { + t.Errorf("order: got %v, want %v", ids(got), want) + } +} + +func TestDiversifyBySessionEmptyKeyNotCapped(t *testing.T) { + // Items with empty key (e.g. code docs) pass through without counting. + in := []divItem{ + {"a", ""}, + {"b", ""}, + {"c", ""}, + {"d", "s1"}, + {"e", "s1"}, + {"f", "s1"}, // dropped + } + got := diversifyBySession(in, sessionKey, 2) + want := []string{"a", "b", "c", "d", "e"} + if !equalStrings(ids(got), want) { + t.Errorf("empty key: got %v, want %v", ids(got), want) + } +} + +func TestDiversifyBySessionDisabled(t *testing.T) { + in := []divItem{{"a", "s1"}, {"b", "s1"}, {"c", "s1"}} + if got := diversifyBySession(in, sessionKey, 0); len(got) != 3 { + t.Errorf("cap=0 should passthrough, got %d items", len(got)) + } + if got := diversifyBySession(in, sessionKey, -1); len(got) != 3 { + t.Errorf("cap<0 should passthrough, got %d items", len(got)) + } +} + +func TestDiversifyBySessionEmpty(t *testing.T) { + if got := diversifyBySession[divItem](nil, sessionKey, 2); got != nil { + t.Errorf("nil input: got %v, want nil", got) + } +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/cmd/solr-mem-server/search_tool.go b/cmd/solr-mem-server/search_tool.go index a18ae80..2efee40 100644 --- a/cmd/solr-mem-server/search_tool.go +++ b/cmd/solr-mem-server/search_tool.go @@ -74,6 +74,16 @@ func searchMemoriesTool(ctx context.Context, args map[string]any) (any, error) { return nil, fmt.Errorf("search failed: %w", err) } + // Cap per session so one chatty session can't dominate results. + // Default 3; pass 0 to disable. + sessionCap := getInt(args, "session_cap", 3) + if sessionCap > 0 { + resp.Docs = diversifyBySession(resp.Docs, func(d map[string]any) string { + s, _ := d["session_id"].(string) + return s + }, sessionCap) + } + return ToolOutput{ Text: formatSearchResults(resp), Structured: resp, diff --git a/cmd/solr-mem-server/tools.go b/cmd/solr-mem-server/tools.go index 179eceb..3249bab 100644 --- a/cmd/solr-mem-server/tools.go +++ b/cmd/solr-mem-server/tools.go @@ -57,7 +57,7 @@ func ToolSchemas() []ToolDefinition { **When to use**: Find relevant memories by content, tags, type, agent, or time range. Uses edismax with field boosting (content^3, title^2, tags^1.5) and recency boost. **Required**: query (search text) -**Optional**: agent_id, memory_type, tags, source, importance_min, from, to, limit, highlight, facet, session_id, lifetime`, +**Optional**: agent_id, memory_type, tags, source, importance_min, from, to, limit, highlight, facet, session_id, lifetime, session_cap`, InputSchema: NewObjectSchema(map[string]any{ "query": prop("string", "Full-text search query (required)"), "agent_id": prop("string", "Filter by agent ID"), @@ -72,6 +72,7 @@ func ToolSchemas() []ToolDefinition { "facet": prop("boolean", "Include facet counts (default: false)"), "session_id": prop("string", "Filter by session ID"), "lifetime": prop("string", "Filter by lifetime (permanent, session, ephemeral, temporary)"), + "session_cap": integerProp("Max hits per session_id after ranking (default 3, 0 = disable)", intPtr(0), intPtr(100)), }, "query"), }, Handler: searchMemoriesTool, From a5a9fb58446118c199f15f0fffc20dfa2a7cca64 Mon Sep 17 00:00:00 2001 From: arreyder Date: Sat, 18 Apr 2026 10:20:05 -0500 Subject: [PATCH 2/8] Scrub secrets from memory content before storage Adds internal/privacy.Scrub with patterns for AWS keys, GitHub tokens, Anthropic/OpenAI keys, Slack tokens, bearer tokens, RSA/EC private-key blocks, URLs with embedded credentials, and / tag blocks. Matches are replaced with [REDACTED:] markers and the tally is merged into the memory's metadata as scrub_count/scrub_kinds. Wired into store_memory, bulk_store_memories, update_memory, and observe_work. Can be disabled with SOLR_MEM_PRIVACY_SCRUB=off for trusted corpora. Closes #15. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/solr-mem-server/broker_tool.go | 15 +- cmd/solr-mem-server/bulk_store_tool.go | 35 +++- cmd/solr-mem-server/scrub.go | 42 +++++ cmd/solr-mem-server/store_tool.go | 38 +++- cmd/solr-mem-server/update_tool.go | 18 +- internal/privacy/scrub.go | 130 +++++++++++++ internal/privacy/scrub_test.go | 249 +++++++++++++++++++++++++ 7 files changed, 503 insertions(+), 24 deletions(-) create mode 100644 cmd/solr-mem-server/scrub.go create mode 100644 internal/privacy/scrub.go create mode 100644 internal/privacy/scrub_test.go diff --git a/cmd/solr-mem-server/broker_tool.go b/cmd/solr-mem-server/broker_tool.go index 4b5de62..0f1ee69 100644 --- a/cmd/solr-mem-server/broker_tool.go +++ b/cmd/solr-mem-server/broker_tool.go @@ -14,17 +14,24 @@ func observeWorkTool(broker *Broker) ToolHandler { return nil, fmt.Errorf("run_id is required") } + // Scrub free-text fields; entities/code_refs/repo/phase are symbolic + // and unlikely to carry secrets. + task, _ := scrubString(getString(args, "task")) + subgoal, _ := scrubString(getString(args, "subgoal")) + uncertainty, _ := scrubString(getString(args, "uncertainty")) + nextAction, _ := scrubString(getString(args, "next_action")) + obs := WorkObservation{ RunID: runID, AgentID: getString(args, "agent_id"), Repo: getString(args, "repo"), Phase: getString(args, "phase"), - Task: getString(args, "task"), - Subgoal: getString(args, "subgoal"), + Task: task, + Subgoal: subgoal, Entities: getStringSlice(args, "entities"), CodeRefs: getStringSlice(args, "code_refs"), - Uncertainty: getString(args, "uncertainty"), - NextAction: getString(args, "next_action"), + Uncertainty: uncertainty, + NextAction: nextAction, } result := broker.Observe(obs) diff --git a/cmd/solr-mem-server/bulk_store_tool.go b/cmd/solr-mem-server/bulk_store_tool.go index df8ebfb..6f64a75 100644 --- a/cmd/solr-mem-server/bulk_store_tool.go +++ b/cmd/solr-mem-server/bulk_store_tool.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/arreyder/solr-mem/internal/privacy" "github.com/arreyder/solr-mem/internal/solr" "github.com/google/uuid" ) @@ -26,6 +27,7 @@ func bulkStoreMemoriesTool(ctx context.Context, args map[string]any) (any, error now := time.Now().UTC() var docs []solr.Document var ids []string + totalScrubbed := 0 for i, raw := range memories { m, ok := raw.(map[string]any) @@ -49,16 +51,26 @@ func bulkStoreMemoriesTool(ctx context.Context, args map[string]any) (any, error format = "prose" } + title := getString(m, "title") + metadata := getString(m, "metadata") + scrubbedContent, contentHits := scrubString(content) + scrubbedTitle, titleHits := scrubString(title) + allHits := privacy.MergeHits(contentHits, titleHits) + metadata = privacy.MergeMetadata(metadata, allHits) + for _, v := range allHits { + totalScrubbed += v + } + docs = append(docs, solr.Document{ ID: id, AgentID: getString(m, "agent_id"), MemoryType: getString(m, "memory_type"), - Content: content, - Title: getString(m, "title"), + Content: scrubbedContent, + Title: scrubbedTitle, Tags: getStringSlice(m, "tags"), Source: getString(m, "source"), Importance: getFloat(m, "importance", 0.5), - Metadata: getString(m, "metadata"), + Metadata: metadata, CreatedAt: now, UpdatedAt: now, ExpiresAt: expiresAt, @@ -73,11 +85,18 @@ func bulkStoreMemoriesTool(ctx context.Context, args map[string]any) (any, error return nil, fmt.Errorf("failed to bulk store memories: %w", err) } + text := fmt.Sprintf("Successfully stored %d memories.\nIDs: %v", len(docs), ids) + structured := map[string]any{ + "count": len(docs), + "ids": ids, + } + if totalScrubbed > 0 { + text += fmt.Sprintf("\nPrivacy: redacted %d secret(s) across all memories", totalScrubbed) + structured["scrub_count"] = totalScrubbed + } + return ToolOutput{ - Text: fmt.Sprintf("Successfully stored %d memories.\nIDs: %v", len(docs), ids), - Structured: map[string]any{ - "count": len(docs), - "ids": ids, - }, + Text: text, + Structured: structured, }, nil } diff --git a/cmd/solr-mem-server/scrub.go b/cmd/solr-mem-server/scrub.go new file mode 100644 index 0000000..6d0695e --- /dev/null +++ b/cmd/solr-mem-server/scrub.go @@ -0,0 +1,42 @@ +package main + +import ( + "log" + "os" + "strings" + "sync" + + "github.com/arreyder/solr-mem/internal/privacy" +) + +// privacyScrubDisabled is resolved once from the environment. +var ( + privacyScrubOnce sync.Once + privacyScrubOff bool +) + +// privacyScrubEnabled reports whether secret scrubbing is on. Controlled by +// SOLR_MEM_PRIVACY_SCRUB=off (default: on). +func privacyScrubEnabled() bool { + privacyScrubOnce.Do(func() { + v := strings.ToLower(strings.TrimSpace(os.Getenv("SOLR_MEM_PRIVACY_SCRUB"))) + if v == "off" || v == "false" || v == "0" || v == "disabled" { + privacyScrubOff = true + log.Println("privacy: scrub disabled by SOLR_MEM_PRIVACY_SCRUB") + } + }) + return !privacyScrubOff +} + +// scrubString returns a scrubbed copy of s and the hit map; if scrubbing is +// disabled or s is empty, the hit map is nil. +func scrubString(s string) (string, map[string]int) { + if !privacyScrubEnabled() || s == "" { + return s, nil + } + r := privacy.Scrub(s) + if r.Count() == 0 { + return r.Content, nil + } + return r.Content, r.Hits +} diff --git a/cmd/solr-mem-server/store_tool.go b/cmd/solr-mem-server/store_tool.go index 88ddf47..c17459c 100644 --- a/cmd/solr-mem-server/store_tool.go +++ b/cmd/solr-mem-server/store_tool.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/arreyder/solr-mem/internal/privacy" "github.com/arreyder/solr-mem/internal/solr" "github.com/google/uuid" ) @@ -24,16 +25,23 @@ func storeMemoryTool(ctx context.Context, args map[string]any) (any, error) { format = "prose" } + title := getString(args, "title") + metadata := getString(args, "metadata") + scrubbedContent, contentHits := scrubString(content) + scrubbedTitle, titleHits := scrubString(title) + allHits := privacy.MergeHits(contentHits, titleHits) + metadata = privacy.MergeMetadata(metadata, allHits) + doc := solr.Document{ ID: uuid.New().String(), AgentID: getString(args, "agent_id"), MemoryType: getString(args, "memory_type"), - Content: content, - Title: getString(args, "title"), + Content: scrubbedContent, + Title: scrubbedTitle, Tags: getStringSlice(args, "tags"), Source: getString(args, "source"), Importance: getFloat(args, "importance", 0.5), - Metadata: getString(args, "metadata"), + Metadata: metadata, CreatedAt: now, UpdatedAt: now, ExpiresAt: expiresAt, @@ -53,13 +61,23 @@ func storeMemoryTool(ctx context.Context, args map[string]any) (any, error) { result += fmt.Sprintf("\nExpires: %s", expiresAt) } + structured := map[string]any{ + "id": doc.ID, + "lifetime": doc.Lifetime, + "expires_at": expiresAt, + "created_at": doc.CreatedAt.Format(time.RFC3339), + } + if n := len(allHits); n > 0 { + total := 0 + for _, v := range allHits { + total += v + } + structured["scrub_count"] = total + result += fmt.Sprintf("\nPrivacy: redacted %d secret(s) across %d kind(s)", total, n) + } + return ToolOutput{ - Text: result, - Structured: map[string]any{ - "id": doc.ID, - "lifetime": doc.Lifetime, - "expires_at": expiresAt, - "created_at": doc.CreatedAt.Format(time.RFC3339), - }, + Text: result, + Structured: structured, }, nil } diff --git a/cmd/solr-mem-server/update_tool.go b/cmd/solr-mem-server/update_tool.go index a8e5b24..26d25af 100644 --- a/cmd/solr-mem-server/update_tool.go +++ b/cmd/solr-mem-server/update_tool.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "time" + + "github.com/arreyder/solr-mem/internal/privacy" ) func updateMemoryTool(ctx context.Context, args map[string]any) (any, error) { @@ -13,12 +15,17 @@ func updateMemoryTool(ctx context.Context, args map[string]any) (any, error) { } fields := make(map[string]any) + scrubHits := map[string]int{} if v := getString(args, "content"); v != "" { - fields["content"] = v + scrubbed, hits := scrubString(v) + fields["content"] = scrubbed + scrubHits = privacy.MergeHits(scrubHits, hits) } if v := getString(args, "title"); v != "" { - fields["title"] = v + scrubbed, hits := scrubString(v) + fields["title"] = scrubbed + scrubHits = privacy.MergeHits(scrubHits, hits) } if v := getString(args, "memory_type"); v != "" { fields["memory_type"] = v @@ -35,6 +42,13 @@ func updateMemoryTool(ctx context.Context, args map[string]any) (any, error) { if v := getString(args, "metadata"); v != "" { fields["metadata"] = v } + // If secrets were scrubbed, merge the tally into metadata (existing + // metadata arg takes precedence over the stored doc's metadata, which + // matches atomic-update semantics). + if len(scrubHits) > 0 { + existing, _ := fields["metadata"].(string) + fields["metadata"] = privacy.MergeMetadata(existing, scrubHits) + } if v := getString(args, "session_id"); v != "" { fields["session_id"] = v } diff --git a/internal/privacy/scrub.go b/internal/privacy/scrub.go new file mode 100644 index 0000000..3fb38f9 --- /dev/null +++ b/internal/privacy/scrub.go @@ -0,0 +1,130 @@ +// Package privacy scrubs secrets from strings before they are stored as memory +// content. It is conservative: a missed redaction is preferred over mangling +// legitimate content, so patterns are anchored and stand-alone. +package privacy + +import ( + "encoding/json" + "regexp" + "sort" +) + +// Result is the outcome of a Scrub call. +type Result struct { + // Content is the input string with any matched secrets replaced. + Content string + // Hits maps pattern name -> number of matches redacted. + Hits map[string]int +} + +// Count returns the total number of redactions applied. +func (r Result) Count() int { + n := 0 + for _, v := range r.Hits { + n += v + } + return n +} + +// Kinds returns a sorted list of the pattern names that matched. +func (r Result) Kinds() []string { + out := make([]string, 0, len(r.Hits)) + for k := range r.Hits { + out = append(out, k) + } + sort.Strings(out) + return out +} + +type pattern struct { + name string + re *regexp.Regexp + // replace overrides the default "[REDACTED:]" template. + // Use Go regexp replacement syntax ($1, $2, ...) if needed. + replace string +} + +// Patterns are applied in order. Multi-line / block patterns run first so +// they subsume any single-line keys that would otherwise match inside the +// block. The sk-ant- pattern runs before the generic sk- OpenAI pattern so +// Anthropic keys aren't mislabeled. +var patterns = []pattern{ + {name: "private_key_block", re: regexp.MustCompile(`(?s)-----BEGIN (?:RSA |EC |DSA |OPENSSH |PGP )?PRIVATE KEY(?: BLOCK)?-----.*?-----END[^-]*-----`)}, + {name: "private_tag", re: regexp.MustCompile(`(?is).*?`)}, + {name: "secret_tag", re: regexp.MustCompile(`(?is).*?`)}, + {name: "anthropic_key", re: regexp.MustCompile(`sk-ant-[A-Za-z0-9_\-]{20,}`)}, + {name: "openai_key", re: regexp.MustCompile(`\bsk-[A-Za-z0-9]{48}\b`)}, + {name: "github_pat", re: regexp.MustCompile(`\bgithub_pat_[A-Za-z0-9_]{82}\b`)}, + {name: "github_token", re: regexp.MustCompile(`\bghp_[A-Za-z0-9]{36}\b`)}, + {name: "aws_access_key", re: regexp.MustCompile(`\bAKIA[0-9A-Z]{16}\b`)}, + {name: "aws_secret_key", re: regexp.MustCompile(`(?i)aws_secret_access_key\s*[=:]\s*[A-Za-z0-9/+=]{40}`)}, + {name: "slack_token", re: regexp.MustCompile(`\bxox[baprs]-[A-Za-z0-9\-]+`)}, + {name: "bearer_token", re: regexp.MustCompile(`(?i)bearer\s+[A-Za-z0-9_\-\.=]{20,}`)}, + {name: "url_creds", re: regexp.MustCompile(`(https?://)[^:/@\s]+:[^@\s]+@`), replace: "${1}[REDACTED:url_creds]@"}, +} + +// Scrub replaces recognized secrets in s with [REDACTED:] and returns +// the scrubbed content plus a tally of what was hit. +func Scrub(s string) Result { + res := Result{Content: s, Hits: map[string]int{}} + if s == "" { + return res + } + for _, p := range patterns { + matches := p.re.FindAllStringIndex(res.Content, -1) + if len(matches) == 0 { + continue + } + res.Hits[p.name] = len(matches) + replace := p.replace + if replace == "" { + replace = "[REDACTED:" + p.name + "]" + } + res.Content = p.re.ReplaceAllString(res.Content, replace) + } + return res +} + +// MergeHits combines two hit maps into a new one. +func MergeHits(a, b map[string]int) map[string]int { + out := make(map[string]int, len(a)+len(b)) + for k, v := range a { + out[k] += v + } + for k, v := range b { + out[k] += v + } + return out +} + +// MergeMetadata augments an existing metadata JSON string with scrub_count and +// scrub_kinds fields. If existing is empty or invalid JSON, a fresh object is +// produced. If the result has no hits, existing is returned unchanged. +func MergeMetadata(existing string, hits map[string]int) string { + total := 0 + kinds := make([]string, 0, len(hits)) + for k, v := range hits { + total += v + kinds = append(kinds, k) + } + if total == 0 { + return existing + } + sort.Strings(kinds) + + var obj map[string]any + if existing != "" { + _ = json.Unmarshal([]byte(existing), &obj) + } + if obj == nil { + obj = map[string]any{} + } + obj["scrub_count"] = total + obj["scrub_kinds"] = kinds + + b, err := json.Marshal(obj) + if err != nil { + return existing + } + return string(b) +} diff --git a/internal/privacy/scrub_test.go b/internal/privacy/scrub_test.go new file mode 100644 index 0000000..8248cf5 --- /dev/null +++ b/internal/privacy/scrub_test.go @@ -0,0 +1,249 @@ +package privacy + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestScrubAnthropicKey(t *testing.T) { + // Anthropic keys start with sk-ant-. + in := "My key is sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA done" + r := Scrub(in) + if strings.Contains(r.Content, "sk-ant-") { + t.Errorf("expected sk-ant- to be redacted, got: %s", r.Content) + } + if !strings.Contains(r.Content, "[REDACTED:anthropic_key]") { + t.Errorf("expected REDACTED marker, got: %s", r.Content) + } + if r.Hits["anthropic_key"] != 1 { + t.Errorf("expected 1 anthropic_key hit, got %d", r.Hits["anthropic_key"]) + } +} + +func TestScrubAnthropicBeforeOpenAI(t *testing.T) { + // sk-ant- keys must not be labeled as OpenAI keys. + in := "sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + r := Scrub(in) + if r.Hits["openai_key"] != 0 { + t.Errorf("anthropic key should not match openai pattern, got %+v", r.Hits) + } + if r.Hits["anthropic_key"] != 1 { + t.Errorf("expected anthropic match, got %+v", r.Hits) + } +} + +func TestScrubOpenAIKey(t *testing.T) { + // OpenAI keys are sk- followed by exactly 48 alphanumeric chars. + in := "use sk-" + strings.Repeat("a", 48) + " done" + r := Scrub(in) + if !strings.Contains(r.Content, "[REDACTED:openai_key]") { + t.Errorf("expected openai redaction, got: %s", r.Content) + } +} + +func TestScrubGithubToken(t *testing.T) { + in := "token=ghp_1234567890abcdefghijKLMNOPQRSTUVwxyz extra" + r := Scrub(in) + if r.Hits["github_token"] != 1 { + t.Errorf("expected github_token hit, got %+v", r.Hits) + } + if !strings.Contains(r.Content, "[REDACTED:github_token]") { + t.Errorf("expected redaction, got: %s", r.Content) + } +} + +func TestScrubGithubPAT(t *testing.T) { + // github_pat_ has exactly 82 chars of [A-Za-z0-9_] after the prefix. + suffix := strings.Repeat("a", 82) + in := "token: github_pat_" + suffix + r := Scrub(in) + if r.Hits["github_pat"] != 1 { + t.Errorf("expected github_pat hit, got %+v", r.Hits) + } +} + +func TestScrubAWSAccessKey(t *testing.T) { + in := "AWS_KEY=AKIAIOSFODNN7EXAMPLE here" + r := Scrub(in) + if r.Hits["aws_access_key"] != 1 { + t.Errorf("expected aws_access_key hit, got %+v", r.Hits) + } +} + +func TestScrubAWSSecretKey(t *testing.T) { + in := "aws_secret_access_key=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + r := Scrub(in) + if r.Hits["aws_secret_key"] != 1 { + t.Errorf("expected aws_secret_key hit, got %+v", r.Hits) + } +} + +func TestScrubSlackToken(t *testing.T) { + in := "slack=xoxb-12345-67890-abcdef leaks" + r := Scrub(in) + if r.Hits["slack_token"] != 1 { + t.Errorf("expected slack_token hit, got %+v", r.Hits) + } +} + +func TestScrubBearerToken(t *testing.T) { + in := "Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.abc" + r := Scrub(in) + if r.Hits["bearer_token"] != 1 { + t.Errorf("expected bearer_token hit, got %+v", r.Hits) + } + if strings.Contains(r.Content, "eyJ") { + t.Errorf("JWT should be redacted, got: %s", r.Content) + } +} + +func TestScrubPrivateKeyBlock(t *testing.T) { + in := "note:\n-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQEAvZ\nnextline\n-----END RSA PRIVATE KEY-----\nend" + r := Scrub(in) + if r.Hits["private_key_block"] != 1 { + t.Errorf("expected private_key_block hit, got %+v", r.Hits) + } + if strings.Contains(r.Content, "MIIEowIBAAK") { + t.Errorf("key body should be gone, got: %s", r.Content) + } +} + +func TestScrubURLCreds(t *testing.T) { + in := "connect https://admin:s3cr3t@db.example.com:5432/foo" + r := Scrub(in) + if r.Hits["url_creds"] != 1 { + t.Errorf("expected url_creds hit, got %+v", r.Hits) + } + if !strings.Contains(r.Content, "https://[REDACTED:url_creds]@db.example.com") { + t.Errorf("expected host preserved with redacted creds, got: %s", r.Content) + } +} + +func TestScrubPrivateTag(t *testing.T) { + in := "before hidden stuff after" + r := Scrub(in) + if r.Hits["private_tag"] != 1 { + t.Errorf("expected private_tag hit, got %+v", r.Hits) + } + if strings.Contains(r.Content, "hidden stuff") { + t.Errorf("private content should be gone, got: %s", r.Content) + } +} + +func TestScrubSecretTag(t *testing.T) { + in := "foo bar" + r := Scrub(in) + if r.Hits["secret_tag"] != 1 { + t.Errorf("expected secret_tag hit, got %+v", r.Hits) + } +} + +func TestScrubEmpty(t *testing.T) { + r := Scrub("") + if r.Count() != 0 { + t.Errorf("empty input should have no hits") + } +} + +func TestScrubNoMatches(t *testing.T) { + in := "this is plain text with no secrets" + r := Scrub(in) + if r.Content != in { + t.Errorf("no matches should leave content unchanged") + } + if r.Count() != 0 { + t.Errorf("no matches should have zero count") + } +} + +func TestScrubMultipleKinds(t *testing.T) { + in := "AKIAIOSFODNN7EXAMPLE and ghp_1234567890abcdefghijKLMNOPQRSTUVwxyz" + r := Scrub(in) + if r.Hits["aws_access_key"] != 1 || r.Hits["github_token"] != 1 { + t.Errorf("expected both hits, got %+v", r.Hits) + } + if r.Count() != 2 { + t.Errorf("expected total count 2, got %d", r.Count()) + } + if got := r.Kinds(); len(got) != 2 || got[0] != "aws_access_key" || got[1] != "github_token" { + t.Errorf("expected sorted kinds, got %v", got) + } +} + +func TestScrubIdempotent(t *testing.T) { + // Scrubbing already-scrubbed content should be a no-op. + in := "sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + first := Scrub(in) + second := Scrub(first.Content) + if second.Count() != 0 { + t.Errorf("re-scrubbing should produce no hits, got %+v", second.Hits) + } + if second.Content != first.Content { + t.Errorf("re-scrubbing changed content: %q -> %q", first.Content, second.Content) + } +} + +func TestMergeMetadataFresh(t *testing.T) { + hits := map[string]int{"anthropic_key": 2, "github_token": 1} + got := MergeMetadata("", hits) + var obj map[string]any + if err := json.Unmarshal([]byte(got), &obj); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if c, _ := obj["scrub_count"].(float64); int(c) != 3 { + t.Errorf("expected scrub_count=3, got %v", obj["scrub_count"]) + } + kinds, _ := obj["scrub_kinds"].([]any) + if len(kinds) != 2 { + t.Errorf("expected 2 kinds, got %v", kinds) + } +} + +func TestMergeMetadataExisting(t *testing.T) { + existing := `{"source":"chat","importance":0.7}` + hits := map[string]int{"slack_token": 1} + got := MergeMetadata(existing, hits) + var obj map[string]any + if err := json.Unmarshal([]byte(got), &obj); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if obj["source"] != "chat" { + t.Errorf("existing keys should be preserved, got %v", obj) + } + if c, _ := obj["scrub_count"].(float64); int(c) != 1 { + t.Errorf("expected scrub_count=1, got %v", obj["scrub_count"]) + } +} + +func TestMergeMetadataNoHits(t *testing.T) { + existing := `{"foo":"bar"}` + got := MergeMetadata(existing, map[string]int{}) + if got != existing { + t.Errorf("no hits should leave metadata untouched, got %q", got) + } +} + +func TestMergeMetadataInvalidExisting(t *testing.T) { + // Garbage existing metadata should be replaced by a fresh object. + got := MergeMetadata("not-json", map[string]int{"github_token": 1}) + var obj map[string]any + if err := json.Unmarshal([]byte(got), &obj); err != nil { + t.Fatalf("expected valid JSON: %v", err) + } + if _, ok := obj["scrub_count"]; !ok { + t.Errorf("expected scrub_count key, got %v", obj) + } +} + +func TestMergeHits(t *testing.T) { + a := map[string]int{"anthropic_key": 1, "github_token": 1} + b := map[string]int{"anthropic_key": 2, "slack_token": 1} + got := MergeHits(a, b) + if got["anthropic_key"] != 3 { + t.Errorf("expected sum for overlapping key, got %d", got["anthropic_key"]) + } + if got["github_token"] != 1 || got["slack_token"] != 1 { + t.Errorf("expected non-overlapping keys kept, got %+v", got) + } +} From 5d38b3fe2ce253906a829549997a0d9ce62990ec Mon Sep 17 00:00:00 2001 From: arreyder Date: Sat, 18 Apr 2026 10:30:55 -0500 Subject: [PATCH 3/8] Dedup memories by content-hash within a time window MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a stable SHA-256 hash over normalized title+content+tags (whitespace-collapsed, tag-order-agnostic, case-normalized), stored in a new content_hash field on the memories collection. On store_memory and bulk_store_memories, a Solr lookup in the last N seconds (default 300) with the same hash causes the insert to be skipped and the existing ID returned. on_duplicate=merge bumps updated_at on the existing doc; on_duplicate=force bypasses the check. Existing memories without a content_hash are never matched, so no backfill is required — they simply can't be deduped against until touched again. Closes #14. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/solr-mem-server/bulk_store_tool.go | 83 +++++++++++++++++++------- cmd/solr-mem-server/dedup.go | 51 ++++++++++++++++ cmd/solr-mem-server/dedup_test.go | 32 ++++++++++ cmd/solr-mem-server/store_tool.go | 71 +++++++++++++++++----- cmd/solr-mem-server/tools.go | 39 +++++++----- internal/contenthash/hash.go | 69 +++++++++++++++++++++ internal/contenthash/hash_test.go | 72 ++++++++++++++++++++++ internal/solr/types.go | 7 ++- solr/managed-schema.xml | 1 + 9 files changed, 368 insertions(+), 57 deletions(-) create mode 100644 cmd/solr-mem-server/dedup.go create mode 100644 cmd/solr-mem-server/dedup_test.go create mode 100644 internal/contenthash/hash.go create mode 100644 internal/contenthash/hash_test.go diff --git a/cmd/solr-mem-server/bulk_store_tool.go b/cmd/solr-mem-server/bulk_store_tool.go index 6f64a75..18d7790 100644 --- a/cmd/solr-mem-server/bulk_store_tool.go +++ b/cmd/solr-mem-server/bulk_store_tool.go @@ -3,8 +3,10 @@ package main import ( "context" "fmt" + "log" "time" + "github.com/arreyder/solr-mem/internal/contenthash" "github.com/arreyder/solr-mem/internal/privacy" "github.com/arreyder/solr-mem/internal/solr" "github.com/google/uuid" @@ -27,8 +29,12 @@ func bulkStoreMemoriesTool(ctx context.Context, args map[string]any) (any, error now := time.Now().UTC() var docs []solr.Document var ids []string + var duplicates []map[string]any totalScrubbed := 0 + windowSec := getInt(args, "dedup_window_seconds", defaultDedupWindowSec) + onDup := normalizeOnDuplicate(getString(args, "on_duplicate")) + for i, raw := range memories { m, ok := raw.(map[string]any) if !ok { @@ -43,9 +49,6 @@ func bulkStoreMemoriesTool(ctx context.Context, args map[string]any) (any, error lifetime := normalizeLifetime(getString(m, "lifetime")) expiresAt := resolveExpiration(lifetime, getString(m, "expires_at")) - id := uuid.New().String() - ids = append(ids, id) - format := getString(m, "format") if format == "" { format = "prose" @@ -61,35 +64,71 @@ func bulkStoreMemoriesTool(ctx context.Context, args map[string]any) (any, error totalScrubbed += v } + tags := getStringSlice(m, "tags") + hash := contenthash.Compute(scrubbedTitle, scrubbedContent, tags) + + if hash != "" && windowSec > 0 && onDup != "force" { + existingID, err := findRecentByHash(ctx, solrClient, hash, windowSec) + if err != nil { + log.Printf("bulk_store: dedup lookup failed at index %d (continuing as new): %v", i, err) + } else if existingID != "" { + action := "skipped" + if onDup == "merge" { + upd := map[string]any{"updated_at": now.Format(time.RFC3339)} + if err := solrClient.Update(ctx, existingID, upd); err != nil { + log.Printf("bulk_store: merge update failed at index %d: %v", i, err) + } else { + action = "merged" + } + } + duplicates = append(duplicates, map[string]any{ + "index": i, + "id": existingID, + "action": action, + }) + continue + } + } + + id := uuid.New().String() + ids = append(ids, id) + docs = append(docs, solr.Document{ - ID: id, - AgentID: getString(m, "agent_id"), - MemoryType: getString(m, "memory_type"), - Content: scrubbedContent, - Title: scrubbedTitle, - Tags: getStringSlice(m, "tags"), - Source: getString(m, "source"), - Importance: getFloat(m, "importance", 0.5), - Metadata: metadata, - CreatedAt: now, - UpdatedAt: now, - ExpiresAt: expiresAt, - Lifetime: lifetime, - SessionID: getString(m, "session_id"), - RelatedIDs: getStringSlice(m, "related_ids"), - Format: format, + ID: id, + AgentID: getString(m, "agent_id"), + MemoryType: getString(m, "memory_type"), + Content: scrubbedContent, + Title: scrubbedTitle, + Tags: tags, + Source: getString(m, "source"), + Importance: getFloat(m, "importance", 0.5), + Metadata: metadata, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: expiresAt, + Lifetime: lifetime, + SessionID: getString(m, "session_id"), + RelatedIDs: getStringSlice(m, "related_ids"), + Format: format, + ContentHash: hash, }) } - if err := solrClient.Add(ctx, docs...); err != nil { - return nil, fmt.Errorf("failed to bulk store memories: %w", err) + if len(docs) > 0 { + if err := solrClient.Add(ctx, docs...); err != nil { + return nil, fmt.Errorf("failed to bulk store memories: %w", err) + } } - text := fmt.Sprintf("Successfully stored %d memories.\nIDs: %v", len(docs), ids) + text := fmt.Sprintf("Stored %d memories.\nIDs: %v", len(docs), ids) structured := map[string]any{ "count": len(docs), "ids": ids, } + if len(duplicates) > 0 { + text += fmt.Sprintf("\nDedup: %d duplicates within %ds window", len(duplicates), windowSec) + structured["duplicates"] = duplicates + } if totalScrubbed > 0 { text += fmt.Sprintf("\nPrivacy: redacted %d secret(s) across all memories", totalScrubbed) structured["scrub_count"] = totalScrubbed diff --git a/cmd/solr-mem-server/dedup.go b/cmd/solr-mem-server/dedup.go new file mode 100644 index 0000000..6ec9d1a --- /dev/null +++ b/cmd/solr-mem-server/dedup.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + "fmt" + + "github.com/arreyder/solr-mem/internal/solr" +) + +// defaultDedupWindowSec is the default window for content-hash dedup on writes. +// Callers can override per-call with dedup_window_seconds; 0 disables. +const defaultDedupWindowSec = 300 + +// findRecentByHash returns the ID of the most recent memory with content_hash +// equal to hash and created_at within the last windowSec seconds. Empty string +// means no match (or the lookup errored — callers treat it as "not a dup"). +func findRecentByHash(ctx context.Context, client *solr.Client, hash string, windowSec int) (string, error) { + if hash == "" || windowSec <= 0 || client == nil { + return "", nil + } + params := solr.QueryParams{ + Query: "*:*", + FilterQueries: []string{ + fmt.Sprintf("content_hash:%s", hash), + fmt.Sprintf("created_at:[NOW-%dSECONDS TO *]", windowSec), + }, + Fields: []string{"id"}, + Sort: "created_at desc", + Rows: 1, + } + resp, err := client.Query(ctx, params) + if err != nil { + return "", err + } + if len(resp.Docs) == 0 { + return "", nil + } + id, _ := resp.Docs[0]["id"].(string) + return id, nil +} + +// normalizeOnDuplicate canonicalizes the on_duplicate argument. +// "" / unrecognized → "skip". +func normalizeOnDuplicate(s string) string { + switch s { + case "skip", "merge", "force": + return s + default: + return "skip" + } +} diff --git a/cmd/solr-mem-server/dedup_test.go b/cmd/solr-mem-server/dedup_test.go new file mode 100644 index 0000000..5b61f19 --- /dev/null +++ b/cmd/solr-mem-server/dedup_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "testing" +) + +func TestNormalizeOnDuplicate(t *testing.T) { + cases := []struct{ in, want string }{ + {"", "skip"}, + {"skip", "skip"}, + {"merge", "merge"}, + {"force", "force"}, + {"garbage", "skip"}, + {"SKIP", "skip"}, // case-sensitive — only lowercase forms are accepted + } + for _, c := range cases { + if got := normalizeOnDuplicate(c.in); got != c.want { + t.Errorf("normalizeOnDuplicate(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestFindRecentByHashNoopPaths(t *testing.T) { + // Empty hash, zero window, or nil client should all short-circuit without error. + if id, err := findRecentByHash(context.Background(), nil, "", 0); id != "" || err != nil { + t.Errorf("nil inputs should noop, got id=%q err=%v", id, err) + } + if id, err := findRecentByHash(context.Background(), nil, "abc", 300); id != "" || err != nil { + t.Errorf("nil client should noop, got id=%q err=%v", id, err) + } +} diff --git a/cmd/solr-mem-server/store_tool.go b/cmd/solr-mem-server/store_tool.go index c17459c..9956823 100644 --- a/cmd/solr-mem-server/store_tool.go +++ b/cmd/solr-mem-server/store_tool.go @@ -3,8 +3,10 @@ package main import ( "context" "fmt" + "log" "time" + "github.com/arreyder/solr-mem/internal/contenthash" "github.com/arreyder/solr-mem/internal/privacy" "github.com/arreyder/solr-mem/internal/solr" "github.com/google/uuid" @@ -32,23 +34,60 @@ func storeMemoryTool(ctx context.Context, args map[string]any) (any, error) { allHits := privacy.MergeHits(contentHits, titleHits) metadata = privacy.MergeMetadata(metadata, allHits) + tags := getStringSlice(args, "tags") + hash := contenthash.Compute(scrubbedTitle, scrubbedContent, tags) + windowSec := getInt(args, "dedup_window_seconds", defaultDedupWindowSec) + onDup := normalizeOnDuplicate(getString(args, "on_duplicate")) + + if hash != "" && windowSec > 0 && onDup != "force" { + existingID, err := findRecentByHash(ctx, solrClient, hash, windowSec) + if err != nil { + log.Printf("store_memory: dedup lookup failed (continuing as new): %v", err) + } else if existingID != "" { + if onDup == "merge" { + upd := map[string]any{"updated_at": now.Format(time.RFC3339)} + if err := solrClient.Update(ctx, existingID, upd); err != nil { + log.Printf("store_memory: merge update failed: %v", err) + } + return ToolOutput{ + Text: fmt.Sprintf("Memory already stored within the last %ds. Merged into existing.\nID: %s", windowSec, existingID), + Structured: map[string]any{ + "id": existingID, + "duplicate": true, + "action": "merged", + }, + }, nil + } + // skip (default) + return ToolOutput{ + Text: fmt.Sprintf("Memory already stored within the last %ds. Skipped.\nID: %s", windowSec, existingID), + Structured: map[string]any{ + "id": existingID, + "duplicate": true, + "action": "skipped", + }, + }, nil + } + } + doc := solr.Document{ - ID: uuid.New().String(), - AgentID: getString(args, "agent_id"), - MemoryType: getString(args, "memory_type"), - Content: scrubbedContent, - Title: scrubbedTitle, - Tags: getStringSlice(args, "tags"), - Source: getString(args, "source"), - Importance: getFloat(args, "importance", 0.5), - Metadata: metadata, - CreatedAt: now, - UpdatedAt: now, - ExpiresAt: expiresAt, - Lifetime: lifetime, - SessionID: getString(args, "session_id"), - RelatedIDs: getStringSlice(args, "related_ids"), - Format: format, + ID: uuid.New().String(), + AgentID: getString(args, "agent_id"), + MemoryType: getString(args, "memory_type"), + Content: scrubbedContent, + Title: scrubbedTitle, + Tags: tags, + Source: getString(args, "source"), + Importance: getFloat(args, "importance", 0.5), + Metadata: metadata, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: expiresAt, + Lifetime: lifetime, + SessionID: getString(args, "session_id"), + RelatedIDs: getStringSlice(args, "related_ids"), + Format: format, + ContentHash: hash, } if err := solrClient.Add(ctx, doc); err != nil { diff --git a/cmd/solr-mem-server/tools.go b/cmd/solr-mem-server/tools.go index 3249bab..bd57b31 100644 --- a/cmd/solr-mem-server/tools.go +++ b/cmd/solr-mem-server/tools.go @@ -26,25 +26,29 @@ func ToolSchemas() []ToolDefinition { **When to use**: Save observations, reflections, facts, conversations, tasks, or decisions for later retrieval. **Required**: content (the memory text) -**Optional**: agent_id, memory_type, title, tags, source, importance, metadata, lifetime, session_id, related_ids, expires_at +**Optional**: agent_id, memory_type, title, tags, source, importance, metadata, lifetime, session_id, related_ids, expires_at, format, dedup_window_seconds, on_duplicate **Lifetime values**: permanent (default, never expires), session (cleaned up with session), ephemeral (1 hour TTL), temporary (7 day TTL) +**Dedup**: By default an identical memory (same title + content + tags) stored within 300 seconds is skipped and the existing ID is returned. Pass dedup_window_seconds=0 to disable, on_duplicate="merge" to bump updated_at on the existing doc, or on_duplicate="force" to always insert. + **Content format**: Use compact structured formats (YAML, key-value, tables) over prose. Put searchable summary in title, machine-readable data in metadata (JSON), and categorization in tags. See server instructions for examples.`, InputSchema: NewObjectSchema(map[string]any{ - "content": prop("string", "The main text content of the memory (required)"), - "agent_id": prop("string", "ID of the agent storing this memory"), - "memory_type": prop("string", "Type: observation, reflection, fact, conversation, task, decision"), - "title": prop("string", "Short title or summary of the memory"), - "tags": arrayPropSchema(prop("string", "Tag"), "Categorization tags"), - "source": prop("string", "Where this memory came from (e.g., conversation, tool, file)"), - "importance": numberProp("Importance score from 0.0 to 1.0", floatPtr(0), floatPtr(1)), - "metadata": prop("string", "JSON string of arbitrary metadata"), - "lifetime": prop("string", "Memory lifetime: permanent (default), session, ephemeral (1h), temporary (7d)"), - "session_id": prop("string", "Session/conversation ID to group memories"), - "related_ids": arrayPropSchema(prop("string", "ID"), "IDs of related memories"), - "expires_at": prop("string", "Explicit expiration date (ISO 8601). Overrides lifetime."), - "format": prop("string", "Content format: yaml, markdown, json, table, prose (default: prose). Helps agents choose parsing strategy."), + "content": prop("string", "The main text content of the memory (required)"), + "agent_id": prop("string", "ID of the agent storing this memory"), + "memory_type": prop("string", "Type: observation, reflection, fact, conversation, task, decision"), + "title": prop("string", "Short title or summary of the memory"), + "tags": arrayPropSchema(prop("string", "Tag"), "Categorization tags"), + "source": prop("string", "Where this memory came from (e.g., conversation, tool, file)"), + "importance": numberProp("Importance score from 0.0 to 1.0", floatPtr(0), floatPtr(1)), + "metadata": prop("string", "JSON string of arbitrary metadata"), + "lifetime": prop("string", "Memory lifetime: permanent (default), session, ephemeral (1h), temporary (7d)"), + "session_id": prop("string", "Session/conversation ID to group memories"), + "related_ids": arrayPropSchema(prop("string", "ID"), "IDs of related memories"), + "expires_at": prop("string", "Explicit expiration date (ISO 8601). Overrides lifetime."), + "format": prop("string", "Content format: yaml, markdown, json, table, prose (default: prose). Helps agents choose parsing strategy."), + "dedup_window_seconds": integerProp("Skip-if-duplicate window in seconds (default 300, 0 disables)", intPtr(0), intPtr(86400)), + "on_duplicate": prop("string", "Action when a duplicate is found: skip (default), merge (bump updated_at), force (always insert)"), }, "content"), }, Handler: storeMemoryTool, @@ -174,9 +178,10 @@ func ToolSchemas() []ToolDefinition { Name: "bulk_store_memories", Description: `Store multiple memories in a single batch operation. -**When to use**: Efficiently store many memories at once (e.g., importing notes, bulk archival). Each memory in the array uses the same fields as store_memory. +**When to use**: Efficiently store many memories at once (e.g., importing notes, bulk archival). Each memory in the array uses the same fields as store_memory. Dedup parameters apply to the batch as a whole. -**Required**: memories (array of memory objects, each with at least a "content" field)`, +**Required**: memories (array of memory objects, each with at least a "content" field) +**Optional**: dedup_window_seconds, on_duplicate`, InputSchema: NewObjectSchema(map[string]any{ "memories": arrayPropSchema( NewObjectSchema(map[string]any{ @@ -196,6 +201,8 @@ func ToolSchemas() []ToolDefinition { }, "content"), "Array of memory objects to store", ), + "dedup_window_seconds": integerProp("Skip-if-duplicate window in seconds (default 300, 0 disables)", intPtr(0), intPtr(86400)), + "on_duplicate": prop("string", "Action when a duplicate is found: skip (default), merge (bump updated_at), force (always insert)"), }, "memories"), }, Handler: bulkStoreMemoriesTool, diff --git a/internal/contenthash/hash.go b/internal/contenthash/hash.go new file mode 100644 index 0000000..90c0546 --- /dev/null +++ b/internal/contenthash/hash.go @@ -0,0 +1,69 @@ +// Package contenthash produces stable SHA-256 hashes over memory content for +// dedup purposes. Normalization is deliberate: the same logical memory sent +// with different whitespace or tag ordering should collide. +package contenthash + +import ( + "crypto/sha256" + "encoding/hex" + "sort" + "strings" +) + +// Compute returns a hex-encoded SHA-256 over normalized inputs. +// +// Normalization: +// - title and content are trimmed and collapsed (runs of whitespace → single space) +// - tags are sorted and each tag is lowercased + trimmed +// - fields are joined with 0x1f (unit separator) to avoid ambiguity +// +// An empty content produces an empty hash (caller decides whether to dedup). +func Compute(title, content string, tags []string) string { + content = collapse(content) + if content == "" { + return "" + } + title = collapse(title) + + normTags := make([]string, 0, len(tags)) + for _, t := range tags { + t = strings.ToLower(strings.TrimSpace(t)) + if t != "" { + normTags = append(normTags, t) + } + } + sort.Strings(normTags) + + h := sha256.New() + h.Write([]byte(title)) + h.Write([]byte{0x1f}) + h.Write([]byte(content)) + h.Write([]byte{0x1f}) + // Join tags with 0x1e so a tag containing a literal delimiter can't + // collide with two tags split by it. + h.Write([]byte(strings.Join(normTags, "\x1e"))) + return hex.EncodeToString(h.Sum(nil)) +} + +// collapse trims and replaces runs of whitespace with a single space. +func collapse(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + var b strings.Builder + b.Grow(len(s)) + space := false + for _, r := range s { + if r == ' ' || r == '\t' || r == '\n' || r == '\r' { + if !space { + b.WriteByte(' ') + space = true + } + continue + } + b.WriteRune(r) + space = false + } + return b.String() +} diff --git a/internal/contenthash/hash_test.go b/internal/contenthash/hash_test.go new file mode 100644 index 0000000..548e0eb --- /dev/null +++ b/internal/contenthash/hash_test.go @@ -0,0 +1,72 @@ +package contenthash + +import "testing" + +func TestComputeStable(t *testing.T) { + h1 := Compute("title", "body", []string{"a", "b"}) + h2 := Compute("title", "body", []string{"a", "b"}) + if h1 != h2 { + t.Errorf("expected identical hash for identical inputs, got %q vs %q", h1, h2) + } + if len(h1) != 64 { + t.Errorf("expected 64-char hex sha-256, got %d chars", len(h1)) + } +} + +func TestComputeTagOrderAgnostic(t *testing.T) { + a := Compute("t", "c", []string{"x", "y", "z"}) + b := Compute("t", "c", []string{"z", "x", "y"}) + if a != b { + t.Errorf("tag order should not affect hash") + } +} + +func TestComputeTagCaseNormalized(t *testing.T) { + a := Compute("t", "c", []string{"Foo", "bar"}) + b := Compute("t", "c", []string{"foo", "BAR"}) + if a != b { + t.Errorf("tag case should not affect hash") + } +} + +func TestComputeWhitespaceNormalized(t *testing.T) { + a := Compute("my title", "hello world", nil) + b := Compute("my title", "hello\n\tworld\n", nil) + if a != b { + t.Errorf("whitespace differences should not affect hash: %q vs %q", a, b) + } +} + +func TestComputeDifferentContent(t *testing.T) { + a := Compute("t", "one", nil) + b := Compute("t", "two", nil) + if a == b { + t.Errorf("different content should hash differently") + } +} + +func TestComputeDifferentTitle(t *testing.T) { + a := Compute("title one", "c", nil) + b := Compute("title two", "c", nil) + if a == b { + t.Errorf("different title should hash differently") + } +} + +func TestComputeEmptyContent(t *testing.T) { + if h := Compute("t", "", []string{"x"}); h != "" { + t.Errorf("empty content should produce empty hash, got %q", h) + } + if h := Compute("", " ", nil); h != "" { + t.Errorf("whitespace-only content should produce empty hash") + } +} + +func TestComputeTagSeparatorNotAmbiguous(t *testing.T) { + // A tag containing unusual characters must not collide with a split version. + a := Compute("t", "c", []string{"a,b"}) + b := Compute("t", "c", []string{"a", "b"}) + if a == b { + t.Errorf("tags with delimiter-like chars must not collide with split tags") + } +} diff --git a/internal/solr/types.go b/internal/solr/types.go index c32de05..78cc4ab 100644 --- a/internal/solr/types.go +++ b/internal/solr/types.go @@ -17,9 +17,10 @@ type Document struct { UpdatedAt time.Time `json:"updated_at"` ExpiresAt string `json:"expires_at,omitempty"` Lifetime string `json:"lifetime,omitempty"` - SessionID string `json:"session_id,omitempty"` - RelatedIDs []string `json:"related_ids,omitempty"` - Format string `json:"format,omitempty"` + SessionID string `json:"session_id,omitempty"` + RelatedIDs []string `json:"related_ids,omitempty"` + Format string `json:"format,omitempty"` + ContentHash string `json:"content_hash,omitempty"` } // QueryParams holds parameters for a Solr search query. diff --git a/solr/managed-schema.xml b/solr/managed-schema.xml index 0de096e..f6c9a6e 100644 --- a/solr/managed-schema.xml +++ b/solr/managed-schema.xml @@ -48,5 +48,6 @@ + From d3f92b9d10d9bee24b83a6740f26abecfea8f428 Mon Sep 17 00:00:00 2001 From: arreyder Date: Sat, 18 Apr 2026 11:14:04 -0500 Subject: [PATCH 4/8] Add retrieval benchmark harness (solr-mem-bench) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New binary at cmd/solr-mem-bench that seeds a namespaced bench-* corpus into any memories collection (safe to run against a live one — only touches bench-* IDs), runs a shipped query set with gold labels, and reports R@1/R@3/R@5/R@10 and MRR plus a per-query breakdown as Markdown. Ships 30-doc synthetic corpus + 25 queries covering easy keyword lookups and harder semantic paraphrases. The paraphrase queries are where BM25 is expected to struggle, giving us a baseline to measure hybrid retrieval (#16 + #17) against. New Makefile target: make bench. 11 unit tests cover R@K, MRR, and Aggregate; no live Solr required for the test suite. Closes #20. Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 1 + Makefile | 16 +- cmd/solr-mem-bench/main.go | 230 ++++++++++++++++++++++ cmd/solr-mem-bench/metrics.go | 78 ++++++++ cmd/solr-mem-bench/metrics_test.go | 111 +++++++++++ cmd/solr-mem-bench/testdata/corpus.jsonl | 30 +++ cmd/solr-mem-bench/testdata/queries.jsonl | 25 +++ 7 files changed, 489 insertions(+), 2 deletions(-) create mode 100644 cmd/solr-mem-bench/main.go create mode 100644 cmd/solr-mem-bench/metrics.go create mode 100644 cmd/solr-mem-bench/metrics_test.go create mode 100644 cmd/solr-mem-bench/testdata/corpus.jsonl create mode 100644 cmd/solr-mem-bench/testdata/queries.jsonl diff --git a/.gitignore b/.gitignore index 8ebb177..21f8f37 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ Thumbs.db # Solr data (managed by Docker volume) solr-data/ /solr-mem-indexer +/solr-mem-bench diff --git a/Makefile b/Makefile index e776d44..11f7bd6 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-indexer install install-indexer test tidy run dev up down logs reset docker-build docker-up config systemd-install systemd-uninstall indexer-install indexer-uninstall launchd-install launchd-install-server launchd-install-indexer launchd-uninstall skill-install skill-uninstall +.PHONY: build build-indexer build-bench install install-indexer test tidy run dev up down logs reset docker-build docker-up config systemd-install systemd-uninstall indexer-install indexer-uninstall launchd-install launchd-install-server launchd-install-indexer launchd-uninstall skill-install skill-uninstall bench # Go targets build: @@ -7,7 +7,19 @@ build: build-indexer: go build -o bin/solr-mem-indexer ./cmd/solr-mem-indexer -build-all: build build-indexer +build-all: build build-indexer build-bench + +build-bench: + go build -o bin/solr-mem-bench ./cmd/solr-mem-bench + +# Retrieval benchmark. Seeds the memories collection with namespaced bench-* +# docs (safe to run against a live collection — only touches bench-* IDs) and +# runs the shipped query set. Override BENCH_URL to point elsewhere. +BENCH_URL ?= http://pax89.local:8983/solr/memories +bench: build-bench + ./bin/solr-mem-bench -solr-url $(BENCH_URL) -seed \ + -corpus cmd/solr-mem-bench/testdata/corpus.jsonl \ + -queries cmd/solr-mem-bench/testdata/queries.jsonl install: go install ./cmd/solr-mem-server diff --git a/cmd/solr-mem-bench/main.go b/cmd/solr-mem-bench/main.go new file mode 100644 index 0000000..d08f52f --- /dev/null +++ b/cmd/solr-mem-bench/main.go @@ -0,0 +1,230 @@ +// solr-mem-bench is a small harness for measuring memory retrieval quality. +// +// Usage: +// +// solr-mem-bench -solr-url http://localhost:8983/solr -collection memories_bench \ +// -corpus testdata/corpus.jsonl -queries testdata/queries.jsonl -seed +// +// The -seed flag clears the target collection and writes the corpus before +// running queries. Without -seed, the harness assumes the collection is +// already populated with the gold IDs referenced by queries. +// +// Gold format (JSONL, one per line): +// +// corpus: {"id":"mem-001","title":"...","content":"...","tags":["..."]} +// queries: {"id":"q-001","text":"database N+1 fix","gold":["mem-007"]} +package main + +import ( + "bufio" + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "time" + + "github.com/arreyder/solr-mem/internal/solr" +) + +type corpusItem struct { + ID string `json:"id"` + Title string `json:"title"` + Content string `json:"content"` + Tags []string `json:"tags,omitempty"` +} + +type queryItem struct { + ID string `json:"id"` + Text string `json:"text"` + Gold []string `json:"gold"` +} + +func main() { + var ( + solrURL = flag.String("solr-url", "http://localhost:8983/solr/memories", "Base URL of the memories collection") + corpusPath = flag.String("corpus", "cmd/solr-mem-bench/testdata/corpus.jsonl", "Path to corpus JSONL") + queryPath = flag.String("queries", "cmd/solr-mem-bench/testdata/queries.jsonl", "Path to queries JSONL") + seed = flag.Bool("seed", false, "Clear the collection and write the corpus before running queries") + topK = flag.Int("topk", 10, "Number of hits to retrieve per query") + variant = flag.String("variant", "bm25", "Retrieval variant name (for labeling output)") + ) + flag.Parse() + + corpus, err := loadCorpus(*corpusPath) + if err != nil { + log.Fatalf("load corpus: %v", err) + } + queries, err := loadQueries(*queryPath) + if err != nil { + log.Fatalf("load queries: %v", err) + } + log.Printf("loaded %d corpus items, %d queries", len(corpus), len(queries)) + + client := solr.NewClient(*solrURL) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + if *seed { + log.Printf("seeding collection at %s (clearing first)", *solrURL) + if err := seedCorpus(ctx, client, corpus); err != nil { + log.Fatalf("seed: %v", err) + } + // Give Solr a moment to commit. + time.Sleep(1 * time.Second) + } + + agg := NewAggregate(1, 3, 5, 10) + perQuery := make([]queryResult, 0, len(queries)) + + start := time.Now() + for _, q := range queries { + hits, err := runQuery(ctx, client, q.Text, *topK) + if err != nil { + log.Printf("query %s failed: %v", q.ID, err) + continue + } + agg.Add(q.Gold, hits) + perQuery = append(perQuery, queryResult{ID: q.ID, Text: q.Text, Hits: hits, Gold: q.Gold}) + } + elapsed := time.Since(start) + + agg.Finalize() + printReport(*variant, agg, perQuery, elapsed) +} + +type queryResult struct { + ID string + Text string + Hits []string + Gold []string +} + +func loadCorpus(path string) ([]corpusItem, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + var out []corpusItem + s := bufio.NewScanner(f) + s.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for s.Scan() { + line := s.Bytes() + if len(line) == 0 { + continue + } + var item corpusItem + if err := json.Unmarshal(line, &item); err != nil { + return nil, fmt.Errorf("parse corpus line: %w", err) + } + out = append(out, item) + } + return out, s.Err() +} + +func loadQueries(path string) ([]queryItem, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + var out []queryItem + s := bufio.NewScanner(f) + s.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for s.Scan() { + line := s.Bytes() + if len(line) == 0 { + continue + } + var item queryItem + if err := json.Unmarshal(line, &item); err != nil { + return nil, fmt.Errorf("parse query line: %w", err) + } + out = append(out, item) + } + return out, s.Err() +} + +// seedCorpus writes the bench corpus to Solr. It scopes its cleanup to +// IDs starting with "bench-" so it is safe to run against a live memories +// collection — it will only ever touch its own namespaced docs. +func seedCorpus(ctx context.Context, client *solr.Client, items []corpusItem) error { + for _, it := range items { + if !isBenchID(it.ID) { + return fmt.Errorf("corpus id %q must start with 'bench-' for safe seeding", it.ID) + } + } + if err := client.DeleteByQuery(ctx, "id:bench-*"); err != nil { + return fmt.Errorf("clear bench docs: %w", err) + } + docs := make([]solr.Document, 0, len(items)) + now := time.Now().UTC() + for _, it := range items { + docs = append(docs, solr.Document{ + ID: it.ID, + Title: it.Title, + Content: it.Content, + Tags: it.Tags, + CreatedAt: now, + UpdatedAt: now, + Lifetime: "permanent", + Format: "prose", + }) + } + return client.Add(ctx, docs...) +} + +func isBenchID(s string) bool { + return len(s) > 6 && s[:6] == "bench-" +} + +func runQuery(ctx context.Context, client *solr.Client, text string, topK int) ([]string, error) { + params := solr.QueryParams{ + Query: text, + Rows: topK, + Fields: []string{"id"}, + } + resp, err := client.Query(ctx, params) + if err != nil { + return nil, err + } + out := make([]string, 0, len(resp.Docs)) + for _, d := range resp.Docs { + if id, ok := d["id"].(string); ok { + out = append(out, id) + } + } + return out, nil +} + +func printReport(variant string, agg *Aggregate, results []queryResult, elapsed time.Duration) { + fmt.Printf("# solr-mem retrieval benchmark\n\n") + fmt.Printf("Variant: `%s` | Queries: %d | Elapsed: %s\n\n", variant, agg.Queries, elapsed.Round(time.Millisecond)) + fmt.Printf("| Metric | Value |\n|---|---|\n") + for _, k := range []int{1, 3, 5, 10} { + if v, ok := agg.RecallAt[k]; ok { + fmt.Printf("| R@%d | %.3f |\n", k, v) + } + } + fmt.Printf("| MRR | %.3f |\n\n", agg.MRR) + + fmt.Printf("## Per-query breakdown\n\n") + fmt.Printf("| Query | Gold | R@5 | MRR | Top hits |\n|---|---|---|---|---|\n") + for _, r := range results { + top := r.Hits + if len(top) > 5 { + top = top[:5] + } + fmt.Printf("| %s (%s) | %v | %.2f | %.2f | %v |\n", + r.ID, truncate(r.Text, 40), r.Gold, RecallAtK(r.Gold, r.Hits, 5), MRR(r.Gold, r.Hits), top) + } +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n-1] + "…" +} diff --git a/cmd/solr-mem-bench/metrics.go b/cmd/solr-mem-bench/metrics.go new file mode 100644 index 0000000..6853dcc --- /dev/null +++ b/cmd/solr-mem-bench/metrics.go @@ -0,0 +1,78 @@ +package main + +// RecallAtK returns the fraction of gold IDs that appear in the top-K hits. +// gold is the set of relevant IDs for a query; hits is the ranked result list. +func RecallAtK(gold, hits []string, k int) float64 { + if len(gold) == 0 { + return 0 + } + if k <= 0 || k > len(hits) { + k = len(hits) + } + goldSet := make(map[string]struct{}, len(gold)) + for _, g := range gold { + goldSet[g] = struct{}{} + } + found := 0 + for i := 0; i < k; i++ { + if _, ok := goldSet[hits[i]]; ok { + found++ + } + } + return float64(found) / float64(len(gold)) +} + +// MRR returns the reciprocal rank of the first gold hit in the ranked list, +// or 0 if no gold item was retrieved. Ranks are 1-indexed. +func MRR(gold, hits []string) float64 { + if len(gold) == 0 || len(hits) == 0 { + return 0 + } + goldSet := make(map[string]struct{}, len(gold)) + for _, g := range gold { + goldSet[g] = struct{}{} + } + for i, h := range hits { + if _, ok := goldSet[h]; ok { + return 1.0 / float64(i+1) + } + } + return 0 +} + +// Aggregate holds per-run averaged metrics. +type Aggregate struct { + Queries int + RecallAt map[int]float64 // k -> mean recall + MRR float64 +} + +// NewAggregate accumulates per-query metrics over a set of queries. +func NewAggregate(ks ...int) *Aggregate { + m := make(map[int]float64, len(ks)) + for _, k := range ks { + m[k] = 0 + } + return &Aggregate{RecallAt: m} +} + +// Add records one query's hits against its gold. +func (a *Aggregate) Add(gold, hits []string) { + a.Queries++ + for k := range a.RecallAt { + a.RecallAt[k] += RecallAtK(gold, hits, k) + } + a.MRR += MRR(gold, hits) +} + +// Finalize divides sums by query count, producing mean metrics. +func (a *Aggregate) Finalize() { + if a.Queries == 0 { + return + } + n := float64(a.Queries) + for k, v := range a.RecallAt { + a.RecallAt[k] = v / n + } + a.MRR /= n +} diff --git a/cmd/solr-mem-bench/metrics_test.go b/cmd/solr-mem-bench/metrics_test.go new file mode 100644 index 0000000..a447ee5 --- /dev/null +++ b/cmd/solr-mem-bench/metrics_test.go @@ -0,0 +1,111 @@ +package main + +import ( + "math" + "testing" +) + +func approx(a, b float64) bool { return math.Abs(a-b) < 1e-9 } + +func TestRecallAtKPerfect(t *testing.T) { + gold := []string{"a", "b"} + hits := []string{"a", "b", "c", "d", "e"} + if got := RecallAtK(gold, hits, 5); !approx(got, 1.0) { + t.Errorf("R@5 should be 1.0, got %f", got) + } +} + +func TestRecallAtKPartial(t *testing.T) { + gold := []string{"a", "b", "c"} + hits := []string{"a", "x", "y", "z", "b"} + // 2 of 3 gold found in top 5. + if got := RecallAtK(gold, hits, 5); !approx(got, 2.0/3.0) { + t.Errorf("expected 2/3, got %f", got) + } + // Only 1 in top 2. + if got := RecallAtK(gold, hits, 2); !approx(got, 1.0/3.0) { + t.Errorf("expected 1/3, got %f", got) + } +} + +func TestRecallAtKNoHits(t *testing.T) { + gold := []string{"a"} + hits := []string{"x", "y", "z"} + if got := RecallAtK(gold, hits, 3); !approx(got, 0) { + t.Errorf("expected 0, got %f", got) + } +} + +func TestRecallAtKKLargerThanHits(t *testing.T) { + gold := []string{"a"} + hits := []string{"a", "b"} + if got := RecallAtK(gold, hits, 10); !approx(got, 1.0) { + t.Errorf("k > len(hits) should not error, got %f", got) + } +} + +func TestRecallAtKEmptyGold(t *testing.T) { + if got := RecallAtK(nil, []string{"a"}, 5); got != 0 { + t.Errorf("no gold → 0, got %f", got) + } +} + +func TestMRRFirstPosition(t *testing.T) { + gold := []string{"a"} + hits := []string{"a", "b", "c"} + if got := MRR(gold, hits); !approx(got, 1.0) { + t.Errorf("first-rank gold → 1.0, got %f", got) + } +} + +func TestMRRThirdPosition(t *testing.T) { + gold := []string{"c"} + hits := []string{"a", "b", "c"} + if got := MRR(gold, hits); !approx(got, 1.0/3.0) { + t.Errorf("rank 3 → 1/3, got %f", got) + } +} + +func TestMRRNoMatch(t *testing.T) { + gold := []string{"z"} + hits := []string{"a", "b", "c"} + if got := MRR(gold, hits); got != 0 { + t.Errorf("no match → 0, got %f", got) + } +} + +func TestMRRMultipleGoldTakesBest(t *testing.T) { + // Two gold items; MRR uses the best (earliest) rank. + gold := []string{"c", "a"} + hits := []string{"a", "b", "c"} + if got := MRR(gold, hits); !approx(got, 1.0) { + t.Errorf("best rank should win: expected 1.0, got %f", got) + } +} + +func TestAggregateEndToEnd(t *testing.T) { + agg := NewAggregate(5, 10) + agg.Add([]string{"a"}, []string{"a", "b", "c"}) // R@5=1, R@10=1, MRR=1 + agg.Add([]string{"x"}, []string{"a", "b", "c"}) // all zero + agg.Add([]string{"c"}, []string{"a", "b", "c"}) // R@5=1, R@10=1, MRR=1/3 + agg.Finalize() + + // Means across 3 queries: R@5 = 2/3, R@10 = 2/3, MRR = (1 + 0 + 1/3)/3 = 4/9 + if !approx(agg.RecallAt[5], 2.0/3.0) { + t.Errorf("R@5 mean expected 2/3, got %f", agg.RecallAt[5]) + } + if !approx(agg.RecallAt[10], 2.0/3.0) { + t.Errorf("R@10 mean expected 2/3, got %f", agg.RecallAt[10]) + } + if !approx(agg.MRR, 4.0/9.0) { + t.Errorf("MRR mean expected 4/9, got %f", agg.MRR) + } +} + +func TestAggregateEmpty(t *testing.T) { + agg := NewAggregate(5) + agg.Finalize() // should not panic on zero queries + if agg.MRR != 0 || agg.RecallAt[5] != 0 { + t.Errorf("empty aggregate should stay zero") + } +} diff --git a/cmd/solr-mem-bench/testdata/corpus.jsonl b/cmd/solr-mem-bench/testdata/corpus.jsonl new file mode 100644 index 0000000..b71784f --- /dev/null +++ b/cmd/solr-mem-bench/testdata/corpus.jsonl @@ -0,0 +1,30 @@ +{"id":"bench-001","title":"JWT auth via jose middleware","content":"Project uses jose (not jsonwebtoken) for Edge runtime compatibility. Middleware in src/middleware/auth.ts validates Authorization: Bearer tokens, decodes JWT claims, attaches user to request context. Token signing uses RS256 with keys rotated weekly.","tags":["auth","jwt","middleware","edge"]} +{"id":"bench-002","title":"N+1 query fix on orders list","content":"OrdersController.list was fetching each order's customer in a separate SELECT, causing N+1. Switched to a single JOIN with customers table. p95 latency dropped from 1200ms to 140ms on the orders index page. Related to ticket ORD-413.","tags":["database","performance","n+1","orders"]} +{"id":"bench-003","title":"Rate limiting with Redis token bucket","content":"API gateway uses a sliding-window token-bucket algorithm backed by Redis. 100 requests per minute per API key, 1000 per org. Implementation in pkg/ratelimit/bucket.go. Tokens refill in 600ms increments to avoid thundering herd at minute boundaries.","tags":["rate-limit","redis","api","gateway"]} +{"id":"bench-004","title":"Postgres connection pool sizing","content":"Production pool maxed at 100 connections per service instance, 5 instances → 500 total. PgBouncer in transaction mode in front. Lowering pool size to 50 per instance reduced idle connection overhead by 30% with no throughput loss.","tags":["postgres","connection-pool","pgbouncer","performance"]} +{"id":"bench-005","title":"React context provider pattern for feature flags","content":"FeatureFlagProvider wraps the app root, hydrates flags once from /api/flags at mount, exposes useFeatureFlag(name) hook. Flag changes from admin UI push via WebSocket and trigger provider re-render. Stale-while-revalidate fallback to localStorage for offline.","tags":["react","feature-flags","websocket","frontend"]} +{"id":"bench-006","title":"Memory leak in WebSocket reconnect handler","content":"The reconnect exponential-backoff timer kept a closure reference to the old socket, preventing GC. Fix: explicitly null socket.onmessage and socket.onclose in cleanup, store the current timer ID outside the closure so successive retries replace it.","tags":["websocket","memory-leak","javascript","frontend"]} +{"id":"bench-007","title":"Slow Solr range facet on timestamp field","content":"facet.range on created_at with gap=+1DAY over 90 days was taking 3s. Switched to docValues=true on the pdate field and enabled facet.method=dv. Latency dropped to 80ms. Applies to any range faceting on high-cardinality date fields.","tags":["solr","facet","performance","date"]} +{"id":"bench-008","title":"Docker Compose healthcheck for Solr","content":"solr service in docker-compose.yml needs healthcheck: curl -sf http://localhost:8983/solr/admin/ping. Without it, dependent services start before the configset is loaded, causing schema-not-found errors on first run. 10s interval, 5 retries.","tags":["docker","solr","healthcheck","devops"]} +{"id":"bench-009","title":"TypeScript strict mode migration","content":"Migrated tsconfig with strict: true in stages: strictNullChecks first (revealed 180 null-deref bugs), then noImplicitAny (caught 40 places passing any to third-party APIs). Suppress with // @ts-expect-error initially, burn down incrementally.","tags":["typescript","migration","strict","types"]} +{"id":"bench-010","title":"Kafka consumer lag alerting","content":"Alert fires when consumer group lag exceeds 10k for 5 minutes on any partition. Check by: kafka-consumer-groups --describe --group X. Most common root cause: slow downstream DB writes. Fix is usually raising batch size, not scaling consumers.","tags":["kafka","alerting","lag","monitoring"]} +{"id":"bench-011","title":"GraphQL N+1 via DataLoader","content":"Nested resolvers for User.posts and Post.comments were firing one query per row. Wrapped each resolver in a DataLoader that batches IDs per request. Cut DB round-trips on the feed query from 300 to 3.","tags":["graphql","n+1","dataloader","database"]} +{"id":"bench-012","title":"AWS S3 presigned URL expiry","content":"Upload URLs signed with 15-min expiry. Users on slow networks occasionally fail right at the boundary. Bumped to 60 min for uploads, kept 5 min for downloads. Monitoring the RequestExpired 4xx count on the bucket metrics dashboard.","tags":["aws","s3","presigned-url","upload"]} +{"id":"bench-013","title":"Go generics for repository pattern","content":"Replaced a dozen Find/Create/Delete methods across 6 repos with one generic[T any] Repository struct. Used constraints.Ordered for sortable fields. Reduced LOC by ~40% and caught two places where type mismatches were silently compiling.","tags":["go","generics","repository","refactor"]} +{"id":"bench-014","title":"Flaky integration test on CI","content":"auth_integration_test.go failed 1-in-20 on GitHub Actions, always passed locally. Root cause: test reused a Redis key across parallel goroutines; local runs happened to serialize. Fix: use t.Name() as key prefix to isolate.","tags":["flaky","test","redis","ci"]} +{"id":"bench-015","title":"CORS preflight cache tuning","content":"Access-Control-Max-Age defaulted to 5s; browsers were firing OPTIONS every request. Raised to 86400 (1 day). Preflight traffic dropped 98%. Works well because our CORS config is stable.","tags":["cors","http","performance","browser"]} +{"id":"bench-016","title":"Pulumi stack for ephemeral preview envs","content":"Each PR gets a preview environment: Pulumi stack named pr-, S3 bucket, RDS serverless, domain cnamed to preview.example.com. Cleanup lambda removes stacks after PR close or 7 days inactivity. Cost per stack: ~$2/day.","tags":["pulumi","infra","preview","aws"]} +{"id":"bench-017","title":"Prometheus histogram bucket tuning","content":"Default buckets are tuned for web latency (le=0.005..10). For background job duration we used le=1,5,10,30,60,120,300. Too-coarse buckets hide p99 changes; too-fine explodes time-series cardinality. Revisit quarterly.","tags":["prometheus","metrics","histogram","monitoring"]} +{"id":"bench-018","title":"ElasticSearch to Solr migration","content":"Migrated search from ES 7.x to Solr 9. Key differences: Solr uses managed-schema.xml vs ES mappings, edismax is the edismax parser (no dis_max clause), field boosting syntax is title^2 not title:(query)^2. Faceting syntax maps cleanly 1:1.","tags":["elasticsearch","solr","migration","search"]} +{"id":"bench-019","title":"Stripe webhook idempotency","content":"Store the webhook event ID in a processed_events table with a unique constraint. Reject duplicates with 200 (not 4xx, Stripe will retry). 90-day retention then expire. Saw duplicate deliveries ~0.3% of the time.","tags":["stripe","webhook","idempotency","payments"]} +{"id":"bench-020","title":"pgx vs database/sql connection handling","content":"Switched from lib/pq to pgx/v5. pgx.Conn is stateful per-connection (prepared statements cached locally) so it doesn't play with PgBouncer in transaction mode. Use pgx.Pool with simple_protocol=true or stay with database/sql.","tags":["postgres","pgx","go","connection"]} +{"id":"bench-021","title":"nginx buffer size for large POST bodies","content":"Uploads >1MB were hitting client_body_buffer_size limits and spilling to /tmp on nginx hosts. Raised to 10m and moved temp dir to tmpfs. 413 errors went to zero. Also bumped client_max_body_size to 100m for the upload endpoint.","tags":["nginx","upload","buffer","http"]} +{"id":"bench-022","title":"Redis key prefix conventions","content":"We use :: (e.g. auth:session:abc123). Makes KEYS/SCAN across a namespace easy and lets us purge per-service with a single pattern. Avoid ::: double-colons; Redis pattern matching is happier with single separators.","tags":["redis","conventions","cache","keys"]} +{"id":"bench-023","title":"Go pprof on production pods","content":"Enabled net/http/pprof on /debug/pprof, bound to localhost only, exposed via kubectl port-forward on demand. Used pprof -http to analyze heap dumps. Found a string-append loop in JSON marshaling that was 40% of allocations.","tags":["go","pprof","performance","profiling"]} +{"id":"bench-024","title":"Datadog log pattern cardinality","content":"Log messages containing user-supplied UUIDs blew up pattern aggregation. Now we scrub UUIDs to in log formatter before shipping. Cardinality on the /patterns query dropped 100x, making anomaly detection useful again.","tags":["datadog","logs","cardinality","observability"]} +{"id":"bench-025","title":"Linear API rate limit","content":"Linear's GraphQL API caps at 1500 requests per hour per API key. Issue sync job batches 50 issues per request. Retry-After header is respected. Hit the cap once during a full re-import — now we paginate with 1s jitter between pages.","tags":["linear","api","rate-limit","integration"]} +{"id":"bench-026","title":"macOS launchd user agent for background service","content":"Created ~/Library/LaunchAgents/com.example.service.plist with KeepAlive and RunAtLoad. Logs to /tmp, environment vars set under EnvironmentVariables dict. launchctl load starts immediately and on login. Beats writing a custom systemd wrapper.","tags":["macos","launchd","service","daemon"]} +{"id":"bench-027","title":"Frontend bundle size budget enforcement","content":"bundle-analyzer in CI fails if main.js > 400KB gzipped. Broke twice this quarter — once from an accidental import of all of lodash, once from a moment.js locale bundle. Now we lint for barrel imports and use dayjs exclusively.","tags":["frontend","bundle","webpack","ci"]} +{"id":"bench-028","title":"PostgreSQL partial index for sparse column","content":"created_at IS NOT NULL on a 50M row table; only ~200k rows have it set. Regular btree took 800MB; partial index (WHERE created_at IS NOT NULL) is 12MB. Query planner picks it correctly after ANALYZE.","tags":["postgres","index","partial","performance"]} +{"id":"bench-029","title":"Claude Code hooks for auto-capture","content":"settings.json PostToolUse hook runs a shell script that parses the tool event JSON, filters for file-touching tools, and POSTs to a local observer endpoint. UserPromptSubmit mirrors prompts to a journaling log. Zero-effort session recording.","tags":["claude-code","hooks","observability","agent"]} +{"id":"bench-030","title":"Solr more-like-this for related memories","content":"MLT on title,content with mintf=2 mindf=2 gives decent recall for 'memories similar to X'. Filter out the source doc id. For memory system: raise mintf for short memories so we don't match on single common tokens.","tags":["solr","mlt","similarity","search"]} diff --git a/cmd/solr-mem-bench/testdata/queries.jsonl b/cmd/solr-mem-bench/testdata/queries.jsonl new file mode 100644 index 0000000..df1a2b5 --- /dev/null +++ b/cmd/solr-mem-bench/testdata/queries.jsonl @@ -0,0 +1,25 @@ +{"id":"q-01","text":"JWT authentication middleware","gold":["bench-001"]} +{"id":"q-02","text":"rate limiting API gateway","gold":["bench-003"]} +{"id":"q-03","text":"Postgres connection pool sizing","gold":["bench-004"]} +{"id":"q-04","text":"feature flag React provider","gold":["bench-005"]} +{"id":"q-05","text":"Stripe webhook idempotency","gold":["bench-019"]} +{"id":"q-06","text":"Kafka consumer lag alerts","gold":["bench-010"]} +{"id":"q-07","text":"database performance optimization","gold":["bench-002","bench-011","bench-028","bench-007"]} +{"id":"q-08","text":"secure token validation edge runtime","gold":["bench-001"]} +{"id":"q-09","text":"fixing slow search facets on dates","gold":["bench-007"]} +{"id":"q-10","text":"preventing duplicate payment processing","gold":["bench-019"]} +{"id":"q-11","text":"browser preflight caching","gold":["bench-015"]} +{"id":"q-12","text":"monitoring job duration histograms","gold":["bench-017"]} +{"id":"q-13","text":"find related records efficiently","gold":["bench-011","bench-030"]} +{"id":"q-14","text":"fixing a memory leak in reconnection logic","gold":["bench-006"]} +{"id":"q-15","text":"ephemeral environments per pull request","gold":["bench-016"]} +{"id":"q-16","text":"upload large files HTTP body limit","gold":["bench-021","bench-012"]} +{"id":"q-17","text":"type safety migration staged rollout","gold":["bench-009"]} +{"id":"q-18","text":"flaky tests on continuous integration","gold":["bench-014"]} +{"id":"q-19","text":"moving from Elasticsearch","gold":["bench-018"]} +{"id":"q-20","text":"launchd plist for a background service","gold":["bench-026"]} +{"id":"q-21","text":"Go profiling heap allocations","gold":["bench-023"]} +{"id":"q-22","text":"N+1 query with GraphQL resolvers","gold":["bench-011","bench-002"]} +{"id":"q-23","text":"automatic observation capture from agent hooks","gold":["bench-029"]} +{"id":"q-24","text":"generic repository pattern in Go","gold":["bench-013"]} +{"id":"q-25","text":"scrubbing high-cardinality identifiers from logs","gold":["bench-024"]} From 7182a056d6a224d026e0bccc80ab3e88584a337a Mon Sep 17 00:00:00 2001 From: arreyder Date: Sat, 18 Apr 2026 11:38:40 -0500 Subject: [PATCH 5/8] Add RRF fuser, embedding provider, and vector schema field MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foundation for hybrid retrieval: - internal/retrieval: Reciprocal Rank Fusion over ranked ID streams with configurable k (default 60), per-stream weights, and top-K cap. 10 unit tests covering agreement boost, weighting, tie-break, and cross-stream behavior. - internal/embed: pluggable embedding Provider interface with an OpenAI text-embedding-3-small implementation. FromEnv() returns nil when no API key is set; callers treat nil as "embeddings disabled" rather than erroring. 7 unit tests (httptest-based, no network). - solr/managed-schema.xml: new knn_vector field type (1536 dim cosine, matching OpenAI default) and an embedding field on memories. Existing docs don't need reindexing — they simply won't show up in vector search. BM25 still works for all docs. - store_memory / bulk_store_memories compute embeddings at write time when a provider is configured. Failures log and fall through to a vector-less write; they never fail the overall store. Query-side KNN integration and RRF wiring in search_memories are scoped to a follow-up commit on this branch. This commit ships only the pieces that are safely inert without the query-side changes. Addresses #16 (pt. 1) and #17. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/solr-mem-server/bulk_store_tool.go | 1 + cmd/solr-mem-server/embed_helper.go | 39 ++++++ cmd/solr-mem-server/main.go | 13 ++ cmd/solr-mem-server/store_tool.go | 1 + internal/embed/openai.go | 102 ++++++++++++++++ internal/embed/openai_test.go | 135 +++++++++++++++++++++ internal/embed/provider.go | 42 +++++++ internal/retrieval/rrf.go | 126 ++++++++++++++++++++ internal/retrieval/rrf_test.go | 159 +++++++++++++++++++++++++ internal/solr/types.go | 1 + solr/managed-schema.xml | 8 ++ 11 files changed, 627 insertions(+) create mode 100644 cmd/solr-mem-server/embed_helper.go create mode 100644 internal/embed/openai.go create mode 100644 internal/embed/openai_test.go create mode 100644 internal/embed/provider.go create mode 100644 internal/retrieval/rrf.go create mode 100644 internal/retrieval/rrf_test.go diff --git a/cmd/solr-mem-server/bulk_store_tool.go b/cmd/solr-mem-server/bulk_store_tool.go index 18d7790..7af99cf 100644 --- a/cmd/solr-mem-server/bulk_store_tool.go +++ b/cmd/solr-mem-server/bulk_store_tool.go @@ -111,6 +111,7 @@ func bulkStoreMemoriesTool(ctx context.Context, args map[string]any) (any, error RelatedIDs: getStringSlice(m, "related_ids"), Format: format, ContentHash: hash, + Embedding: embedForStore(ctx, scrubbedTitle, scrubbedContent), }) } diff --git a/cmd/solr-mem-server/embed_helper.go b/cmd/solr-mem-server/embed_helper.go new file mode 100644 index 0000000..ca77b4a --- /dev/null +++ b/cmd/solr-mem-server/embed_helper.go @@ -0,0 +1,39 @@ +package main + +import ( + "context" + "log" + "strings" + "time" +) + +// embedForStore produces an embedding for a memory write. It combines title +// and content (title first, it's usually the most informative signal) and +// calls the configured provider with a short timeout. On error or when no +// provider is configured, returns nil so callers can still store the doc. +func embedForStore(ctx context.Context, title, content string) []float32 { + if embedProvider == nil { + return nil + } + text := strings.TrimSpace(title) + if content != "" { + if text != "" { + text += "\n\n" + } + text += content + } + if text == "" { + return nil + } + + // Short timeout: a slow provider shouldn't block a memory write. + cctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + v, err := embedProvider.Embed(cctx, text) + if err != nil { + log.Printf("embed: failed (continuing without vector): %v", err) + return nil + } + return v +} diff --git a/cmd/solr-mem-server/main.go b/cmd/solr-mem-server/main.go index e9a4c64..abd0b19 100644 --- a/cmd/solr-mem-server/main.go +++ b/cmd/solr-mem-server/main.go @@ -8,12 +8,14 @@ import ( "net/http" "os" + "github.com/arreyder/solr-mem/internal/embed" "github.com/arreyder/solr-mem/internal/solr" "github.com/modelcontextprotocol/go-sdk/mcp" ) var solrClient *solr.Client var codeClient *solr.Client +var embedProvider embed.Provider func main() { solrURL := os.Getenv("SOLR_URL") @@ -28,6 +30,17 @@ func main() { } codeClient = solr.NewClient(codeURL) + // Optional embedding provider. If no API key is set, embedProvider stays + // nil and the server runs in BM25-only mode. + if p, err := embed.FromEnv(); err != nil { + log.Printf("embedding provider init failed: %v (continuing BM25-only)", err) + } else if p != nil { + embedProvider = p + log.Printf("embedding provider: %s (dim=%d)", p.Name(), p.Dim()) + } else { + log.Printf("no embedding provider configured (set OPENAI_API_KEY to enable)") + } + // Start expiration sweeper ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/cmd/solr-mem-server/store_tool.go b/cmd/solr-mem-server/store_tool.go index 9956823..3f1f577 100644 --- a/cmd/solr-mem-server/store_tool.go +++ b/cmd/solr-mem-server/store_tool.go @@ -88,6 +88,7 @@ func storeMemoryTool(ctx context.Context, args map[string]any) (any, error) { RelatedIDs: getStringSlice(args, "related_ids"), Format: format, ContentHash: hash, + Embedding: embedForStore(ctx, scrubbedTitle, scrubbedContent), } if err := solrClient.Add(ctx, doc); err != nil { diff --git a/internal/embed/openai.go b/internal/embed/openai.go new file mode 100644 index 0000000..476e820 --- /dev/null +++ b/internal/embed/openai.go @@ -0,0 +1,102 @@ +package embed + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// OpenAI provides embeddings via the OpenAI /v1/embeddings API. It is the +// reference provider; other providers follow the same interface. +type OpenAI struct { + apiKey string + model string + dim int + baseURL string + client *http.Client +} + +// DefaultOpenAIModel is OpenAI's current small text embedding model. It emits +// 1536-dimensional vectors, which the Solr schema must match. +const ( + DefaultOpenAIModel = "text-embedding-3-small" + DefaultOpenAIDim = 1536 +) + +// NewOpenAI constructs a client. If model is empty, DefaultOpenAIModel is used. +func NewOpenAI(apiKey, model string) *OpenAI { + if model == "" { + model = DefaultOpenAIModel + } + return &OpenAI{ + apiKey: apiKey, + model: model, + dim: DefaultOpenAIDim, + baseURL: "https://api.openai.com/v1", + client: &http.Client{Timeout: 30 * time.Second}, + } +} + +// Name implements Provider. +func (o *OpenAI) Name() string { return "openai:" + o.model } + +// Dim implements Provider. +func (o *OpenAI) Dim() int { return o.dim } + +// Embed implements Provider. A non-200 response or malformed payload +// returns an error with enough context for debugging. +func (o *OpenAI) Embed(ctx context.Context, text string) ([]float32, error) { + if text == "" { + return nil, fmt.Errorf("embed: empty text") + } + + body, err := json.Marshal(map[string]any{ + "model": o.model, + "input": text, + }) + if err != nil { + return nil, fmt.Errorf("embed: marshal: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, o.baseURL+"/embeddings", bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+o.apiKey) + + resp, err := o.client.Do(req) + if err != nil { + return nil, fmt.Errorf("embed: http: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + snippet, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("embed: status %d: %s", resp.StatusCode, string(snippet)) + } + + var out struct { + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, fmt.Errorf("embed: decode: %w", err) + } + if len(out.Data) == 0 || len(out.Data[0].Embedding) == 0 { + return nil, fmt.Errorf("embed: empty response") + } + v := out.Data[0].Embedding + if len(v) != o.dim { + // Update our advertised dim on first successful call so queries + // don't drift. OpenAI allows a `dimensions` override; we don't + // set one here, so the response should match DefaultOpenAIDim. + o.dim = len(v) + } + return v, nil +} diff --git a/internal/embed/openai_test.go b/internal/embed/openai_test.go new file mode 100644 index 0000000..df2889b --- /dev/null +++ b/internal/embed/openai_test.go @@ -0,0 +1,135 @@ +package embed + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestOpenAIEmbedSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("missing/wrong Authorization header: %q", got) + } + if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/json") { + t.Errorf("wrong content-type: %q", ct) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("bad body: %v", err) + } + if body["model"] != DefaultOpenAIModel { + t.Errorf("wrong model: %v", body["model"]) + } + if body["input"] != "hello" { + t.Errorf("wrong input: %v", body["input"]) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{ + {"embedding": []float32{0.1, 0.2, 0.3}}, + }, + }) + })) + defer server.Close() + + p := NewOpenAI("test-key", "") + p.baseURL = server.URL + + v, err := p.Embed(context.Background(), "hello") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if len(v) != 3 { + t.Errorf("expected 3 dims, got %d", len(v)) + } + if v[0] != 0.1 || v[1] != 0.2 || v[2] != 0.3 { + t.Errorf("unexpected vector: %v", v) + } + // After a successful call with unexpected dim, Dim() should adapt. + if p.Dim() != 3 { + t.Errorf("dim should update to actual response size: got %d", p.Dim()) + } +} + +func TestOpenAIEmbedEmptyInput(t *testing.T) { + p := NewOpenAI("k", "") + if _, err := p.Embed(context.Background(), ""); err == nil { + t.Error("expected error for empty text") + } +} + +func TestOpenAIEmbedErrorStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + })) + defer server.Close() + + p := NewOpenAI("k", "") + p.baseURL = server.URL + + _, err := p.Embed(context.Background(), "hi") + if err == nil { + t.Fatal("expected error on 429") + } + if !strings.Contains(err.Error(), "429") { + t.Errorf("error should include status code, got: %v", err) + } +} + +func TestOpenAIEmbedEmptyData(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"data": []any{}}) + })) + defer server.Close() + + p := NewOpenAI("k", "") + p.baseURL = server.URL + + if _, err := p.Embed(context.Background(), "x"); err == nil { + t.Error("expected error on empty data array") + } +} + +func TestOpenAINameAndDim(t *testing.T) { + p := NewOpenAI("k", "") + if p.Name() != "openai:"+DefaultOpenAIModel { + t.Errorf("name: got %s", p.Name()) + } + if p.Dim() != DefaultOpenAIDim { + t.Errorf("dim: got %d", p.Dim()) + } + + p2 := NewOpenAI("k", "text-embedding-3-large") + if !strings.HasSuffix(p2.Name(), "text-embedding-3-large") { + t.Errorf("custom model not reflected in name: %s", p2.Name()) + } +} + +func TestFromEnvNoProvider(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "") + p, err := FromEnv() + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if p != nil { + t.Errorf("expected nil provider, got %v", p) + } +} + +func TestFromEnvOpenAI(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "sk-test") + t.Setenv("OPENAI_EMBEDDING_MODEL", "custom-model") + p, err := FromEnv() + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if p == nil { + t.Fatal("expected non-nil provider") + } + if !strings.Contains(p.Name(), "custom-model") { + t.Errorf("provider should use configured model: %s", p.Name()) + } +} diff --git a/internal/embed/provider.go b/internal/embed/provider.go new file mode 100644 index 0000000..27ff7ef --- /dev/null +++ b/internal/embed/provider.go @@ -0,0 +1,42 @@ +// Package embed turns text into dense vector embeddings for semantic search. +// +// Providers are pluggable. Users configure one via environment variables: +// +// OPENAI_API_KEY -> OpenAI text-embedding-3-small (1536 dim by default) +// OPENAI_EMBEDDING_MODEL -> override model (optional) +// +// If no provider is configured, FromEnv returns (nil, nil) and the server +// runs in BM25-only mode. Callers MUST handle a nil provider as a signal to +// skip embedding rather than erroring. +package embed + +import ( + "context" + "os" + "strings" +) + +// Provider produces a dense vector for a piece of text. Implementations +// should be safe for concurrent use. +type Provider interface { + // Embed returns the embedding for text. The returned slice length must + // equal Dim() and must be stable across calls on the same provider. + Embed(ctx context.Context, text string) ([]float32, error) + + // Dim is the output dimension. Used to validate schema configuration. + Dim() int + + // Name is a short identifier for logging / metrics. + Name() string +} + +// FromEnv constructs a provider from environment variables. Returns (nil, nil) +// if no provider is configured — callers should treat nil as "embeddings +// disabled" rather than an error. +func FromEnv() (Provider, error) { + if key := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")); key != "" { + model := strings.TrimSpace(os.Getenv("OPENAI_EMBEDDING_MODEL")) + return NewOpenAI(key, model), nil + } + return nil, nil +} diff --git a/internal/retrieval/rrf.go b/internal/retrieval/rrf.go new file mode 100644 index 0000000..ca2c6b5 --- /dev/null +++ b/internal/retrieval/rrf.go @@ -0,0 +1,126 @@ +// Package retrieval contains the search-side primitives used by the +// solr-mem server to combine multiple retrieval signals. +package retrieval + +// Stream is one ranked list of document IDs from a single retrieval source +// (e.g. BM25 keyword match, dense vector KNN, graph walk). Order matters: +// position 0 is the top-ranked hit. +type Stream struct { + Name string + IDs []string +} + +// RRFOption configures the fuser. +type RRFOption func(*rrfConfig) + +type rrfConfig struct { + k int + topK int + weight map[string]float64 +} + +// RRFWithK sets the rank-damping constant. Larger k compresses scores across +// streams, making them count more evenly. The canonical default is 60 (per +// the original Cormack et al. paper). +func RRFWithK(k int) RRFOption { + return func(c *rrfConfig) { + if k > 0 { + c.k = k + } + } +} + +// RRFWithTopK truncates the returned fused list. 0 means "return everything +// that appeared in any stream". +func RRFWithTopK(n int) RRFOption { + return func(c *rrfConfig) { + c.topK = n + } +} + +// RRFWithWeight overrides the default 1.0 weight for a named stream. Useful +// if one signal is known to be higher-quality for a given query class. +// Streams not in the map keep weight 1.0. +func RRFWithWeight(weights map[string]float64) RRFOption { + return func(c *rrfConfig) { + c.weight = weights + } +} + +// RRFScored pairs an ID with its fused score. Sorted high-to-low on return +// from Fuse. +type RRFScored struct { + ID string + Score float64 +} + +// Fuse combines multiple ranked streams with Reciprocal Rank Fusion: +// +// score(d) = sum over streams i of weight_i / (k + rank_i(d)) +// +// rank is 1-indexed. Streams with fewer results than another don't penalize +// the missing IDs — they simply don't contribute. The returned list is +// sorted by fused score descending; ties break on lexicographic ID for +// deterministic output. +func Fuse(streams []Stream, opts ...RRFOption) []RRFScored { + cfg := rrfConfig{k: 60} + for _, opt := range opts { + opt(&cfg) + } + + scores := make(map[string]float64) + for _, s := range streams { + w := 1.0 + if cfg.weight != nil { + if v, ok := cfg.weight[s.Name]; ok { + w = v + } + } + for rank, id := range s.IDs { + if id == "" { + continue + } + scores[id] += w / float64(cfg.k+rank+1) + } + } + + out := make([]RRFScored, 0, len(scores)) + for id, sc := range scores { + out = append(out, RRFScored{ID: id, Score: sc}) + } + // Sort by score desc, then ID asc for determinism. + sortFused(out) + + if cfg.topK > 0 && len(out) > cfg.topK { + out = out[:cfg.topK] + } + return out +} + +// FuseIDs is a thin wrapper around Fuse that returns just the ranked IDs. +func FuseIDs(streams []Stream, opts ...RRFOption) []string { + scored := Fuse(streams, opts...) + ids := make([]string, len(scored)) + for i, s := range scored { + ids[i] = s.ID + } + return ids +} + +// sortFused sorts in-place: descending score, ascending ID for ties. +func sortFused(xs []RRFScored) { + // Simple selection sort is fine for small N; most memory queries return + // <= 50 hits. Using sort.Slice would pull in sort; inlined for zero-dep. + for i := 0; i < len(xs); i++ { + best := i + for j := i + 1; j < len(xs); j++ { + if xs[j].Score > xs[best].Score || + (xs[j].Score == xs[best].Score && xs[j].ID < xs[best].ID) { + best = j + } + } + if best != i { + xs[i], xs[best] = xs[best], xs[i] + } + } +} diff --git a/internal/retrieval/rrf_test.go b/internal/retrieval/rrf_test.go new file mode 100644 index 0000000..46a6f45 --- /dev/null +++ b/internal/retrieval/rrf_test.go @@ -0,0 +1,159 @@ +package retrieval + +import ( + "math" + "testing" +) + +func approx(a, b float64) bool { return math.Abs(a-b) < 1e-9 } + +func TestFuseSingleStreamPreservesOrder(t *testing.T) { + streams := []Stream{ + {Name: "bm25", IDs: []string{"a", "b", "c"}}, + } + got := FuseIDs(streams) + want := []string{"a", "b", "c"} + if !equal(got, want) { + t.Errorf("single stream order: got %v, want %v", got, want) + } +} + +func TestFuseBothStreamsAgreeTopGoesFirst(t *testing.T) { + streams := []Stream{ + {Name: "bm25", IDs: []string{"a", "b", "c"}}, + {Name: "vec", IDs: []string{"a", "b", "c"}}, + } + got := FuseIDs(streams) + if len(got) != 3 || got[0] != "a" { + t.Errorf("full agreement should put a first, got %v", got) + } +} + +func TestFuseCrossStreamBoostsHitsInBoth(t *testing.T) { + // c appears in both but never at the top; a appears once; b appears once. + // Canonical RRF with k=60: score(c) = 1/62 + 1/62 = ~0.0323 + // score(a) = 1/61 = ~0.0164 + // score(b) = 1/61 = ~0.0164 + // So c should rank first. + streams := []Stream{ + {Name: "bm25", IDs: []string{"a", "c"}}, + {Name: "vec", IDs: []string{"b", "c"}}, + } + got := FuseIDs(streams) + if got[0] != "c" { + t.Errorf("cross-stream hit should win: got %v", got) + } +} + +func TestFuseKnownScores(t *testing.T) { + // Verify exact RRF formula with k=10 for cleaner numbers. + streams := []Stream{ + {Name: "s1", IDs: []string{"x"}}, + {Name: "s2", IDs: []string{"x"}}, + } + got := Fuse(streams, RRFWithK(10)) + if len(got) != 1 { + t.Fatalf("expected 1 result, got %d", len(got)) + } + // score(x) = 1/(10+1) + 1/(10+1) = 2/11 + want := 2.0 / 11.0 + if !approx(got[0].Score, want) { + t.Errorf("expected score %f, got %f", want, got[0].Score) + } +} + +func TestFuseWeights(t *testing.T) { + // Equal-rank hits should split by weight. + streams := []Stream{ + {Name: "high", IDs: []string{"a"}}, + {Name: "low", IDs: []string{"b"}}, + } + got := Fuse(streams, RRFWithK(10), RRFWithWeight(map[string]float64{"high": 2.0, "low": 0.5})) + var scoreA, scoreB float64 + for _, s := range got { + if s.ID == "a" { + scoreA = s.Score + } + if s.ID == "b" { + scoreB = s.Score + } + } + // score(a) = 2.0/11 = 0.1818..., score(b) = 0.5/11 = 0.0454... + if scoreA <= scoreB { + t.Errorf("higher weight should produce higher score: a=%f b=%f", scoreA, scoreB) + } + if !approx(scoreA, 2.0/11.0) { + t.Errorf("expected score(a)=2/11, got %f", scoreA) + } +} + +func TestFuseTopK(t *testing.T) { + streams := []Stream{ + {Name: "s", IDs: []string{"a", "b", "c", "d", "e"}}, + } + got := FuseIDs(streams, RRFWithTopK(3)) + if len(got) != 3 { + t.Errorf("expected 3, got %d", len(got)) + } +} + +func TestFuseEmptyInput(t *testing.T) { + if got := FuseIDs(nil); len(got) != 0 { + t.Errorf("nil input should produce empty, got %v", got) + } + if got := FuseIDs([]Stream{{Name: "s"}}); len(got) != 0 { + t.Errorf("empty stream should produce empty, got %v", got) + } +} + +func TestFuseIgnoresEmptyIDs(t *testing.T) { + streams := []Stream{ + {Name: "s", IDs: []string{"a", "", "b"}}, + } + got := FuseIDs(streams) + // Empty IDs are skipped; remaining ranks keep their original positions. + // So "a" is rank 0, "b" is rank 2 (even though "a" is next to it). + if len(got) != 2 || got[0] != "a" || got[1] != "b" { + t.Errorf("empty IDs should be skipped: got %v", got) + } +} + +func TestFuseDeterministicTieBreak(t *testing.T) { + // Two IDs with identical scores should sort by ID ascending for stable output. + streams := []Stream{ + {Name: "s1", IDs: []string{"zebra"}}, + {Name: "s2", IDs: []string{"apple"}}, + } + got := FuseIDs(streams) + if len(got) != 2 || got[0] != "apple" { + t.Errorf("tie break should favor lexicographically-smaller ID: got %v", got) + } +} + +func TestFuseMissingFromOneStream(t *testing.T) { + // Asymmetric streams: "only-a" appears only in s1. Its score should + // come only from that stream's rank. + streams := []Stream{ + {Name: "s1", IDs: []string{"only-a", "shared"}}, + {Name: "s2", IDs: []string{"shared", "only-b"}}, + } + got := Fuse(streams, RRFWithK(10)) + // shared is rank 1 in both: 1/11 + 1/12 = ~0.174 + // only-a is rank 0 in s1: 1/11 = 0.0909 + // only-b is rank 1 in s2: 1/12 = 0.0833 + if got[0].ID != "shared" { + t.Errorf("shared should rank first: got %v", got) + } +} + +func equal(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/solr/types.go b/internal/solr/types.go index 78cc4ab..ae21e80 100644 --- a/internal/solr/types.go +++ b/internal/solr/types.go @@ -21,6 +21,7 @@ type Document struct { RelatedIDs []string `json:"related_ids,omitempty"` Format string `json:"format,omitempty"` ContentHash string `json:"content_hash,omitempty"` + Embedding []float32 `json:"embedding,omitempty"` } // QueryParams holds parameters for a Solr search query. diff --git a/solr/managed-schema.xml b/solr/managed-schema.xml index f6c9a6e..580e254 100644 --- a/solr/managed-schema.xml +++ b/solr/managed-schema.xml @@ -9,6 +9,13 @@ + + + @@ -49,5 +56,6 @@ + From 7b34458d1af13833385543a4eabd7a90baf16169 Mon Sep 17 00:00:00 2001 From: arreyder Date: Sat, 18 Apr 2026 11:44:45 -0500 Subject: [PATCH 6/8] Wire KNN + RRF into search_memories MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit search_memories now runs BM25 and KNN in parallel when an embedding provider is configured, then fuses with RRF (default k=60). Filters (agent_id, tags, session_id, etc.) are shared across both streams so semantic hits respect the same scoping as keyword hits. The KNN pool is widened 3× the user limit by default to give fusion more overlap. New args on search_memories: - semantic: bool, default true when a provider is configured; false forces BM25-only - knn_topk: int, size of the KNN pool before fusion Session-diversification (cap N per session_id) runs after fusion, so the post-RRF output still respects the per-session cap. Existing BM25-only behavior is unchanged when no provider is configured or semantic=false. Embed failures log and fall through to BM25 — semantic is always opt-in progress, never a regression risk. Existing docs without an embedding won't appear in KNN hits but still appear in BM25. Closes #17. Completes the query-side half of #16 (backfill tooling for existing memories is a separate issue). Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/solr-mem-server/knn.go | 32 ++++++ cmd/solr-mem-server/search_tool.go | 156 ++++++++++++++++++++++------- cmd/solr-mem-server/tools.go | 6 +- 3 files changed, 157 insertions(+), 37 deletions(-) create mode 100644 cmd/solr-mem-server/knn.go diff --git a/cmd/solr-mem-server/knn.go b/cmd/solr-mem-server/knn.go new file mode 100644 index 0000000..15a7812 --- /dev/null +++ b/cmd/solr-mem-server/knn.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/arreyder/solr-mem/internal/solr" +) + +// knnQueryString formats a Solr KNN local-params query for the embedding field. +// Produces: {!knn f=embedding topK=N}[v1,v2,...] +func knnQueryString(vec []float32, topK int) string { + parts := make([]string, len(vec)) + for i, v := range vec { + parts[i] = strconv.FormatFloat(float64(v), 'f', -1, 32) + } + return fmt.Sprintf("{!knn f=embedding topK=%d}[%s]", topK, strings.Join(parts, ",")) +} + +// runKNN performs a KNN search against the embedding field. Filters from the +// caller (agent_id, tags, etc.) are passed through as fq so semantic hits +// respect the same scoping as BM25 hits. +func runKNN(ctx context.Context, client *solr.Client, vec []float32, filters []string, topK int) (*solr.QueryResponse, error) { + params := solr.QueryParams{ + Query: knnQueryString(vec, topK), + FilterQueries: filters, + Rows: topK, + } + return client.Query(ctx, params) +} diff --git a/cmd/solr-mem-server/search_tool.go b/cmd/solr-mem-server/search_tool.go index 2efee40..cc23c95 100644 --- a/cmd/solr-mem-server/search_tool.go +++ b/cmd/solr-mem-server/search_tool.go @@ -4,8 +4,10 @@ import ( "context" "encoding/json" "fmt" + "log" "strings" + "github.com/arreyder/solr-mem/internal/retrieval" "github.com/arreyder/solr-mem/internal/solr" ) @@ -15,42 +17,85 @@ func searchMemoriesTool(ctx context.Context, args map[string]any) (any, error) { return nil, fmt.Errorf("query is required") } - params := solr.QueryParams{ - Query: query, - Rows: getInt(args, "limit", 10), - Highlight: getBool(args, "highlight", true), - Facet: getBool(args, "facet", false), + rows := getInt(args, "limit", 10) + + // Build filter queries shared by both BM25 and KNN passes. This keeps + // semantic hits scoped to the same agent / session / tag set as keyword + // hits. + filters := buildSearchFilters(args) + + bm25Params := solr.QueryParams{ + Query: query, + Rows: rows, + Highlight: getBool(args, "highlight", true), + Facet: getBool(args, "facet", false), + FilterQueries: filters, + } + if bm25Params.Facet { + bm25Params.FacetFields = []string{"memory_type", "tags", "agent_id", "source"} + } + + bm25Resp, err := solrClient.Query(ctx, bm25Params) + if err != nil { + return nil, fmt.Errorf("search failed: %w", err) + } + + // Semantic (KNN) stream: only runs when the server has an embed provider + // configured and the caller hasn't opted out via semantic=false. + semantic := getBool(args, "semantic", true) && embedProvider != nil + knnTopK := getInt(args, "knn_topk", rows*3) // fetch wider so fusion has something to cross-pollinate + + resp := bm25Resp + if semantic { + if fused, err := runHybrid(ctx, query, filters, bm25Resp, knnTopK, rows); err == nil && fused != nil { + resp = fused + } else if err != nil { + log.Printf("search: semantic fallback to BM25-only: %v", err) + } + } + + // Cap per session so one chatty session can't dominate results. + // Default 3; pass 0 to disable. + sessionCap := getInt(args, "session_cap", 3) + if sessionCap > 0 { + resp.Docs = diversifyBySession(resp.Docs, func(d map[string]any) string { + s, _ := d["session_id"].(string) + return s + }, sessionCap) } - // Build filter queries + return ToolOutput{ + Text: formatSearchResults(resp), + Structured: resp, + }, nil +} + +// buildSearchFilters translates the search_memories args into Solr fq clauses. +func buildSearchFilters(args map[string]any) []string { + var filters []string if v := getString(args, "agent_id"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("agent_id:%q", v)) + filters = append(filters, fmt.Sprintf("agent_id:%q", v)) } if v := getString(args, "memory_type"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("memory_type:%q", v)) + filters = append(filters, fmt.Sprintf("memory_type:%q", v)) } if v := getString(args, "source"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("source:%q", v)) + filters = append(filters, fmt.Sprintf("source:%q", v)) } if tags := getStringSlice(args, "tags"); len(tags) > 0 { for _, tag := range tags { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("tags:%q", tag)) + filters = append(filters, fmt.Sprintf("tags:%q", tag)) } } if v := getString(args, "session_id"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("session_id:%q", v)) + filters = append(filters, fmt.Sprintf("session_id:%q", v)) } if v := getString(args, "lifetime"); v != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("lifetime:%q", v)) + filters = append(filters, fmt.Sprintf("lifetime:%q", v)) } - - // Importance filter - impMin := getString(args, "importance_min") - if impMin != "" { - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("importance:[%s TO *]", impMin)) + if impMin := getString(args, "importance_min"); impMin != "" { + filters = append(filters, fmt.Sprintf("importance:[%s TO *]", impMin)) } - - // Date range filters from := getString(args, "from") to := getString(args, "to") if from != "" || to != "" { @@ -62,34 +107,73 @@ func searchMemoriesTool(ctx context.Context, args map[string]any) (any, error) { if to != "" { toVal = to } - params.FilterQueries = append(params.FilterQueries, fmt.Sprintf("created_at:[%s TO %s]", fromVal, toVal)) + filters = append(filters, fmt.Sprintf("created_at:[%s TO %s]", fromVal, toVal)) } + return filters +} - if params.Facet { - params.FacetFields = []string{"memory_type", "tags", "agent_id", "source"} +// runHybrid embeds the query, runs a KNN search alongside the BM25 hits +// already in bm25Resp, and fuses them with RRF. Returns a synthetic +// QueryResponse whose Docs are in fused order and whose Facets/Highlighting +// are inherited from the BM25 response. Returns (nil, nil) if the embed +// failed softly (KNN skipped) — caller stays on BM25. +func runHybrid(ctx context.Context, query string, filters []string, bm25Resp *solr.QueryResponse, knnTopK, finalRows int) (*solr.QueryResponse, error) { + vec, err := embedProvider.Embed(ctx, query) + if err != nil { + return nil, fmt.Errorf("embed query: %w", err) } - - resp, err := solrClient.Query(ctx, params) + knnResp, err := runKNN(ctx, solrClient, vec, filters, knnTopK) if err != nil { - return nil, fmt.Errorf("search failed: %w", err) + return nil, fmt.Errorf("knn query: %w", err) } - // Cap per session so one chatty session can't dominate results. - // Default 3; pass 0 to disable. - sessionCap := getInt(args, "session_cap", 3) - if sessionCap > 0 { - resp.Docs = diversifyBySession(resp.Docs, func(d map[string]any) string { - s, _ := d["session_id"].(string) - return s - }, sessionCap) + bm25IDs := idsFromDocs(bm25Resp.Docs) + knnIDs := idsFromDocs(knnResp.Docs) + + fused := retrieval.FuseIDs([]retrieval.Stream{ + {Name: "bm25", IDs: bm25IDs}, + {Name: "vec", IDs: knnIDs}, + }, retrieval.RRFWithTopK(finalRows)) + + // Hydrate: prefer the BM25 doc (has highlighting) over the KNN doc for + // any ID present in both. + byID := make(map[string]map[string]any, len(bm25Resp.Docs)+len(knnResp.Docs)) + for _, d := range knnResp.Docs { + if id, ok := d["id"].(string); ok { + byID[id] = d + } + } + for _, d := range bm25Resp.Docs { + if id, ok := d["id"].(string); ok { + byID[id] = d + } } - return ToolOutput{ - Text: formatSearchResults(resp), - Structured: resp, + ordered := make([]map[string]any, 0, len(fused)) + for _, id := range fused { + if d, ok := byID[id]; ok { + ordered = append(ordered, d) + } + } + + return &solr.QueryResponse{ + NumFound: len(ordered), + Docs: ordered, + Highlighting: bm25Resp.Highlighting, + Facets: bm25Resp.Facets, }, nil } +func idsFromDocs(docs []map[string]any) []string { + out := make([]string, 0, len(docs)) + for _, d := range docs { + if id, ok := d["id"].(string); ok { + out = append(out, id) + } + } + return out +} + func formatSearchResults(resp *solr.QueryResponse) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("Found %d memories.\n\n", resp.NumFound)) diff --git a/cmd/solr-mem-server/tools.go b/cmd/solr-mem-server/tools.go index bd57b31..9b87d11 100644 --- a/cmd/solr-mem-server/tools.go +++ b/cmd/solr-mem-server/tools.go @@ -61,7 +61,9 @@ func ToolSchemas() []ToolDefinition { **When to use**: Find relevant memories by content, tags, type, agent, or time range. Uses edismax with field boosting (content^3, title^2, tags^1.5) and recency boost. **Required**: query (search text) -**Optional**: agent_id, memory_type, tags, source, importance_min, from, to, limit, highlight, facet, session_id, lifetime, session_cap`, +**Optional**: agent_id, memory_type, tags, source, importance_min, from, to, limit, highlight, facet, session_id, lifetime, session_cap, semantic, knn_topk + +**Semantic search**: When the server has an embedding provider configured (OPENAI_API_KEY set), results combine BM25 keyword hits with KNN vector hits via Reciprocal Rank Fusion. Pass semantic=false to force BM25-only. knn_topk widens the KNN pool before fusion (default: 3× limit).`, InputSchema: NewObjectSchema(map[string]any{ "query": prop("string", "Full-text search query (required)"), "agent_id": prop("string", "Filter by agent ID"), @@ -77,6 +79,8 @@ func ToolSchemas() []ToolDefinition { "session_id": prop("string", "Filter by session ID"), "lifetime": prop("string", "Filter by lifetime (permanent, session, ephemeral, temporary)"), "session_cap": integerProp("Max hits per session_id after ranking (default 3, 0 = disable)", intPtr(0), intPtr(100)), + "semantic": prop("boolean", "Include KNN vector stream and fuse with BM25 via RRF (default: true when an embedding provider is configured)"), + "knn_topk": integerProp("KNN pool size before fusion (default: 3× limit)", intPtr(0), intPtr(500)), }, "query"), }, Handler: searchMemoriesTool, From 8280ea3dfca942b7463a38c9fab77c2a3c1986e8 Mon Sep 17 00:00:00 2001 From: arreyder Date: Sat, 18 Apr 2026 12:39:14 -0500 Subject: [PATCH 7/8] Backfill existing memories with embeddings New cmd/solr-mem-backfill binary that scans memories missing the embedding field and writes vectors back via atomic update. Needed so existing memories (written before this branch's vector support) show up in KNN hits after the PR merges; without it, semantic search silently ignores pre-existing docs. Flow: - Query "-_exists_:embedding" (paginated, default 50 per batch) - Embed title+content via the configured provider (OPENAI_API_KEY) - Atomic-update the doc with the new embedding field - Track a per-run "seen" set so failed embeds can't cause infinite loops (failed docs never get the field written, so they'd otherwise keep re-appearing in the query) Options: -batch-size, -concurrency (default 4 parallel embed calls), -dry-run, -force (re-embed existing), -max-docs, -pause-ms (rate headroom for bursty providers). Makefile: make backfill / make backfill-dry targets, BACKFILL_URL override. Tests: 7 unit tests using httptest as a fake Solr and a stub embed provider. Cover happy path, dry-run, max-docs cap, error resilience (proves the seen-set prevents infinite loops on persistent failures), and early termination. No live Solr or API key required. --- .gitignore | 1 + Makefile | 16 +- cmd/solr-mem-backfill/backfill.go | 240 ++++++++++++++++++++ cmd/solr-mem-backfill/backfill_test.go | 296 +++++++++++++++++++++++++ cmd/solr-mem-backfill/main.go | 62 ++++++ 5 files changed, 613 insertions(+), 2 deletions(-) create mode 100644 cmd/solr-mem-backfill/backfill.go create mode 100644 cmd/solr-mem-backfill/backfill_test.go create mode 100644 cmd/solr-mem-backfill/main.go diff --git a/.gitignore b/.gitignore index 21f8f37..c97762e 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ Thumbs.db solr-data/ /solr-mem-indexer /solr-mem-bench +/solr-mem-backfill diff --git a/Makefile b/Makefile index 11f7bd6..36b9e79 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-indexer build-bench install install-indexer test tidy run dev up down logs reset docker-build docker-up config systemd-install systemd-uninstall indexer-install indexer-uninstall launchd-install launchd-install-server launchd-install-indexer launchd-uninstall skill-install skill-uninstall bench +.PHONY: build build-indexer build-bench build-backfill install install-indexer test tidy run dev up down logs reset docker-build docker-up config systemd-install systemd-uninstall indexer-install indexer-uninstall launchd-install launchd-install-server launchd-install-indexer launchd-uninstall skill-install skill-uninstall bench backfill backfill-dry # Go targets build: @@ -7,11 +7,23 @@ build: build-indexer: go build -o bin/solr-mem-indexer ./cmd/solr-mem-indexer -build-all: build build-indexer build-bench +build-all: build build-indexer build-bench build-backfill build-bench: go build -o bin/solr-mem-bench ./cmd/solr-mem-bench +build-backfill: + go build -o bin/solr-mem-backfill ./cmd/solr-mem-backfill + +# Backfill embeddings on existing memories. Requires OPENAI_API_KEY. +# Use backfill-dry to preview without writing. +BACKFILL_URL ?= http://pax89.local:8983/solr/memories +backfill: build-backfill + ./bin/solr-mem-backfill -solr-url $(BACKFILL_URL) + +backfill-dry: build-backfill + ./bin/solr-mem-backfill -solr-url $(BACKFILL_URL) -dry-run + # Retrieval benchmark. Seeds the memories collection with namespaced bench-* # docs (safe to run against a live collection — only touches bench-* IDs) and # runs the shipped query set. Override BENCH_URL to point elsewhere. diff --git a/cmd/solr-mem-backfill/backfill.go b/cmd/solr-mem-backfill/backfill.go new file mode 100644 index 0000000..4e00631 --- /dev/null +++ b/cmd/solr-mem-backfill/backfill.go @@ -0,0 +1,240 @@ +// Package main is the solr-mem-backfill tool: it computes embeddings for +// existing memories that don't have one yet and writes them back via atomic +// update. Safe to run incrementally — it always queries "missing embedding" +// first, so a partial run can be resumed just by running it again. +package main + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "github.com/arreyder/solr-mem/internal/embed" + "github.com/arreyder/solr-mem/internal/solr" +) + +// Options controls a backfill run. +type Options struct { + BatchSize int // docs per Solr page (default 50) + Concurrency int // parallel embed calls (default 4) + DryRun bool // compute but don't write + MaxDocs int // cap total processed (0 = unlimited) + Force bool // re-embed docs even if they already have an embedding + Pause time.Duration // sleep between batches (rate-limit headroom) +} + +// Stats summarize a run. +type Stats struct { + Scanned int // docs returned from Solr + Skipped int // already had an embedding (and !Force) + Embedded int // docs we computed vectors for + Written int // docs we actually wrote back (0 when DryRun) + Errors int // embed failures; these get logged but don't abort +} + +// Run scans the memories collection for docs needing embeddings, computes +// them, and writes back. It returns when the backlog is empty, MaxDocs is +// hit, or ctx is canceled. +func Run(ctx context.Context, client *solr.Client, provider embed.Provider, opts Options) (Stats, error) { + if provider == nil { + return Stats{}, fmt.Errorf("backfill: no embedding provider configured") + } + if opts.BatchSize <= 0 { + opts.BatchSize = 50 + } + if opts.Concurrency <= 0 { + opts.Concurrency = 4 + } + + // Track IDs we've already attempted this run. Without this, a doc that + // fails to embed (and therefore never gets its field written) would keep + // showing up in every "-_exists_:embedding" query and loop forever. + seen := make(map[string]bool) + + var stats Stats + for { + if opts.MaxDocs > 0 && stats.Embedded >= opts.MaxDocs { + break + } + if ctx.Err() != nil { + return stats, ctx.Err() + } + + batch, err := fetchBatch(ctx, client, opts.Force, opts.BatchSize) + if err != nil { + return stats, fmt.Errorf("fetch batch: %w", err) + } + if len(batch) == 0 { + break + } + stats.Scanned += len(batch) + + // Build the work list: skip docs already attempted this run, and + // skip docs with existing embeddings unless Force. + work := make([]memoryDoc, 0, len(batch)) + for _, d := range batch { + if seen[d.ID] { + continue + } + seen[d.ID] = true + if !opts.Force && d.HasEmbedding { + stats.Skipped++ + continue + } + work = append(work, d) + } + if opts.MaxDocs > 0 && stats.Embedded+len(work) > opts.MaxDocs { + work = work[:opts.MaxDocs-stats.Embedded] + } + + // If every doc in this page was already seen, we've exhausted the + // backlog that this run can make progress on. + if len(work) == 0 { + break + } + + updates := embedBatch(ctx, provider, work, opts.Concurrency, &stats) + + if len(updates) > 0 && !opts.DryRun { + if err := client.BulkUpdate(ctx, updates); err != nil { + return stats, fmt.Errorf("bulk update: %w", err) + } + stats.Written += len(updates) + } + + log.Printf("backfill: scanned=%d skipped=%d embedded=%d written=%d errors=%d", + stats.Scanned, stats.Skipped, stats.Embedded, stats.Written, stats.Errors) + + if opts.Pause > 0 { + select { + case <-ctx.Done(): + return stats, ctx.Err() + case <-time.After(opts.Pause): + } + } + } + return stats, nil +} + +type memoryDoc struct { + ID string + Title string + Content string + HasEmbedding bool +} + +// fetchBatch returns up to batchSize memory docs. When !force, scopes to +// docs missing the embedding field so subsequent calls make progress as we +// write back. +func fetchBatch(ctx context.Context, client *solr.Client, force bool, batchSize int) ([]memoryDoc, error) { + params := solr.QueryParams{ + Query: "*:*", + Rows: batchSize, + Fields: []string{"id", "title", "content", "embedding"}, + Sort: "created_at asc", + } + if !force { + // "-_exists_:embedding" is Solr's canonical missing-field filter + // and works for any field type, including DenseVectorField. + params.FilterQueries = append(params.FilterQueries, "-_exists_:embedding") + } + resp, err := client.Query(ctx, params) + if err != nil { + return nil, err + } + out := make([]memoryDoc, 0, len(resp.Docs)) + for _, d := range resp.Docs { + id, _ := d["id"].(string) + title, _ := d["title"].(string) + content, _ := d["content"].(string) + has := false + if v, ok := d["embedding"]; ok { + if arr, ok := v.([]any); ok && len(arr) > 0 { + has = true + } + } + out = append(out, memoryDoc{ID: id, Title: title, Content: content, HasEmbedding: has}) + } + return out, nil +} + +// embedBatch runs opts.Concurrency embed calls in parallel and returns the +// atomic-update payloads for docs that succeeded. Failures are logged and +// counted but don't stop the run. +func embedBatch(ctx context.Context, provider embed.Provider, docs []memoryDoc, concurrency int, stats *Stats) []map[string]any { + if len(docs) == 0 { + return nil + } + + type result struct { + ID string + Embedding []float32 + Err error + } + + in := make(chan memoryDoc) + out := make(chan result) + + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for d := range in { + text := buildEmbedText(d.Title, d.Content) + if text == "" { + out <- result{ID: d.ID, Err: fmt.Errorf("empty text")} + continue + } + cctx, cancel := context.WithTimeout(ctx, 30*time.Second) + v, err := provider.Embed(cctx, text) + cancel() + out <- result{ID: d.ID, Embedding: v, Err: err} + } + }() + } + + go func() { + for _, d := range docs { + select { + case <-ctx.Done(): + break + case in <- d: + } + } + close(in) + wg.Wait() + close(out) + }() + + updates := make([]map[string]any, 0, len(docs)) + for r := range out { + if r.Err != nil { + stats.Errors++ + log.Printf("backfill: embed %s failed: %v", r.ID, r.Err) + continue + } + stats.Embedded++ + updates = append(updates, map[string]any{ + "id": r.ID, + "embedding": map[string]any{"set": r.Embedding}, + }) + } + return updates +} + +// buildEmbedText joins title and content with a blank line. Mirrors the +// shape used at write time in the server (embed_helper.go) so backfilled +// vectors live in the same space as live writes. +func buildEmbedText(title, content string) string { + t := title + if t != "" && content != "" { + return t + "\n\n" + content + } + if t != "" { + return t + } + return content +} diff --git a/cmd/solr-mem-backfill/backfill_test.go b/cmd/solr-mem-backfill/backfill_test.go new file mode 100644 index 0000000..b255767 --- /dev/null +++ b/cmd/solr-mem-backfill/backfill_test.go @@ -0,0 +1,296 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/arreyder/solr-mem/internal/solr" +) + +// fakeProvider returns a deterministic 3-dim vector for any input, or errors +// when Fail is non-nil. +type fakeProvider struct { + mu sync.Mutex + calls []string + Fail map[string]error // text -> err +} + +func (f *fakeProvider) Embed(ctx context.Context, text string) ([]float32, error) { + f.mu.Lock() + f.calls = append(f.calls, text) + f.mu.Unlock() + if err, ok := f.Fail[text]; ok { + return nil, err + } + return []float32{0.1, 0.2, 0.3}, nil +} +func (f *fakeProvider) Dim() int { return 3 } +func (f *fakeProvider) Name() string { return "fake" } + +// fakeSolr is a minimal httptest.Server standing in for the memories +// collection. It supports /select (returning the configured docs, shrunk +// after each update call) and /update (accepting atomic updates and +// removing the target doc from the "missing embedding" set). +type fakeSolr struct { + mu sync.Mutex + docs []map[string]any // mutable "collection" + queries int + updates [][]map[string]any // captured update payloads per request +} + +func newFakeSolr(initial []map[string]any) *fakeSolr { + return &fakeSolr{docs: initial} +} + +func (s *fakeSolr) handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/select"): + s.handleSelect(w, r) + case strings.Contains(r.URL.Path, "/update"): + s.handleUpdate(w, r) + default: + http.Error(w, "not found: "+r.URL.Path, 404) + } + }) +} + +func (s *fakeSolr) handleSelect(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + s.queries++ + + fqs := r.URL.Query()["fq"] + missingOnly := false + for _, fq := range fqs { + if fq == "-_exists_:embedding" { + missingOnly = true + break + } + } + + var out []map[string]any + for _, d := range s.docs { + if missingOnly { + if v, ok := d["embedding"]; ok && v != nil { + if arr, ok := v.([]any); ok && len(arr) > 0 { + continue + } + if _, ok := v.([]float32); ok { + continue + } + } + } + out = append(out, d) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "response": map[string]any{ + "numFound": len(out), + "start": 0, + "docs": out, + }, + }) +} + +func (s *fakeSolr) handleUpdate(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + body, _ := io.ReadAll(r.Body) + var payload []map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + http.Error(w, err.Error(), 400) + return + } + s.updates = append(s.updates, payload) + + // Apply atomic updates to our in-memory corpus. + for _, upd := range payload { + id, _ := upd["id"].(string) + for i, d := range s.docs { + if d["id"] != id { + continue + } + for k, v := range upd { + if k == "id" { + continue + } + if setter, ok := v.(map[string]any); ok { + if val, has := setter["set"]; has { + // Normalize to []any for consistent downstream checks. + if vec, ok := val.([]float32); ok { + arr := make([]any, len(vec)) + for j, f := range vec { + arr[j] = float64(f) + } + val = arr + } + s.docs[i][k] = val + } + } + } + } + } + _, _ = w.Write([]byte(`{"responseHeader":{"status":0}}`)) +} + +func TestRunEmbedsMissingDocs(t *testing.T) { + fs := newFakeSolr([]map[string]any{ + {"id": "a", "title": "A", "content": "alpha"}, + {"id": "b", "title": "B", "content": "beta"}, + }) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{} + + stats, err := Run(context.Background(), client, provider, Options{BatchSize: 10, Concurrency: 2}) + if err != nil { + t.Fatalf("run: %v", err) + } + if stats.Embedded != 2 { + t.Errorf("expected 2 embedded, got %d", stats.Embedded) + } + if stats.Written != 2 { + t.Errorf("expected 2 written, got %d", stats.Written) + } + if stats.Errors != 0 { + t.Errorf("expected 0 errors, got %d", stats.Errors) + } + if len(provider.calls) != 2 { + t.Errorf("expected 2 provider calls, got %d", len(provider.calls)) + } +} + +func TestRunDryRunSkipsWrite(t *testing.T) { + fs := newFakeSolr([]map[string]any{ + {"id": "a", "title": "A", "content": "alpha"}, + }) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{} + + stats, err := Run(context.Background(), client, provider, Options{DryRun: true}) + if err != nil { + t.Fatalf("run: %v", err) + } + if stats.Embedded != 1 { + t.Errorf("expected 1 embedded, got %d", stats.Embedded) + } + if stats.Written != 0 { + t.Errorf("dry-run should write nothing, got %d", stats.Written) + } + if len(fs.updates) != 0 { + t.Errorf("dry-run should not POST updates, got %d payloads", len(fs.updates)) + } +} + +func TestRunRespectsMaxDocs(t *testing.T) { + var docs []map[string]any + for i := 0; i < 10; i++ { + docs = append(docs, map[string]any{ + "id": fmt.Sprintf("d-%d", i), + "title": "t", + "content": "c", + }) + } + fs := newFakeSolr(docs) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{} + + stats, err := Run(context.Background(), client, provider, Options{BatchSize: 3, MaxDocs: 5}) + if err != nil { + t.Fatalf("run: %v", err) + } + if stats.Embedded != 5 { + t.Errorf("MaxDocs=5: expected 5 embedded, got %d", stats.Embedded) + } +} + +func TestRunContinuesThroughEmbedErrors(t *testing.T) { + fs := newFakeSolr([]map[string]any{ + {"id": "good", "title": "G", "content": "g"}, + {"id": "bad", "title": "B", "content": "b"}, + }) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{ + Fail: map[string]error{"B\n\nb": fmt.Errorf("simulated failure")}, + } + + stats, err := Run(context.Background(), client, provider, Options{Concurrency: 1}) + if err != nil { + t.Fatalf("run: %v", err) + } + if stats.Errors != 1 { + t.Errorf("expected 1 error, got %d", stats.Errors) + } + if stats.Embedded != 1 { + t.Errorf("expected 1 successful embed, got %d", stats.Embedded) + } + if stats.Written != 1 { + t.Errorf("expected 1 write, got %d", stats.Written) + } +} + +func TestRunStopsWhenNoMissingRemain(t *testing.T) { + // After the first batch updates every doc, the next fetch should return + // zero and we terminate. + fs := newFakeSolr([]map[string]any{ + {"id": "a", "title": "A", "content": "a"}, + }) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + + client := solr.NewClient(srv.URL) + provider := &fakeProvider{} + + _, err := Run(context.Background(), client, provider, Options{BatchSize: 10}) + if err != nil { + t.Fatalf("run: %v", err) + } + // Expect exactly 2 selects: first returns the doc, second returns empty. + if fs.queries != 2 { + t.Errorf("expected 2 queries (work + empty probe), got %d", fs.queries) + } +} + +func TestRunRefusesWithoutProvider(t *testing.T) { + fs := newFakeSolr(nil) + srv := httptest.NewServer(fs.handler()) + defer srv.Close() + client := solr.NewClient(srv.URL) + + if _, err := Run(context.Background(), client, nil, Options{}); err == nil { + t.Error("expected error when provider is nil") + } +} + +func TestBuildEmbedText(t *testing.T) { + cases := []struct{ title, content, want string }{ + {"t", "c", "t\n\nc"}, + {"", "c", "c"}, + {"t", "", "t"}, + {"", "", ""}, + } + for _, c := range cases { + if got := buildEmbedText(c.title, c.content); got != c.want { + t.Errorf("buildEmbedText(%q,%q)=%q, want %q", c.title, c.content, got, c.want) + } + } +} diff --git a/cmd/solr-mem-backfill/main.go b/cmd/solr-mem-backfill/main.go new file mode 100644 index 0000000..a8468e1 --- /dev/null +++ b/cmd/solr-mem-backfill/main.go @@ -0,0 +1,62 @@ +package main + +import ( + "context" + "flag" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/arreyder/solr-mem/internal/embed" + "github.com/arreyder/solr-mem/internal/solr" +) + +func main() { + var ( + solrURL = flag.String("solr-url", getenv("SOLR_URL", "http://localhost:8983/solr/memories"), "Memories collection URL") + batchSize = flag.Int("batch-size", 50, "Docs per Solr page") + concurrency = flag.Int("concurrency", 4, "Parallel embed calls") + dryRun = flag.Bool("dry-run", false, "Compute embeddings but do not write") + maxDocs = flag.Int("max-docs", 0, "Cap total docs embedded (0 = unlimited)") + force = flag.Bool("force", false, "Re-embed docs that already have an embedding") + pauseMS = flag.Int("pause-ms", 0, "Sleep between batches in milliseconds") + ) + flag.Parse() + + provider, err := embed.FromEnv() + if err != nil { + log.Fatalf("embed provider init: %v", err) + } + if provider == nil { + log.Fatalf("no embedding provider configured — set OPENAI_API_KEY") + } + log.Printf("backfill: provider=%s dim=%d", provider.Name(), provider.Dim()) + + client := solr.NewClient(*solrURL) + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + stats, err := Run(ctx, client, provider, Options{ + BatchSize: *batchSize, + Concurrency: *concurrency, + DryRun: *dryRun, + MaxDocs: *maxDocs, + Force: *force, + Pause: time.Duration(*pauseMS) * time.Millisecond, + }) + if err != nil { + log.Fatalf("backfill: %v", err) + } + log.Printf("backfill done: scanned=%d skipped=%d embedded=%d written=%d errors=%d", + stats.Scanned, stats.Skipped, stats.Embedded, stats.Written, stats.Errors) +} + +func getenv(k, def string) string { + if v := os.Getenv(k); v != "" { + return v + } + return def +} From 255750ddf252e02a04cadd20db33fc103846ceb3 Mon Sep 17 00:00:00 2001 From: arreyder Date: Sat, 18 Apr 2026 13:38:05 -0500 Subject: [PATCH 8/8] Add Ollama embedding provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New provider hitting Ollama's /api/embeddings endpoint — local, free, runs well on Apple Silicon. Discovers its output dimension from the first successful call rather than hardcoding (different models produce different dims: nomic-embed-text=768, mxbai-embed-large=1024, all-minilm=384). FromEnv now detects OLLAMA_EMBEDDING_URL first, then OPENAI_API_KEY. Local/free wins over remote/paid when both are configured so an accidentally-present OpenAI key doesn't silently incur charges. Server startup log distinguishes known-dim (OpenAI) from inferred-dim (Ollama) cases. Schema comment now lists the common dims per provider so users know which vectorDimension to set before deploying — this is the one static choice in the whole pipeline. 8 new unit tests (httptest-based): happy path, model override, trailing-slash normalization, empty input, error status, the quirky-200-with-empty-embedding case (Ollama returns this when a model isn't pulled), and FromEnv precedence. --- cmd/solr-mem-server/main.go | 14 ++- internal/embed/ollama.go | 98 +++++++++++++++++++++ internal/embed/ollama_test.go | 159 ++++++++++++++++++++++++++++++++++ internal/embed/openai_test.go | 2 + internal/embed/provider.go | 30 +++++-- solr/managed-schema.xml | 17 ++-- 6 files changed, 306 insertions(+), 14 deletions(-) create mode 100644 internal/embed/ollama.go create mode 100644 internal/embed/ollama_test.go diff --git a/cmd/solr-mem-server/main.go b/cmd/solr-mem-server/main.go index abd0b19..13942cf 100644 --- a/cmd/solr-mem-server/main.go +++ b/cmd/solr-mem-server/main.go @@ -30,15 +30,21 @@ func main() { } codeClient = solr.NewClient(codeURL) - // Optional embedding provider. If no API key is set, embedProvider stays - // nil and the server runs in BM25-only mode. + // Optional embedding provider. If neither OLLAMA_EMBEDDING_URL nor + // OPENAI_API_KEY is set, embedProvider stays nil and the server runs in + // BM25-only mode. if p, err := embed.FromEnv(); err != nil { log.Printf("embedding provider init failed: %v (continuing BM25-only)", err) } else if p != nil { embedProvider = p - log.Printf("embedding provider: %s (dim=%d)", p.Name(), p.Dim()) + if p.Dim() > 0 { + log.Printf("embedding provider: %s (dim=%d)", p.Name(), p.Dim()) + } else { + // Ollama's dim is learned on the first successful call. + log.Printf("embedding provider: %s (dim to be inferred on first call)", p.Name()) + } } else { - log.Printf("no embedding provider configured (set OPENAI_API_KEY to enable)") + log.Printf("no embedding provider configured (set OLLAMA_EMBEDDING_URL for local or OPENAI_API_KEY for managed)") } // Start expiration sweeper diff --git a/internal/embed/ollama.go b/internal/embed/ollama.go new file mode 100644 index 0000000..02bd283 --- /dev/null +++ b/internal/embed/ollama.go @@ -0,0 +1,98 @@ +package embed + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// Ollama provides embeddings via a local (or LAN) Ollama server's +// /api/embeddings endpoint. The output dimension depends on the model — +// it's discovered from the first successful call rather than hardcoded. +type Ollama struct { + baseURL string + model string + dim int + client *http.Client +} + +// DefaultOllamaModel is a small, fast general-purpose embedding model. +// 768 dim, ~137 MB on disk. +const DefaultOllamaModel = "nomic-embed-text" + +// NewOllama constructs a client. baseURL should be the Ollama server root +// (e.g. "http://localhost:11434"). If model is empty, DefaultOllamaModel is +// used. +func NewOllama(baseURL, model string) *Ollama { + baseURL = strings.TrimRight(baseURL, "/") + if model == "" { + model = DefaultOllamaModel + } + return &Ollama{ + baseURL: baseURL, + model: model, + client: &http.Client{Timeout: 60 * time.Second}, + } +} + +// Name implements Provider. +func (o *Ollama) Name() string { return "ollama:" + o.model } + +// Dim returns 0 until the first successful call discovers the true +// dimension. Callers using Dim() to preconfigure the Solr schema must +// make at least one embed call first (or hardcode the dim from docs). +func (o *Ollama) Dim() int { return o.dim } + +// Embed implements Provider. On success, updates the cached dim with the +// actual response length so downstream consumers can introspect it. +func (o *Ollama) Embed(ctx context.Context, text string) ([]float32, error) { + if text == "" { + return nil, fmt.Errorf("embed: empty text") + } + + body, err := json.Marshal(map[string]any{ + "model": o.model, + "prompt": text, + }) + if err != nil { + return nil, fmt.Errorf("embed: marshal: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, o.baseURL+"/api/embeddings", bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := o.client.Do(req) + if err != nil { + return nil, fmt.Errorf("embed: http: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + snippet, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("embed: status %d: %s", resp.StatusCode, string(snippet)) + } + + var out struct { + Embedding []float32 `json:"embedding"` + } + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, fmt.Errorf("embed: decode: %w", err) + } + if len(out.Embedding) == 0 { + // Ollama returns 200 with empty embedding on some error cases + // (e.g. unknown model). Treat as an error. + return nil, fmt.Errorf("embed: empty embedding returned (model %q may not be pulled)", o.model) + } + if o.dim == 0 || o.dim != len(out.Embedding) { + o.dim = len(out.Embedding) + } + return out.Embedding, nil +} diff --git a/internal/embed/ollama_test.go b/internal/embed/ollama_test.go new file mode 100644 index 0000000..4e6325c --- /dev/null +++ b/internal/embed/ollama_test.go @@ -0,0 +1,159 @@ +package embed + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestOllamaEmbedSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/embeddings" { + t.Errorf("wrong path: %s", r.URL.Path) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("bad body: %v", err) + } + if body["model"] != DefaultOllamaModel { + t.Errorf("wrong model: %v", body["model"]) + } + if body["prompt"] != "hello world" { + t.Errorf("wrong prompt: %v", body["prompt"]) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "embedding": []float32{0.1, 0.2, 0.3, 0.4}, + }) + })) + defer server.Close() + + p := NewOllama(server.URL, "") + v, err := p.Embed(context.Background(), "hello world") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if len(v) != 4 { + t.Errorf("expected 4 dims, got %d", len(v)) + } + // Dim should be learned from the response. + if p.Dim() != 4 { + t.Errorf("dim should update after first call, got %d", p.Dim()) + } +} + +func TestOllamaCustomModel(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + if body["model"] != "mxbai-embed-large" { + t.Errorf("expected mxbai-embed-large, got %v", body["model"]) + } + _ = json.NewEncoder(w).Encode(map[string]any{"embedding": []float32{1.0}}) + })) + defer server.Close() + + p := NewOllama(server.URL, "mxbai-embed-large") + if _, err := p.Embed(context.Background(), "x"); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !strings.Contains(p.Name(), "mxbai-embed-large") { + t.Errorf("name should include model: %s", p.Name()) + } +} + +func TestOllamaTrailingSlashInURL(t *testing.T) { + // NewOllama should trim trailing slash so we don't double it. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/embeddings" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{"embedding": []float32{0.0}}) + })) + defer server.Close() + + p := NewOllama(server.URL+"/", "") + if _, err := p.Embed(context.Background(), "x"); err != nil { + t.Errorf("unexpected err: %v", err) + } +} + +func TestOllamaEmptyInput(t *testing.T) { + p := NewOllama("http://irrelevant", "") + if _, err := p.Embed(context.Background(), ""); err == nil { + t.Error("expected error for empty text") + } +} + +func TestOllamaErrorStatus(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "model not found", http.StatusNotFound) + })) + defer server.Close() + + p := NewOllama(server.URL, "") + _, err := p.Embed(context.Background(), "x") + if err == nil { + t.Fatal("expected error on 404") + } + if !strings.Contains(err.Error(), "404") { + t.Errorf("error should include status, got: %v", err) + } +} + +func TestOllamaEmptyEmbedding(t *testing.T) { + // Ollama returns 200 with an empty embedding when the model isn't pulled. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"embedding": []float32{}}) + })) + defer server.Close() + + p := NewOllama(server.URL, "") + _, err := p.Embed(context.Background(), "x") + if err == nil { + t.Fatal("expected error on empty embedding") + } + if !strings.Contains(err.Error(), "empty embedding") { + t.Errorf("error should name the failure mode, got: %v", err) + } +} + +func TestFromEnvPrefersOllama(t *testing.T) { + // When both Ollama and OpenAI are configured, Ollama wins (local/free beats remote/paid). + t.Setenv("OLLAMA_EMBEDDING_URL", "http://mac.local:11434") + t.Setenv("OLLAMA_EMBEDDING_MODEL", "mxbai-embed-large") + t.Setenv("OPENAI_API_KEY", "sk-test") + + p, err := FromEnv() + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if p == nil { + t.Fatal("expected a provider") + } + if !strings.HasPrefix(p.Name(), "ollama:") { + t.Errorf("expected ollama provider to win, got %s", p.Name()) + } + if !strings.Contains(p.Name(), "mxbai-embed-large") { + t.Errorf("custom model should be reflected: %s", p.Name()) + } +} + +func TestFromEnvOllamaAloneUsesDefault(t *testing.T) { + t.Setenv("OLLAMA_EMBEDDING_URL", "http://localhost:11434") + t.Setenv("OLLAMA_EMBEDDING_MODEL", "") + t.Setenv("OPENAI_API_KEY", "") + + p, err := FromEnv() + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if p == nil { + t.Fatal("expected a provider") + } + if !strings.Contains(p.Name(), DefaultOllamaModel) { + t.Errorf("should default to %s, got %s", DefaultOllamaModel, p.Name()) + } +} diff --git a/internal/embed/openai_test.go b/internal/embed/openai_test.go index df2889b..ab4a164 100644 --- a/internal/embed/openai_test.go +++ b/internal/embed/openai_test.go @@ -109,6 +109,7 @@ func TestOpenAINameAndDim(t *testing.T) { } func TestFromEnvNoProvider(t *testing.T) { + t.Setenv("OLLAMA_EMBEDDING_URL", "") t.Setenv("OPENAI_API_KEY", "") p, err := FromEnv() if err != nil { @@ -120,6 +121,7 @@ func TestFromEnvNoProvider(t *testing.T) { } func TestFromEnvOpenAI(t *testing.T) { + t.Setenv("OLLAMA_EMBEDDING_URL", "") t.Setenv("OPENAI_API_KEY", "sk-test") t.Setenv("OPENAI_EMBEDDING_MODEL", "custom-model") p, err := FromEnv() diff --git a/internal/embed/provider.go b/internal/embed/provider.go index 27ff7ef..c02ea11 100644 --- a/internal/embed/provider.go +++ b/internal/embed/provider.go @@ -1,13 +1,24 @@ // Package embed turns text into dense vector embeddings for semantic search. // -// Providers are pluggable. Users configure one via environment variables: +// Providers are pluggable. Users configure one via environment variables. +// FromEnv checks them in this order (first match wins): // -// OPENAI_API_KEY -> OpenAI text-embedding-3-small (1536 dim by default) -// OPENAI_EMBEDDING_MODEL -> override model (optional) +// 1. Ollama — local / LAN, free: +// OLLAMA_EMBEDDING_URL e.g. http://localhost:11434 +// OLLAMA_EMBEDDING_MODEL e.g. nomic-embed-text (default) +// +// 2. OpenAI — managed, paid: +// OPENAI_API_KEY +// OPENAI_EMBEDDING_MODEL e.g. text-embedding-3-small (default) // // If no provider is configured, FromEnv returns (nil, nil) and the server // runs in BM25-only mode. Callers MUST handle a nil provider as a signal to // skip embedding rather than erroring. +// +// NOTE: The Solr schema's vectorDimension is static. Ollama's nomic-embed-text +// is 768, mxbai-embed-large is 1024, OpenAI text-embedding-3-small is 1536. +// When switching providers, update solr/managed-schema.xml's vectorDimension +// to match the chosen model and reload the configset before running. package embed import ( @@ -20,10 +31,12 @@ import ( // should be safe for concurrent use. type Provider interface { // Embed returns the embedding for text. The returned slice length must - // equal Dim() and must be stable across calls on the same provider. + // equal Dim() (after the first successful call) and must be stable + // across calls on the same provider. Embed(ctx context.Context, text string) ([]float32, error) - // Dim is the output dimension. Used to validate schema configuration. + // Dim is the output dimension. Some providers learn this on the first + // successful call (returning 0 until then); others know it upfront. Dim() int // Name is a short identifier for logging / metrics. @@ -33,7 +46,14 @@ type Provider interface { // FromEnv constructs a provider from environment variables. Returns (nil, nil) // if no provider is configured — callers should treat nil as "embeddings // disabled" rather than an error. +// +// Ollama is checked first so a local server preempts an accidentally-present +// OpenAI key (free/private beats paid/remote by default). func FromEnv() (Provider, error) { + if url := strings.TrimSpace(os.Getenv("OLLAMA_EMBEDDING_URL")); url != "" { + model := strings.TrimSpace(os.Getenv("OLLAMA_EMBEDDING_MODEL")) + return NewOllama(url, model), nil + } if key := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")); key != "" { model := strings.TrimSpace(os.Getenv("OPENAI_EMBEDDING_MODEL")) return NewOpenAI(key, model), nil diff --git a/solr/managed-schema.xml b/solr/managed-schema.xml index 580e254..fb706b6 100644 --- a/solr/managed-schema.xml +++ b/solr/managed-schema.xml @@ -9,11 +9,18 @@ - +