diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 2c80d718..eef5fe41 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -1105,6 +1105,18 @@ func (a *Agent) GetExtensionToolCount() int { return len(a.extraTools) } +// GetExtraTools returns the agent's current extra tools (e.g. +// extension-registered tools). The returned slice is a copy so callers can +// snapshot and later restore it via SetExtraTools. +func (a *Agent) GetExtraTools() []fantasy.AgentTool { + if len(a.extraTools) == 0 { + return nil + } + out := make([]fantasy.AgentTool, len(a.extraTools)) + copy(out, a.extraTools) + return out +} + // SetExtraTools replaces the agent's extra tools (e.g. extension-registered // tools) and rebuilds the internal agent with the updated tool list. The // model, system prompt, and all other configuration are preserved. diff --git a/internal/config/config.go b/internal/config/config.go index fabcc021..2fdf17f6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/spf13/viper" "gopkg.in/yaml.v3" @@ -554,7 +555,7 @@ func FilepathOr[T any](key string, value *T) error { absPath = filepath.Join(home, absPath[2:]) } if !filepath.IsAbs(absPath) { - base := configPath + base := GetConfigPath() if base == "" { fmt.Fprintf(os.Stderr, "unable to build relative path to config.") os.Exit(1) @@ -581,11 +582,24 @@ func FilepathOr[T any](key string, value *T) error { return nil } -var configPath string +var ( + configPathMu sync.RWMutex + configPath string +) // SetConfigPath sets the configuration file path for resolving relative paths // in configuration values. This should be called when the configuration file -// location is known. +// location is known. It is safe for concurrent use. func SetConfigPath(path string) { + configPathMu.Lock() + defer configPathMu.Unlock() configPath = path } + +// GetConfigPath returns the configuration file path previously set via +// SetConfigPath. It is safe for concurrent use. +func GetConfigPath() string { + configPathMu.RLock() + defer configPathMu.RUnlock() + return configPath +} diff --git a/internal/config/configpath_test.go b/internal/config/configpath_test.go new file mode 100644 index 00000000..c9899a64 --- /dev/null +++ b/internal/config/configpath_test.go @@ -0,0 +1,33 @@ +package config + +import ( + "sync" + "testing" +) + +// TestConfigPathConcurrentAccess exercises the mutex guarding the package-level +// configPath global. Run with -race to detect the data race that motivated the +// guard (concurrent kit.New() calls discovering a .kit.yml). +func TestConfigPathConcurrentAccess(t *testing.T) { + t.Cleanup(func() { SetConfigPath("") }) + + const goroutines = 32 + var wg sync.WaitGroup + wg.Add(goroutines * 2) + for range goroutines { + go func() { + defer wg.Done() + SetConfigPath("/tmp/kit.yml") + }() + go func() { + defer wg.Done() + _ = GetConfigPath() + }() + } + wg.Wait() + + SetConfigPath("/tmp/final.yml") + if got := GetConfigPath(); got != "/tmp/final.yml" { + t.Fatalf("GetConfigPath() = %q, want /tmp/final.yml", got) + } +} diff --git a/internal/skills/skills.go b/internal/skills/skills.go index a653554a..61a135d9 100644 --- a/internal/skills/skills.go +++ b/internal/skills/skills.go @@ -12,8 +12,11 @@ package skills import ( "bytes" + "errors" "fmt" + "io/fs" "os" + "path" "path/filepath" "strings" @@ -55,7 +58,14 @@ func LoadSkill(path string) (*Skill, error) { abs = path } - skill := &Skill{Path: abs} + return parseSkill(data, path, abs) +} + +// parseSkill parses skill bytes that originated from srcPath (used for error +// messages and name derivation) and records storePath as the skill's Path. +// It is shared by the os-backed and fs.FS-backed loaders. +func parseSkill(data []byte, srcPath, storePath string) (*Skill, error) { + skill := &Skill{Path: storePath} content := string(data) @@ -70,7 +80,7 @@ func LoadSkill(path string) (*Skill, error) { body = strings.TrimPrefix(body, "\n") if err := yaml.Unmarshal([]byte(frontmatter), skill); err != nil { - return nil, fmt.Errorf("parsing frontmatter in %s: %w", path, err) + return nil, fmt.Errorf("parsing frontmatter in %s: %w", srcPath, err) } skill.Content = strings.TrimSpace(body) } else { @@ -83,12 +93,12 @@ func LoadSkill(path string) (*Skill, error) { // Fallback: derive name from filename if frontmatter didn't set one. if skill.Name == "" { - base := filepath.Base(path) + base := filepath.Base(srcPath) ext := filepath.Ext(base) skill.Name = strings.TrimSuffix(base, ext) // Convert SKILL → directory name for SKILL.md files. if strings.EqualFold(skill.Name, "SKILL") || strings.EqualFold(skill.Name, "skill") { - skill.Name = filepath.Base(filepath.Dir(path)) + skill.Name = filepath.Base(filepath.Dir(srcPath)) } } @@ -113,7 +123,7 @@ func LoadSkillsFromDir(dir string) ([]*Skill, error) { } var skills []*Skill - var errs []string + var errs []error for _, entry := range entries { full := filepath.Join(dir, entry.Name()) @@ -123,7 +133,7 @@ func LoadSkillsFromDir(dir string) ([]*Skill, error) { if ext == ".md" || ext == ".txt" { s, err := LoadSkill(full) if err != nil { - errs = append(errs, err.Error()) + errs = append(errs, err) continue } skills = append(skills, s) @@ -140,7 +150,7 @@ func LoadSkillsFromDir(dir string) ([]*Skill, error) { if !se.IsDir() && strings.EqualFold(se.Name(), "SKILL.md") { s, err := LoadSkill(filepath.Join(full, se.Name())) if err != nil { - errs = append(errs, err.Error()) + errs = append(errs, err) continue } skills = append(skills, s) @@ -150,7 +160,65 @@ func LoadSkillsFromDir(dir string) ([]*Skill, error) { } if len(errs) > 0 { - return skills, fmt.Errorf("some skills failed to load: %s", strings.Join(errs, "; ")) + return skills, fmt.Errorf("some skills failed to load: %w", errors.Join(errs...)) + } + return skills, nil +} + +// LoadSkillsFromFS is the fs.FS-typed counterpart of LoadSkillsFromDir. It +// walks fsys starting at root (which may be "." or a subdirectory), finds +// *.md and *.txt files plus SKILL.md files in subdirectories, parses YAML +// frontmatter + markdown body, and returns the loaded skills. +// +// Because fs.FS has no notion of an absolute on-disk path, each loaded skill's +// Path is set to its slash-separated path within fsys. Files that fail to +// parse are skipped and reported via the returned error. +func LoadSkillsFromFS(fsys fs.FS, root string) ([]*Skill, error) { + if fsys == nil { + return nil, nil + } + if root == "" { + root = "." + } + + var skills []*Skill + var errs []error + + walkErr := fs.WalkDir(fsys, root, func(p string, d fs.DirEntry, err error) error { + if err != nil { + return nil // skip unreadable entries rather than aborting the walk + } + if d.IsDir() { + return nil + } + name := d.Name() + ext := strings.ToLower(path.Ext(name)) + if ext != ".md" && ext != ".txt" { + return nil + } + // Top-level .md/.txt files, or SKILL.md anywhere. + isTopLevel := path.Dir(p) == root + if !isTopLevel && !strings.EqualFold(name, "SKILL.md") { + return nil + } + data, readErr := fs.ReadFile(fsys, p) + if readErr != nil { + errs = append(errs, fmt.Errorf("reading skill %s: %w", p, readErr)) + return nil + } + s, parseErr := parseSkill(data, p, p) + if parseErr != nil { + errs = append(errs, parseErr) + return nil + } + skills = append(skills, s) + return nil + }) + if walkErr != nil { + return skills, fmt.Errorf("walking skills fs at %s: %w", root, walkErr) + } + if len(errs) > 0 { + return skills, fmt.Errorf("some skills failed to load: %w", errors.Join(errs...)) } return skills, nil } diff --git a/internal/skills/skills_fs_test.go b/internal/skills/skills_fs_test.go new file mode 100644 index 00000000..0cedea72 --- /dev/null +++ b/internal/skills/skills_fs_test.go @@ -0,0 +1,70 @@ +package skills + +import ( + "testing" + "testing/fstest" +) + +func TestLoadSkillsFromFS(t *testing.T) { + fsys := fstest.MapFS{ + "top.md": {Data: []byte("---\nname: top-skill\ndescription: a top level skill\n---\nbody here")}, + "notes.txt": {Data: []byte("plain text skill")}, + "deep/SKILL.md": {Data: []byte("---\nname: deep-skill\n---\ndeep body")}, + "deep/other.md": {Data: []byte("ignored non-SKILL nested md")}, + "ignore.json": {Data: []byte("{}")}, + } + + got, err := LoadSkillsFromFS(fsys, ".") + if err != nil { + t.Fatalf("LoadSkillsFromFS error: %v", err) + } + + byName := map[string]*Skill{} + for _, s := range got { + byName[s.Name] = s + } + + if _, ok := byName["top-skill"]; !ok { + t.Errorf("top-skill not loaded; got %v", names(got)) + } + if _, ok := byName["notes"]; !ok { + t.Errorf("notes (txt) not loaded; got %v", names(got)) + } + if _, ok := byName["deep-skill"]; !ok { + t.Errorf("deep SKILL.md not loaded; got %v", names(got)) + } + if _, ok := byName["other"]; ok { + t.Errorf("nested non-SKILL .md should be ignored; got %v", names(got)) + } + if len(got) != 3 { + t.Errorf("expected 3 skills, got %d: %v", len(got), names(got)) + } + + // Content/description parsed from frontmatter. + if s := byName["top-skill"]; s != nil { + if s.Description != "a top level skill" { + t.Errorf("description = %q", s.Description) + } + if s.Content != "body here" { + t.Errorf("content = %q", s.Content) + } + if s.Path != "top.md" { + t.Errorf("path = %q, want top.md", s.Path) + } + } +} + +func TestLoadSkillsFromFSNil(t *testing.T) { + got, err := LoadSkillsFromFS(nil, ".") + if err != nil || got != nil { + t.Fatalf("nil fs should yield (nil, nil), got (%v, %v)", got, err) + } +} + +func names(skills []*Skill) []string { + out := make([]string, 0, len(skills)) + for _, s := range skills { + out = append(out, s.Name) + } + return out +} diff --git a/pkg/kit/README.md b/pkg/kit/README.md index fcaafb4a..04459290 100644 --- a/pkg/kit/README.md +++ b/pkg/kit/README.md @@ -364,15 +364,28 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message - `Option` - Functional option (`func(*Options)`) for `NewAgent` - `Message` - Conversation message with typed content parts - `Tool` - Agent tool interface -- `TurnResult` - Full result from a prompt including usage stats +- `TurnResult` - Full result from a prompt including usage stats, captured + stream deltas (`Stream`), and any tool-driven halt (`FinalValue` / + `HaltedByTool`) +- `StreamEvent` / `StreamEventKind` - Ordered delta events captured in + `TurnResult.Stream` +- `ToolOutput` - Custom tool return value; set `Halt`/`FinalValue` to end the + agent loop and surface a typed result +- Provider-error sentinels - `ErrContextOverflow`, `ErrRateLimit`, `ErrAuth`, + `ErrProviderUnavailable`, `ErrInvalidRequest`; classify with + `ClassifyProviderError(err)` and match via `errors.Is` ### Key Methods - `New(ctx, opts)` - Create new Kit instance - `NewAgent(ctx, ...Option)` - Create a Kit via functional options (streaming on by default) - `Prompt(ctx, message)` - Send message and get response string -- `PromptResult(ctx, message)` - Send message and get full TurnResult +- `PromptResult(ctx, message)` - Send message and get full TurnResult (blocks + until end-of-turn; populates `TurnResult.Stream` in streaming mode) - `PromptWithOptions(ctx, message, opts)` - Prompt with per-call options + (system message, model, thinking level, provider credentials, extra tools) +- `PromptResultWithOptions(ctx, message, opts)` - Per-call options variant that + returns the full TurnResult - `Steer(ctx, instruction)` - System-level steering - `FollowUp(ctx, text)` - Continue without new user input - `SetModel(ctx, model)` - Switch model at runtime @@ -384,7 +397,15 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message - `AddSkill(*Skill)` / `LoadAndAddSkill(path)` / `RemoveSkill(name)` / `SetSkills([])` - Manage skills at runtime - `AddContextFile(*ContextFile)` / `AddContextFileContent(path, content)` / `LoadAndAddContextFile(path)` / `RemoveContextFile(path)` / `SetContextFiles([])` - Manage AGENTS.md-style context files at runtime - `RefreshSystemPrompt()` - Re-apply the composed system prompt to the agent +- `NewTool[T]` / `NewParallelTool[T]` - Create a typed custom tool +- `NewRawTool(name, desc, schema, fn)` - Create a schema-driven tool when the + input shape isn't known at compile time (skill/MCP catalogs) +- `LoadSkillsFromFS(fsys, root)` - `fs.FS`-typed skill loader (embed.FS, + fstest.MapFS, per-tenant virtual filesystems) +- `CollapseBranch(fromID, toID, summary)` - Collapse a branch range into a + summary (works with any `SessionManager` via `AppendBranchSummary`) - `Close()` - Clean up resources +- `CloseContext(ctx)` - Clean up resources with a shutdown deadline ### Options diff --git a/pkg/kit/adapter.go b/pkg/kit/adapter.go index 61add1c4..5c92542a 100644 --- a/pkg/kit/adapter.go +++ b/pkg/kit/adapter.go @@ -153,6 +153,11 @@ func (a *treeManagerAdapter) GetContextEntryIDs() []string { return a.inner.GetContextEntryIDs() } +// AppendBranchSummary implements SessionManager. +func (a *treeManagerAdapter) AppendBranchSummary(fromID, summary string) (string, error) { + return a.inner.AppendBranchSummary(fromID, summary) +} + // Close implements SessionManager. func (a *treeManagerAdapter) Close() error { return a.inner.Close() diff --git a/pkg/kit/compaction.go b/pkg/kit/compaction.go index 60819f9d..f3672cad 100644 --- a/pkg/kit/compaction.go +++ b/pkg/kit/compaction.go @@ -112,8 +112,20 @@ func (m *Kit) Compact(ctx context.Context, opts *CompactionOptions, customInstru } // compactInternal is the shared compaction implementation. The isAutomatic -// flag distinguishes auto-triggered compaction from manual /compact. +// flag distinguishes user-triggered from auto-compaction for hooks/events. +// On failure it emits a CompactionEvent carrying the error so embedders can +// observe the failure path symmetrically with the success path. func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, customInstructions string, isAutomatic bool) (*CompactionResult, error) { + result, err := m.compactImpl(ctx, opts, customInstructions, isAutomatic) + if err != nil { + m.events.emit(CompactionEvent{Err: err}) + } + return result, err +} + +// compactImpl performs the actual compaction work. On success it emits a +// CompactionEvent via persistAndEmitCompaction. +func (m *Kit) compactImpl(ctx context.Context, opts *CompactionOptions, customInstructions string, isAutomatic bool) (*CompactionResult, error) { if opts == nil { if m.compactionOpts != nil { opts = m.compactionOpts diff --git a/pkg/kit/errors.go b/pkg/kit/errors.go new file mode 100644 index 00000000..5e1ffb42 --- /dev/null +++ b/pkg/kit/errors.go @@ -0,0 +1,113 @@ +package kit + +import ( + "errors" + "strings" +) + +// Provider-error sentinels. Provider and turn execution paths wrap these via +// fmt.Errorf("%w: %s", …) so embedders can classify failures with errors.Is +// instead of brittle string matching. Use [ClassifyProviderError] to map an +// arbitrary provider error to one of these sentinels. +var ( + // ErrContextOverflow indicates the request exceeded the model's maximum + // context window. Embedders typically respond by compacting and retrying. + ErrContextOverflow = errors.New("context window exceeded") + + // ErrRateLimit indicates the provider throttled the request. Embedders + // typically respond by backing off and retrying. + ErrRateLimit = errors.New("rate limited by provider") + + // ErrAuth indicates a credential / authorization failure. + ErrAuth = errors.New("provider authentication failed") + + // ErrProviderUnavailable indicates a transient provider/upstream failure + // (5xx, network error, timeout). + ErrProviderUnavailable = errors.New("provider unavailable") + + // ErrInvalidRequest indicates the request was structurally invalid and + // retrying will not help. + ErrInvalidRequest = errors.New("invalid request to provider") +) + +// ClassifyProviderError inspects err and returns it wrapped with the matching +// provider-error sentinel ([ErrContextOverflow], [ErrRateLimit], [ErrAuth], +// [ErrProviderUnavailable], or [ErrInvalidRequest]) when the underlying cause +// can be recognized. The returned error satisfies errors.Is against both the +// sentinel and the original cause, so the full chain stays inspectable. +// +// When err is nil it returns nil. When the cause cannot be classified the +// original err is returned unchanged so callers never lose information. +// +// Classification is heuristic: it first honors any sentinel already present in +// the chain (so double-classification is idempotent), then falls back to +// matching common provider status codes and phrases in the error text. +func ClassifyProviderError(err error) error { + if err == nil { + return nil + } + // Already classified — keep as-is so the call is idempotent. + for _, sentinel := range []error{ + ErrContextOverflow, ErrRateLimit, ErrAuth, + ErrProviderUnavailable, ErrInvalidRequest, + } { + if errors.Is(err, sentinel) { + return err + } + } + + if sentinel := classifyProviderErrorText(err.Error()); sentinel != nil { + return wrapSentinel(sentinel, err) + } + return err +} + +// wrapSentinel returns an error that satisfies errors.Is(_, sentinel) while +// keeping the original cause inspectable via errors.Is. +func wrapSentinel(sentinel, cause error) error { + return &sentinelError{sentinel: sentinel, cause: cause} +} + +type sentinelError struct { + sentinel error + cause error +} + +func (e *sentinelError) Error() string { + return e.sentinel.Error() + ": " + e.cause.Error() +} + +// Unwrap returns both the sentinel and the cause so errors.Is matches the +// sentinel and the underlying error chain stays reachable. +func (e *sentinelError) Unwrap() []error { + return []error{e.sentinel, e.cause} +} + +// classifyProviderErrorText returns the sentinel matching common provider +// error phrasings, or nil if none match. +func classifyProviderErrorText(msg string) error { + m := strings.ToLower(msg) + switch { + case containsAny(m, "context_length_exceeded", "context window", "maximum context length", "too many tokens", "prompt is too long"): + return ErrContextOverflow + case containsAny(m, "rate limit", "rate_limit", "too many requests", "status 429", "429"): + return ErrRateLimit + case containsAny(m, "unauthorized", "authentication", "invalid api key", "invalid_api_key", "permission denied", "status 401", "status 403", "401", "403"): + return ErrAuth + case containsAny(m, "status 500", "status 502", "status 503", "status 504", "internal server error", "bad gateway", "service unavailable", "gateway timeout", "timeout", "connection refused", "no such host", "eof"): + return ErrProviderUnavailable + case containsAny(m, "status 400", "invalid request", "bad request", "unprocessable"): + return ErrInvalidRequest + default: + return nil + } +} + +func containsAny(s string, subs ...string) bool { + for _, sub := range subs { + if strings.Contains(s, sub) { + return true + } + } + return false +} diff --git a/pkg/kit/errors_test.go b/pkg/kit/errors_test.go new file mode 100644 index 00000000..38271a1e --- /dev/null +++ b/pkg/kit/errors_test.go @@ -0,0 +1,64 @@ +package kit_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/mark3labs/kit/pkg/kit" +) + +func TestClassifyProviderError(t *testing.T) { + cases := []struct { + name string + in error + want error + }{ + {"nil", nil, nil}, + {"context overflow", errors.New("error: context_length_exceeded for this model"), kit.ErrContextOverflow}, + {"context window phrase", errors.New("the prompt is too long for the context window"), kit.ErrContextOverflow}, + {"rate limit", errors.New("HTTP status 429: rate limit exceeded"), kit.ErrRateLimit}, + {"auth 401", errors.New("status 401 unauthorized"), kit.ErrAuth}, + {"auth invalid key", errors.New("invalid api key provided"), kit.ErrAuth}, + {"unavailable 503", errors.New("status 503 service unavailable"), kit.ErrProviderUnavailable}, + {"invalid request", errors.New("status 400 bad request: malformed body"), kit.ErrInvalidRequest}, + {"unclassified", errors.New("something totally unexpected"), nil}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := kit.ClassifyProviderError(tc.in) + if tc.in == nil { + if got != nil { + t.Fatalf("expected nil, got %v", got) + } + return + } + if tc.want == nil { + // Unclassified errors are returned unchanged. + if got.Error() != tc.in.Error() { + t.Fatalf("expected unchanged error, got %v", got) + } + return + } + if !errors.Is(got, tc.want) { + t.Fatalf("errors.Is(%v, %v) = false", got, tc.want) + } + // Original cause must remain reachable. + if !errors.Is(got, tc.in) { + t.Fatalf("original cause not preserved in %v", got) + } + }) + } +} + +func TestClassifyProviderErrorIdempotent(t *testing.T) { + wrapped := fmt.Errorf("%w: upstream detail", kit.ErrRateLimit) + got := kit.ClassifyProviderError(wrapped) + if got != wrapped { + t.Fatalf("already-classified error should be returned unchanged") + } + if !errors.Is(got, kit.ErrRateLimit) { + t.Fatalf("expected ErrRateLimit to remain") + } +} diff --git a/pkg/kit/events.go b/pkg/kit/events.go index d662f1df..fdc82e1c 100644 --- a/pkg/kit/events.go +++ b/pkg/kit/events.go @@ -370,7 +370,10 @@ type StepUsageEvent struct { // EventType implements Event. func (e StepUsageEvent) EventType() EventType { return EventStepUsage } -// CompactionEvent fires after a successful compaction. +// CompactionEvent fires after a compaction attempt. On success Err is nil and +// the summary/token/file fields are populated. On failure Err is non-nil and +// the remaining fields are zero-valued, so embedders can wire symmetric +// start/end lifecycle telemetry around both outcomes. type CompactionEvent struct { Summary string OriginalTokens int @@ -378,6 +381,10 @@ type CompactionEvent struct { MessagesRemoved int ReadFiles []string ModifiedFiles []string + + // Err is non-nil when compaction failed. On the failure path the other + // fields are zero-valued. + Err error } // EventType implements Event. diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index 2b4aa9dd..0d29a3c3 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -115,6 +115,11 @@ type Kit struct { steerMu sync.Mutex steerCh chan agent.SteerMessage leftoverSteer []agent.SteerMessage // unconsumed steer messages from the last turn + + // promptOptsMu serializes per-call PromptOptions overrides that mutate + // shared agent state (model, thinking level, provider creds, extra tools) + // so the apply/restore window of one call never races another. + promptOptsMu sync.Mutex } // Subscribe registers an EventListener that will be called for every lifecycle @@ -1824,6 +1829,58 @@ type TurnResult struct { // any tool call/result messages added during the agent loop. // Each message carries role and plain-text content. Messages []LLMMessage + + // FinalValue is set when a tool returned a [ToolOutput] with Halt=true + // during the turn. The dynamic type is whatever the tool handler placed + // in [ToolOutput.FinalValue]. Nil when no tool halted the turn. + FinalValue any + + // HaltedByTool is the name of the tool that returned Halt=true, or empty + // if the turn ended for any other reason. + HaltedByTool string + + // Stream contains every delta event observed during the turn in emit + // order. It is populated regardless of streaming mode (in non-streaming + // mode it carries the coarse-grained events the provider reported). + // PromptResult and the other turn-returning entry points always block + // until end-of-turn, so Stream is complete when they return. + Stream []StreamEvent +} + +// StreamEventKind classifies a [StreamEvent] captured during a turn. +type StreamEventKind string + +// Stream event kinds captured in [TurnResult.Stream]. +const ( + StreamEventTextDelta StreamEventKind = "text_delta" + StreamEventReasoningStart StreamEventKind = "reasoning_start" + StreamEventReasoningDelta StreamEventKind = "reasoning_delta" + StreamEventReasoningEnd StreamEventKind = "reasoning_end" + StreamEventToolCallChunk StreamEventKind = "tool_call_chunk" +) + +// StreamEvent is a single delta observed during a turn, captured in +// [TurnResult.Stream]. It lets embedders assert streamed ordering +// deterministically without re-implementing an OnMessageUpdate collector. +type StreamEvent struct { + // Kind classifies the event. + Kind StreamEventKind + + // Text carries the assistant text for StreamEventTextDelta. + Text string + + // Reasoning carries the reasoning text for StreamEventReasoningDelta. + Reasoning string + + // ToolName is the tool name for StreamEventToolCallChunk. + ToolName string + + // ToolID is the tool call ID for StreamEventToolCallChunk. + ToolID string + + // Args carries the (accumulating) tool-call argument JSON for + // StreamEventToolCallChunk. + Args string } // --------------------------------------------------------------------------- @@ -2064,6 +2121,9 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult // All prompt modes (Prompt, Steer, FollowUp, PromptWithOptions) share this // single code path so callback wiring is never duplicated. func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.GenerateWithLoopResult, error) { + // Capture the per-turn stream collector (set by runTurn) so streamed + // deltas are recorded into TurnResult.Stream in emit order. + collector := streamCollectorFromContext(ctx) // Create a per-turn steer channel and attach it to the context so the // agent's PrepareStep can inject steering messages between steps. steerCh := make(chan agent.SteerMessage, 16) @@ -2181,24 +2241,30 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent. i := strings.Index(remaining, thinkClose) if i == -1 { m.events.emit(ReasoningDeltaEvent{Delta: remaining}) + collector.add(StreamEvent{Kind: StreamEventReasoningDelta, Reasoning: remaining}) return } if i > 0 { m.events.emit(ReasoningDeltaEvent{Delta: remaining[:i]}) + collector.add(StreamEvent{Kind: StreamEventReasoningDelta, Reasoning: remaining[:i]}) } inThinkTag = false m.events.emit(ReasoningCompleteEvent{}) + collector.add(StreamEvent{Kind: StreamEventReasoningEnd}) remaining = remaining[i+len(thinkClose):] } else { i := strings.Index(remaining, thinkOpen) if i == -1 { m.events.emit(MessageUpdateEvent{Chunk: remaining}) + collector.add(StreamEvent{Kind: StreamEventTextDelta, Text: remaining}) return } if i > 0 { m.events.emit(MessageUpdateEvent{Chunk: remaining[:i]}) + collector.add(StreamEvent{Kind: StreamEventTextDelta, Text: remaining[:i]}) } inThinkTag = true + collector.add(StreamEvent{Kind: StreamEventReasoningStart}) remaining = remaining[i+len(thinkOpen):] } } @@ -2206,9 +2272,11 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent. }(), OnReasoningDelta: func(delta string) { m.events.emit(ReasoningDeltaEvent{Delta: delta}) + collector.add(StreamEvent{Kind: StreamEventReasoningDelta, Reasoning: delta}) }, OnReasoningComplete: func() { m.events.emit(ReasoningCompleteEvent{}) + collector.add(StreamEvent{Kind: StreamEventReasoningEnd}) }, OnToolOutput: func(toolCallID, toolName, chunk string, isStderr bool) { m.events.emit(ToolOutputEvent{ @@ -2255,12 +2323,14 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent. ToolName: toolName, ToolKind: toolKindFor(toolName), }) + collector.add(StreamEvent{Kind: StreamEventToolCallChunk, ToolID: toolCallID, ToolName: toolName}) }, OnToolCallDelta: func(toolCallID, delta string) { m.events.emit(ToolCallDeltaEvent{ ToolCallID: toolCallID, Delta: delta, }) + collector.add(StreamEvent{Kind: StreamEventToolCallChunk, ToolID: toolCallID, Args: delta}) }, OnToolCallEnd: func(toolCallID string) { m.events.emit(ToolCallEndEvent{ @@ -2412,6 +2482,14 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr sentCount := len(messages) + // Attach a per-turn stream collector and halt holder so generate's + // callbacks can capture delta events (TurnResult.Stream) and tools can + // signal loop termination (TurnResult.FinalValue / HaltedByTool). + collector := &streamCollector{} + holder := &haltHolder{} + ctx = context.WithValue(ctx, streamCollectorKey{}, collector) + ctx = context.WithValue(ctx, haltHolderKey{}, holder) + m.events.emit(TurnStartEvent{Prompt: promptLabel}) m.events.emit(MessageStartEvent{}) @@ -2434,7 +2512,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr m.events.emit(TurnEndEvent{Error: err}) // Run AfterTurn hooks even on error. m.afterTurn.run(AfterTurnHook{Error: err}) - return nil, err + return nil, ClassifyProviderError(err) } responseText := result.FinalResponse.Content.Text() @@ -2487,6 +2565,13 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr turnResult.FinalUsage = &finalUsage } + // Surface captured stream deltas and any tool-driven halt signal. + turnResult.Stream = collector.drain() + if halted, toolName, value := holder.snapshot(); halted { + turnResult.HaltedByTool = toolName + turnResult.FinalValue = value + } + return turnResult, nil } @@ -2628,28 +2713,158 @@ type PromptOptions struct { // Use it to inject per-call instructions or context without permanently // modifying the agent's system prompt. SystemMessage string + + // Model overrides the agent's configured model for this call only. Empty + // string means "use the agent's default". The previous model is restored + // after the call returns. + Model string + + // ThinkingLevel overrides the agent's reasoning level for this call only + // (e.g. "off", "low", "medium", "high"). Empty string means "use the + // agent's default". The previous level is restored after the call. + ThinkingLevel string + + // ExtraTools are added to the effective tool set for this call only and + // removed afterwards. + ExtraTools []Tool + + // ProviderURL overrides the provider base URL for this call only. Useful + // for multi-tenant embedders that resolve endpoints per request. The + // previous value is restored after the call. + ProviderURL string + + // ProviderAPIKey overrides the provider credential for this call only. + // The previous value is restored after the call. + ProviderAPIKey string +} + +// applyPromptOptions applies the per-call overrides in opts to the shared +// agent state and returns a restore function that reverts every change. It +// holds promptOptsMu for the lifetime of the override window (the returned +// restore releases it), so concurrent option-driven prompts are serialized. +// On error nothing is changed and the lock is released. +func (m *Kit) applyPromptOptions(ctx context.Context, opts PromptOptions) (func(), error) { + needsModelRebuild := opts.Model != "" || opts.ThinkingLevel != "" || + opts.ProviderURL != "" || opts.ProviderAPIKey != "" + if !needsModelRebuild && len(opts.ExtraTools) == 0 { + return func() {}, nil + } + + m.promptOptsMu.Lock() + var restores []func() + restore := func() { + for i := len(restores) - 1; i >= 0; i-- { + restores[i]() + } + m.promptOptsMu.Unlock() + } + + // Extra tools (additive) — restored by re-setting the prior slice. + if len(opts.ExtraTools) > 0 { + prev := m.agent.GetExtraTools() + merged := make([]Tool, 0, len(prev)+len(opts.ExtraTools)) + merged = append(merged, prev...) + merged = append(merged, opts.ExtraTools...) + m.agent.SetExtraTools(merged) + restores = append(restores, func() { m.agent.SetExtraTools(prev) }) + } + + if needsModelRebuild { + prevModel := m.modelString + prevThinkingSet := m.v.IsSet("thinking-level") + prevThinking := m.v.GetString("thinking-level") + prevURLSet := m.v.IsSet("provider-url") + prevURL := m.v.GetString("provider-url") + prevKeySet := m.v.IsSet("provider-api-key") + prevKey := m.v.GetString("provider-api-key") + + if opts.ThinkingLevel != "" { + m.v.Set("thinking-level", opts.ThinkingLevel) + } + if opts.ProviderURL != "" { + m.v.Set("provider-url", opts.ProviderURL) + } + if opts.ProviderAPIKey != "" { + m.v.Set("provider-api-key", opts.ProviderAPIKey) + } + + targetModel := opts.Model + if targetModel == "" { + targetModel = prevModel + } + if err := m.SetModel(ctx, targetModel); err != nil { + // Revert config keys we may have set, then unwind prior restores. + restoreViperString(m.v, "thinking-level", prevThinking, prevThinkingSet) + restoreViperString(m.v, "provider-url", prevURL, prevURLSet) + restoreViperString(m.v, "provider-api-key", prevKey, prevKeySet) + restore() + return nil, err + } + restores = append(restores, func() { + restoreViperString(m.v, "thinking-level", prevThinking, prevThinkingSet) + restoreViperString(m.v, "provider-url", prevURL, prevURLSet) + restoreViperString(m.v, "provider-api-key", prevKey, prevKeySet) + // Use a fresh context: the rollback must complete even if the + // caller's ctx was canceled or expired during the call, otherwise + // the per-call model override would leak into subsequent calls. + _ = m.SetModel(context.Background(), prevModel) + }) + } + + return restore, nil +} + +// restoreViperString restores a config key to its prior value, clearing it +// back to the unset state when it was not explicitly set before. +func restoreViperString(v *viper.Viper, key, prev string, wasSet bool) { + if wasSet { + v.Set(key, prev) + return + } + v.Set(key, "") } // PromptWithOptions sends a message with per-call configuration. It behaves -// like Prompt but allows injecting an additional system message before the -// user prompt. Both messages are persisted to the session. +// like Prompt but applies the overrides in opts (system message, model, +// thinking level, provider credentials, extra tools) for this call only, +// restoring the agent's prior state afterwards. func (m *Kit) PromptWithOptions(ctx context.Context, msg string, opts PromptOptions) (string, error) { + result, err := m.PromptResultWithOptions(ctx, msg, opts) + if err != nil { + return "", err + } + return result.Response, nil +} + +// PromptResultWithOptions is the [TurnResult]-returning counterpart of +// PromptWithOptions. Like all turn-returning entry points it blocks until +// end-of-turn, so the returned TurnResult (including TurnResult.Stream) is +// complete when it returns. Per-call overrides in opts are applied for this +// call only and the agent's prior state is restored before returning. +func (m *Kit) PromptResultWithOptions(ctx context.Context, msg string, opts PromptOptions) (*TurnResult, error) { + restore, err := m.applyPromptOptions(ctx, opts) + if err != nil { + return nil, err + } + defer restore() + var preMessages []fantasy.Message if opts.SystemMessage != "" { preMessages = append(preMessages, fantasy.NewSystemMessage(opts.SystemMessage)) } preMessages = append(preMessages, fantasy.NewUserMessage(msg)) - result, err := m.runTurn(ctx, msg, msg, preMessages) - if err != nil { - return "", err - } - return result.Response, nil + return m.runTurn(ctx, msg, msg, preMessages) } // PromptResult sends a message and returns the full turn result including // usage statistics and conversation messages. Use this instead of Prompt() // when you need more than just the response text. +// +// PromptResult blocks until end-of-turn regardless of whether streaming is +// enabled. When streaming is enabled, every delta observed during the turn is +// also captured in order in [TurnResult.Stream], so callers can assert +// streamed ordering deterministically without an OnMessageUpdate collector. func (m *Kit) PromptResult(ctx context.Context, message string) (*TurnResult, error) { return m.runTurn(ctx, message, message, []fantasy.Message{ fantasy.NewUserMessage(message), @@ -2787,7 +3002,18 @@ func extractFileParts(msg fantasy.Message) []fantasy.FilePart { // Close cleans up resources including MCP server connections, model resources, // and the tree session file handle. Should be called when the Kit instance is // no longer needed. Returns an error if cleanup fails. +// +// Close is equivalent to CloseContext(context.Background()). Use +// [Kit.CloseContext] when shutdown must be bounded by a deadline. func (m *Kit) Close() error { + return m.CloseContext(context.Background()) +} + +// CloseContext is like [Kit.Close] but accepts a context so graceful shutdown +// can be bounded by a deadline or cancellation. The context is honored on a +// best-effort basis: if it is already done when CloseContext is called, the +// context error is returned after a best-effort cleanup pass. +func (m *Kit) CloseContext(ctx context.Context) error { // Emit SessionShutdown for extensions. if m.extRunner != nil && m.extRunner.HasHandlers(extensions.SessionShutdown) { _, _ = m.extRunner.Emit(extensions.SessionShutdownEvent{}) @@ -2799,7 +3025,11 @@ func (m *Kit) Close() error { if closer, ok := m.authHandler.(interface{ Close() error }); ok { _ = closer.Close() } - return m.agent.Close() + err := m.agent.Close() + if ctxErr := ctx.Err(); ctxErr != nil && err == nil { + return ctxErr + } + return err } // Conversion helpers are defined in adapter.go. diff --git a/pkg/kit/rawtool_test.go b/pkg/kit/rawtool_test.go new file mode 100644 index 00000000..8de936a7 --- /dev/null +++ b/pkg/kit/rawtool_test.go @@ -0,0 +1,102 @@ +package kit_test + +import ( + "context" + "testing" + + "charm.land/fantasy" + + "github.com/mark3labs/kit/pkg/kit" +) + +func TestNewRawTool(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string", "description": "City name"}, + }, + "required": []any{"city"}, + } + + var gotArgs map[string]any + tool := kit.NewRawTool("get_weather", "Get weather", schema, + func(ctx context.Context, args map[string]any) (kit.ToolOutput, error) { + gotArgs = args + return kit.TextResult("72F in " + args["city"].(string)), nil + }, + ) + + info := tool.Info() + if info.Name != "get_weather" { + t.Fatalf("name = %q", info.Name) + } + if info.Parameters["type"] != "object" { + t.Fatalf("schema not propagated: %#v", info.Parameters) + } + if len(info.Required) != 1 || info.Required[0] != "city" { + t.Fatalf("required not propagated: %#v", info.Required) + } + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call_1", + Input: `{"city":"Boston"}`, + }) + if err != nil { + t.Fatalf("Run error: %v", err) + } + if resp.IsError { + t.Fatalf("unexpected error response: %q", resp.Content) + } + if resp.Content != "72F in Boston" { + t.Fatalf("content = %q", resp.Content) + } + if gotArgs["city"] != "Boston" { + t.Fatalf("args not decoded: %#v", gotArgs) + } +} + +func TestNewRawToolInvalidArgs(t *testing.T) { + tool := kit.NewRawTool("t", "d", nil, + func(ctx context.Context, args map[string]any) (kit.ToolOutput, error) { + t.Fatal("handler should not be called for invalid args") + return kit.ToolOutput{}, nil + }, + ) + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ID: "x", Input: `not json`}) + if err != nil { + t.Fatalf("Run error: %v", err) + } + if !resp.IsError { + t.Fatalf("expected error response for invalid args") + } +} + +// Contract: null / whitespace-padded-null inputs must hand the handler a +// non-nil empty map (not a nil map), so handlers can read or write keys +// without a nil-map panic. Inputs are normalised before reaching the handler. +func TestNewRawToolNullArgs(t *testing.T) { + for _, input := range []string{"null", " null ", "\tnull\n"} { + called := false + var gotNil bool + tool := kit.NewRawTool("t", "d", nil, + func(ctx context.Context, args map[string]any) (kit.ToolOutput, error) { + called = true + gotNil = args == nil + return kit.TextResult("ok"), nil + }, + ) + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ID: "x", Input: input}) + if err != nil { + t.Fatalf("input %q: Run error: %v", input, err) + } + if resp.IsError { + t.Fatalf("input %q: unexpected error response: %q", input, resp.Content) + } + if !called { + t.Fatalf("input %q: handler not called", input) + } + if gotNil { + t.Fatalf("input %q: args was nil, want non-nil empty map", input) + } + } +} diff --git a/pkg/kit/session.go b/pkg/kit/session.go index 27da64b2..78b5311a 100644 --- a/pkg/kit/session.go +++ b/pkg/kit/session.go @@ -1,9 +1,14 @@ package kit import ( + "errors" "time" ) +// ErrBranchSummaryNotSupported is returned by SessionManager implementations +// that do not support collapsing a branch range into a summary entry. +var ErrBranchSummaryNotSupported = errors.New("session manager does not support branch summaries") + // SessionManager defines the contract for conversation storage backends. // Implementations can use files (default), databases, cloud storage, etc. // @@ -89,6 +94,12 @@ type SessionManager interface { // determine which entries to summarize. GetContextEntryIDs() []string + // AppendBranchSummary collapses the range from fromID to the current leaf + // on the active branch into a single summary entry and returns the new + // entry ID. It backs [Kit.CollapseBranch]. Managers that do not track + // branch summaries should return [ErrBranchSummaryNotSupported]. + AppendBranchSummary(fromID, summary string) (entryID string, err error) + // Close releases resources (database connections, file handles, etc.). Close() error } diff --git a/pkg/kit/sessions.go b/pkg/kit/sessions.go index 9abdca25..f4df0af9 100644 --- a/pkg/kit/sessions.go +++ b/pkg/kit/sessions.go @@ -217,19 +217,19 @@ func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) { // CollapseBranch replaces a branch range with a summary entry. // Returns an error if the session is unavailable or the operation fails. +// Custom SessionManagers that do not support branch summaries surface +// [ErrBranchSummaryNotSupported]. +// +// The branch is always collapsed from fromID to the current leaf; the toID +// parameter is currently unused (the underlying AppendBranchSummary primitive +// only supports collapsing to the leaf) and is retained for forward +// compatibility. func (m *Kit) CollapseBranch(fromID, toID, summary string) error { if m.session == nil { return fmt.Errorf("no session available") } - // Note: This operation is not directly supported by SessionManager interface - // as it requires AppendBranchSummary which is TreeManager-specific. - // For custom SessionManagers, this would need to be implemented differently. - // For now, we try to use the underlying TreeManager if available. - if adapter, ok := m.session.(*treeManagerAdapter); ok { - _, err := adapter.inner.AppendBranchSummary(fromID, summary) - return err - } - return fmt.Errorf("CollapseBranch not supported by custom session manager") + _, err := m.session.AppendBranchSummary(fromID, summary) + return err } // branchEntryToTreeNode converts a BranchEntry to a TreeNode. diff --git a/pkg/kit/skills.go b/pkg/kit/skills.go index e07d42ed..24aeea3c 100644 --- a/pkg/kit/skills.go +++ b/pkg/kit/skills.go @@ -2,6 +2,7 @@ package kit import ( "fmt" + "io/fs" "os" "github.com/mark3labs/kit/internal/extensions" @@ -36,6 +37,19 @@ func LoadSkillsFromDir(dir string) ([]*Skill, error) { return skills.LoadSkillsFromDir(dir) } +// LoadSkillsFromFS is the [fs.FS]-typed counterpart of [LoadSkillsFromDir]. +// It walks fsys starting at root (which may be "." or a subdirectory), finds +// *.md/*.txt files and SKILL.md files in subdirectories, parses YAML +// frontmatter + markdown body, and returns the loaded skills. Use it when +// skill discovery is wrapped in an fs.FS abstraction (embed.FS distribution, +// fstest.MapFS tests, or per-tenant virtual filesystems). +// +// Each loaded skill's Path is its slash-separated path within fsys, since +// fs.FS has no notion of an absolute on-disk path. +func LoadSkillsFromFS(fsys fs.FS, root string) ([]*Skill, error) { + return skills.LoadSkillsFromFS(fsys, root) +} + // LoadSkills auto-discovers skills from standard directories: // - Global: $XDG_CONFIG_HOME/kit/skills/ (default ~/.config/kit/skills/) // - Project-local: /.kit/skills/ diff --git a/pkg/kit/testing.go b/pkg/kit/testing.go new file mode 100644 index 00000000..7bcff9b8 --- /dev/null +++ b/pkg/kit/testing.go @@ -0,0 +1,15 @@ +//go:build testing + +package kit + +import "github.com/mark3labs/kit/internal/config" + +// ResetForTesting clears package-global state that survives across tests in +// the same binary. It is intended for test-binary teardown / between-test +// cleanup. Safe to call concurrently with no in-flight kit.New() calls. +// +// This function is only compiled under the "testing" build tag so it never +// ships in production binaries. +func ResetForTesting() { + config.SetConfigPath("") +} diff --git a/pkg/kit/tools.go b/pkg/kit/tools.go index 07b5996c..3d3a1f23 100644 --- a/pkg/kit/tools.go +++ b/pkg/kit/tools.go @@ -1,8 +1,12 @@ package kit import ( + "bytes" "context" + "encoding/json" + "fmt" "strings" + "sync" "charm.land/fantasy" @@ -40,6 +44,20 @@ type ToolOutput struct { // Metadata is optional opaque metadata attached to the response. // It is not sent to the LLM but may be consumed by hooks or the UI. Metadata any + + // FinalValue, when Halt is true, is propagated to the turn's + // [TurnResult.FinalValue] so the caller can recover a typed result + // produced by the tool (e.g. a structured "finish" tool). The dynamic + // type is whatever the tool handler stored. Ignored when Halt is false. + FinalValue any + + // Halt, when true, signals that the agent loop should terminate after + // this tool call. Content is still returned to the model for the current + // step, but [TurnResult.FinalValue] and [TurnResult.HaltedByTool] are + // populated so embedders building structured-result extraction patterns + // (model calls a finish(...) tool, the loop ends, the typed value is + // returned) no longer need a side-channel. + Halt bool } // TextResult creates a successful text [ToolOutput]. @@ -72,6 +90,49 @@ func MediaResult(content string, data []byte, mediaType string) ToolOutput { // toolCallIDKey is the context key for the tool call ID. type toolCallIDKey struct{} +// haltHolderKey is the context key for the per-turn halt holder. It is +// injected by runTurn so tool handlers created with [NewTool], +// [NewParallelTool], and [NewRawTool] can signal loop termination and carry a +// final value out to the [TurnResult] without an embedder-side side-channel. +type haltHolderKey struct{} + +// haltHolder captures a Halt signal raised by a tool handler during a turn. +type haltHolder struct { + mu sync.Mutex + halted bool + toolName string + value any +} + +func (h *haltHolder) set(toolName string, value any) { + h.mu.Lock() + defer h.mu.Unlock() + // First halt wins so the earliest finishing tool determines the result. + if h.halted { + return + } + h.halted = true + h.toolName = toolName + h.value = value +} + +func (h *haltHolder) snapshot() (bool, string, any) { + h.mu.Lock() + defer h.mu.Unlock() + return h.halted, h.toolName, h.value +} + +// recordHalt records a Halt signal from a tool handler onto the per-turn halt +// holder, if one is present in the context. +func recordHalt(ctx context.Context, toolName string, result ToolOutput) { + if !result.Halt { + return + } + if holder, ok := ctx.Value(haltHolderKey{}).(*haltHolder); ok && holder != nil { + holder.set(toolName, result.FinalValue) + } +} + // ToolCallIDFromContext extracts the tool call ID from the context. // The call ID is set automatically by [NewTool] and [NewParallelTool] // before invoking the handler. Returns an empty string if no ID is present. @@ -144,6 +205,7 @@ func NewTool[TInput any](name, description string, fn func(ctx context.Context, if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil } + recordHalt(ctx, name, result) return toolOutputToResponse(result), nil }, ) @@ -160,11 +222,104 @@ func NewParallelTool[TInput any](name, description string, fn func(ctx context.C if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil } + recordHalt(ctx, name, result) return toolOutputToResponse(result), nil }, ) } +// rawToolInput is the decoded carrier used by [NewRawTool]. Using +// json.RawMessage lets the typed-tool machinery in fantasy generate a +// permissive object schema while we forward the raw arguments to the handler +// as a decoded map. +type rawToolInput = json.RawMessage + +// NewRawTool is the schema-driven counterpart to [NewTool]. Use it when the +// tool's input shape isn't known at compile time — for example tools loaded +// from JSON Schema definitions in skill files or MCP server catalogs. +// +// schema must be a valid JSON Schema describing the tool's input object; it is +// advertised to the model as the tool's parameter schema. fn receives the +// decoded JSON arguments as a map and returns a [ToolOutput]. Like [NewTool], +// the tool call ID is injected into the context and can be retrieved with +// [ToolCallIDFromContext], and [ToolOutput.Halt] is honored. +// +// If the model sends arguments that are not a valid JSON object the call +// short-circuits with an error [ToolResponse] before fn is invoked. +func NewRawTool( + name, description string, + schema map[string]any, + fn func(ctx context.Context, args map[string]any) (ToolOutput, error), +) Tool { + tool := fantasy.NewAgentTool(name, description, + func(ctx context.Context, input rawToolInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID) + args := map[string]any{} + // Normalise whitespace before the null/empty guard so values like + // " null " or "\tnull\n" take the same skip-unmarshal path as the + // bare "null" and the handler always receives a non-nil empty map. + // (fantasy currently trims via its RawMessage decode, but this keeps + // the guard correct independent of that upstream behaviour.) + trimmed := bytes.TrimSpace(input) + if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte("null")) { + if err := json.Unmarshal(trimmed, &args); err != nil { + return fantasy.NewTextErrorResponse(fmt.Sprintf("invalid arguments for tool %q: %v", name, err)), nil + } + } + result, err := fn(ctx, args) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + recordHalt(ctx, name, result) + return toolOutputToResponse(result), nil + }, + ) + // Override the auto-generated schema with the caller-supplied one so the + // model sees the real input shape instead of the permissive raw-message + // schema. + if len(schema) > 0 { + info := tool.Info() + info.Parameters = schema + info.Required = requiredFromSchema(schema) + tool = &schemaOverrideTool{AgentTool: tool, info: info} + } + return tool +} + +// schemaOverrideTool wraps an [fantasy.AgentTool] to advertise a +// caller-supplied JSON Schema instead of the auto-generated one. Used by +// [NewRawTool]. +type schemaOverrideTool struct { + fantasy.AgentTool + info fantasy.ToolInfo +} + +// Info returns the tool info carrying the overridden parameter schema. +func (t *schemaOverrideTool) Info() fantasy.ToolInfo { return t.info } + +// requiredFromSchema extracts the top-level "required" array from a JSON +// Schema object, tolerating both []string and []any element types. +func requiredFromSchema(schema map[string]any) []string { + raw, ok := schema["required"] + if !ok { + return nil + } + switch v := raw.(type) { + case []string: + return v + case []any: + out := make([]string, 0, len(v)) + for _, e := range v { + if s, ok := e.(string); ok { + out = append(out, s) + } + } + return out + default: + return nil + } +} + // --- Individual tool constructors --- // NewReadTool creates a file-reading tool. diff --git a/pkg/kit/turn_capture.go b/pkg/kit/turn_capture.go new file mode 100644 index 00000000..5063c0e3 --- /dev/null +++ b/pkg/kit/turn_capture.go @@ -0,0 +1,48 @@ +package kit + +import ( + "context" + "sync" +) + +// streamCollectorKey is the context key carrying the per-turn stream collector +// so the agent callbacks in generate can capture delta events into +// TurnResult.Stream without re-implementing an OnMessageUpdate handler. +type streamCollectorKey struct{} + +// streamCollector accumulates StreamEvents observed during a single turn in +// emit order. It is safe for concurrent use because tool-call deltas and text +// deltas may be emitted from different goroutines. +type streamCollector struct { + mu sync.Mutex + events []StreamEvent +} + +func (c *streamCollector) add(e StreamEvent) { + if c == nil { + return + } + c.mu.Lock() + c.events = append(c.events, e) + c.mu.Unlock() +} + +func (c *streamCollector) drain() []StreamEvent { + if c == nil { + return nil + } + c.mu.Lock() + defer c.mu.Unlock() + if len(c.events) == 0 { + return nil + } + out := make([]StreamEvent, len(c.events)) + copy(out, c.events) + return out +} + +// streamCollectorFromContext returns the per-turn stream collector if present. +func streamCollectorFromContext(ctx context.Context) *streamCollector { + c, _ := ctx.Value(streamCollectorKey{}).(*streamCollector) + return c +} diff --git a/pkg/kit/turn_capture_test.go b/pkg/kit/turn_capture_test.go new file mode 100644 index 00000000..18efaa0c --- /dev/null +++ b/pkg/kit/turn_capture_test.go @@ -0,0 +1,86 @@ +package kit + +import ( + "context" + "testing" +) + +func TestHaltHolderFirstWins(t *testing.T) { + h := &haltHolder{} + if halted, _, _ := h.snapshot(); halted { + t.Fatal("new holder should not be halted") + } + h.set("finish", 42) + h.set("other", 99) // ignored — first halt wins + halted, name, val := h.snapshot() + if !halted { + t.Fatal("holder should be halted") + } + if name != "finish" { + t.Fatalf("toolName = %q, want finish", name) + } + if v, ok := val.(int); !ok || v != 42 { + t.Fatalf("value = %#v, want 42", val) + } +} + +func TestRecordHalt(t *testing.T) { + holder := &haltHolder{} + ctx := context.WithValue(context.Background(), haltHolderKey{}, holder) + + // Non-halting output records nothing. + recordHalt(ctx, "noop", ToolOutput{Content: "ok"}) + if halted, _, _ := holder.snapshot(); halted { + t.Fatal("non-halting output should not halt") + } + + recordHalt(ctx, "finish", ToolOutput{Halt: true, FinalValue: "done"}) + halted, name, val := holder.snapshot() + if !halted || name != "finish" || val != "done" { + t.Fatalf("halt not recorded: halted=%v name=%q val=%v", halted, name, val) + } + + // Missing holder in context is a safe no-op. + recordHalt(context.Background(), "finish", ToolOutput{Halt: true}) +} + +func TestStreamCollector(t *testing.T) { + c := &streamCollector{} + if c.drain() != nil { + t.Fatal("empty collector should drain to nil") + } + c.add(StreamEvent{Kind: StreamEventTextDelta, Text: "A"}) + c.add(StreamEvent{Kind: StreamEventTextDelta, Text: "B"}) + c.add(StreamEvent{Kind: StreamEventToolCallChunk, ToolName: "x"}) + + out := c.drain() + if len(out) != 3 { + t.Fatalf("len = %d, want 3", len(out)) + } + if out[0].Text != "A" || out[1].Text != "B" { + t.Fatalf("order not preserved: %#v", out) + } + if out[2].Kind != StreamEventToolCallChunk || out[2].ToolName != "x" { + t.Fatalf("tool chunk wrong: %#v", out[2]) + } +} + +// nil receiver collector (no per-turn collector attached) must be safe. +func TestStreamCollectorNil(t *testing.T) { + var c *streamCollector + c.add(StreamEvent{Kind: StreamEventTextDelta, Text: "x"}) // no panic + if c.drain() != nil { + t.Fatal("nil collector should drain to nil") + } +} + +func TestStreamCollectorFromContext(t *testing.T) { + if streamCollectorFromContext(context.Background()) != nil { + t.Fatal("expected nil collector for bare context") + } + c := &streamCollector{} + ctx := context.WithValue(context.Background(), streamCollectorKey{}, c) + if streamCollectorFromContext(ctx) != c { + t.Fatal("collector not retrieved from context") + } +} diff --git a/www/pages/advanced/json-output.md b/www/pages/advanced/json-output.md index f7adf6ab..6caeff5e 100644 --- a/www/pages/advanced/json-output.md +++ b/www/pages/advanced/json-output.md @@ -93,3 +93,21 @@ result, err := host.PromptResult(ctx, "Count files") fmt.Println(result.Response) fmt.Println(result.Usage.TotalTokens) ``` + +`PromptResult` blocks until end-of-turn regardless of streaming mode. When +streaming is enabled, every delta observed during the turn is also captured in +order in `result.Stream` (`[]kit.StreamEvent`), so you can assert streamed +ordering deterministically without wiring an `OnMessageUpdate` collector: + +```go +for _, ev := range result.Stream { + switch ev.Kind { + case kit.StreamEventTextDelta: + fmt.Print(ev.Text) + case kit.StreamEventReasoningDelta: + fmt.Print(ev.Reasoning) + case kit.StreamEventToolCallChunk: + fmt.Printf("[%s %s]", ev.ToolName, ev.Args) + } +} +``` diff --git a/www/pages/sdk/callbacks.md b/www/pages/sdk/callbacks.md index c3c50400..ccb01662 100644 --- a/www/pages/sdk/callbacks.md +++ b/www/pages/sdk/callbacks.md @@ -176,12 +176,30 @@ Lower values run first. First non-nil result wins. | `SourceEvent` | `OnSource` | LLM referenced a source (e.g., web search) | | `ErrorEvent` | `OnError` | Agent-level error during streaming | | `RetryEvent` | `OnRetry` | LLM request retried after transient error | -| `CompactionEvent` | `OnCompaction` | Conversation compacted | +| `CompactionEvent` | `OnCompaction` | Conversation compacted (fires on success **and** failure — check `Err`) | | `SteerConsumedEvent` | `OnSteerConsumed` | Steering messages injected into turn | | `PasswordPromptEvent` | — | Sudo command needs password (respond via `ResponseCh`) | > **Note:** `OnStreaming` is a deprecated alias for `OnMessageUpdate` and will be removed in a future release. +### Compaction telemetry + +`CompactionEvent` fires after every compaction attempt. On success `Err` is +`nil` and the summary/token/file fields are populated; on failure `Err` is +non-nil and the rest are zero-valued. This lets you wire symmetric +start/end lifecycle telemetry without hand-rolling the failure path: + +```go +host.OnCompaction(func(e kit.CompactionEvent) { + if e.Err != nil { + log.Printf("compaction failed: %v", e.Err) + return + } + log.Printf("compacted %d → %d tokens (%d messages removed)", + e.OriginalTokens, e.CompactedTokens, e.MessagesRemoved) +}) +``` + ## Subagent event monitoring Monitor real-time events from LLM-initiated subagents (when the model uses the `subagent` tool): diff --git a/www/pages/sdk/overview.md b/www/pages/sdk/overview.md index 9214b4d2..4e2fdf51 100644 --- a/www/pages/sdk/overview.md +++ b/www/pages/sdk/overview.md @@ -129,12 +129,36 @@ The SDK provides several prompt variants: | Method | Description | |--------|-------------| | `Prompt(ctx, message)` | Simple prompt, returns response string | -| `PromptWithOptions(ctx, message, opts)` | With per-call options | +| `PromptWithOptions(ctx, message, opts)` | With per-call options (model, tools, thinking level, provider creds) | | `PromptResult(ctx, message)` | Returns full `TurnResult` with usage stats | +| `PromptResultWithOptions(ctx, message, opts)` | Per-call options variant that returns the full `TurnResult` | | `PromptResultWithFiles(ctx, message, files)` | Multimodal with file attachments | | `Steer(ctx, instruction)` | System-level steering without user message | | `FollowUp(ctx, text)` | Continue without new user input | +### Per-call overrides + +`PromptOptions` scopes configuration to a **single call** and restores the +agent's prior state afterwards — no need to rebuild a `*Kit` per request. This +suits multi-tenant hosts that resolve the model, credentials, or tool set per +request: + +```go +result, err := host.PromptResultWithOptions(ctx, "Summarise this ticket", kit.PromptOptions{ + SystemMessage: "You are a concise triage assistant.", // prepended for this call + Model: "anthropic/claude-haiku-3-5-20241022", // overrides the default model + ThinkingLevel: "low", // "off" | "low" | "medium" | "high" + ExtraTools: []kit.Tool{lookupTool}, // added on top of the core set + ProviderURL: "https://proxy.tenant-a/v1", // per-tenant endpoint + ProviderAPIKey: tenantKey, // per-tenant credential +}) +``` + +Every field is optional; a zero value means "use the agent's default." The +prior model, thinking level, provider credentials, and tool set are all +restored before the call returns, and concurrent option-driven prompts are +serialized so the apply/restore window of one call never races another. + ## Custom tools Create custom tools with `kit.NewTool`. The JSON schema is auto-generated from the input struct — no external dependencies required: @@ -175,6 +199,64 @@ Binary data (images, audio, etc.) in `ToolOutput.Data` is automatically forwarde Use `kit.NewParallelTool` for tools that are safe to run concurrently. Use `kit.ToolCallIDFromContext(ctx)` to retrieve the LLM-assigned call ID for logging or tracing. +### Schema-driven tools + +When the tool's input shape isn't known at compile time — tools sourced from +JSON Schema definitions in skill files, MCP server catalogs, or user-supplied +definitions — use `kit.NewRawTool`. It takes a JSON Schema and a handler that +receives the decoded arguments as a `map[string]any`, so no Go input type is +required: + +```go +schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string", "description": "City name"}, + }, + "required": []any{"city"}, +} + +weatherTool := kit.NewRawTool("get_weather", "Get current weather for a city", schema, + func(ctx context.Context, args map[string]any) (kit.ToolOutput, error) { + return kit.TextResult("72°F, sunny in " + args["city"].(string)), nil + }, +) +``` + +The `schema` is advertised to the model as the tool's parameter schema. If the +model sends arguments that aren't a valid JSON object, the call short-circuits +with an error result before your handler runs. + +### Halting the agent loop + +For structured-result patterns — the model calls a `finish(...)` tool with a +typed argument and the loop should terminate, returning that value to the +caller — set `Halt` and `FinalValue` on the returned `ToolOutput` instead of +smuggling the value out through a side-channel: + +```go +finishTool := kit.NewTool("finish", "Return the final structured answer", + func(ctx context.Context, input AnswerInput) (kit.ToolOutput, error) { + return kit.ToolOutput{ + Content: "done", + Halt: true, // terminate the agent loop after this call + FinalValue: input, // surfaced to the caller + }, nil + }, +) + +result, _ := host.PromptResult(ctx, "Extract the order details") +if result.HaltedByTool == "finish" { + answer := result.FinalValue.(AnswerInput) // the typed value your handler stored + _ = answer +} +``` + +`TurnResult.HaltedByTool` names the tool that halted the turn (empty if the +turn ended for any other reason), and `TurnResult.FinalValue` carries whatever +your handler placed in `ToolOutput.FinalValue`. `Halt`/`FinalValue` work with +`NewTool`, `NewParallelTool`, and `NewRawTool` alike. + ## Generation & provider overrides SDK consumers can configure generation parameters and provider endpoints @@ -327,6 +409,19 @@ Key points: `Options.NoSkills`, and `Options.NoContextFiles` continue to control the startup set; the runtime API mutates from whatever state `New()` produced. See [SDK options](/sdk/options#skills--configuration). +- **`fs.FS`-backed discovery.** The package-level loaders `kit.LoadSkill`, + `kit.LoadSkillsFromDir`, and `kit.LoadSkills` are path-string based; + `kit.LoadSkillsFromFS(fsys, root)` is the `fs.FS`-typed counterpart for + `embed.FS` distribution, `fstest.MapFS` tests, or per-tenant virtual + filesystems. Feed the result into `host.SetSkills(...)`: + + ```go + //go:embed skills + var skillsFS embed.FS + + loaded, _ := kit.LoadSkillsFromFS(skillsFS, "skills") + host.SetSkills(loaded) + ``` ## MCP prompts and resources @@ -382,6 +477,55 @@ if host.ShouldCompact() { } ``` +## Provider error classification + +Provider failures are wrapped with exported sentinels so you can branch on the +failure category with `errors.Is` instead of string-matching the underlying +HTTP error. `PromptResult` / `Prompt` already return classified errors; you can +also classify any provider error yourself with `kit.ClassifyProviderError`: + +```go +_, err := host.PromptResult(ctx, prompt) +switch { +case errors.Is(err, kit.ErrContextOverflow): + host.Compact(ctx, nil, "") // compact and retry +case errors.Is(err, kit.ErrRateLimit): + backoffAndRetry() +case errors.Is(err, kit.ErrAuth): + rePromptForKey() +case errors.Is(err, kit.ErrProviderUnavailable): + retryLater() +case errors.Is(err, kit.ErrInvalidRequest): + log.Printf("non-retryable: %v", err) +} +``` + +| Sentinel | Meaning | +|----------|---------| +| `kit.ErrContextOverflow` | Request exceeded the model's context window | +| `kit.ErrRateLimit` | Provider throttled the request | +| `kit.ErrAuth` | Credential / authorization failure | +| `kit.ErrProviderUnavailable` | Transient upstream failure (5xx, network, timeout) | +| `kit.ErrInvalidRequest` | Structurally invalid request — retrying won't help | + +The original error stays reachable via `errors.Is`, so you never lose the +provider's detail message. + +## Graceful shutdown + +`Close()` releases MCP connections, model resources, and the session file +handle. When shutdown must be bounded by a deadline, use `CloseContext`: + +```go +shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() +if err := host.CloseContext(shutdownCtx); err != nil { + log.Printf("shutdown: %v", err) +} +``` + +`Close()` is equivalent to `CloseContext(context.Background())`. + ## In-process subagents Spawn child Kit instances without subprocess overhead: diff --git a/www/pages/sdk/sessions.md b/www/pages/sdk/sessions.md index 947e3a97..698b6e12 100644 --- a/www/pages/sdk/sessions.md +++ b/www/pages/sdk/sessions.md @@ -99,6 +99,12 @@ host, _ := kit.New(ctx, &kit.Options{ }) ``` -The interface requires methods for message storage, branching, compaction, extension data, and lifecycle management. See the [SDK skill reference](https://github.com/mark3labs/kit) for the complete interface definition. +The interface requires methods for message storage, branching, compaction, branch summaries, extension data, and lifecycle management. See the [`SessionManager` interface definition](https://pkg.go.dev/github.com/mark3labs/kit/pkg/kit#SessionManager) for the complete method set. + +The `AppendBranchSummary(fromID, summary)` method backs `host.CollapseBranch`, +which collapses a branch range into a single summary entry. Custom managers +that don't track branch summaries can return `kit.ErrBranchSummaryNotSupported` +from that method; `host.CollapseBranch` then surfaces the same sentinel so +callers can detect it with `errors.Is`. When using a custom `SessionManager`, the `SessionPath`, `Continue`, and `NoSession` options are ignored — your manager handles its own storage and session selection.