diff --git a/.changeset/local-models-cost-routing.md b/.changeset/local-models-cost-routing.md new file mode 100644 index 0000000..240da17 --- /dev/null +++ b/.changeset/local-models-cost-routing.md @@ -0,0 +1,26 @@ +--- +"@therr/tiny-code": minor +--- + +Add local models and cost-aware, local-first routing. + +- **Local (Ollama) provider.** Talk to a local Ollama server over its + OpenAI-compatible API (`--provider ollama`), with an idle timeout so a hung + model can't freeze the REPL, best-effort token-usage reporting, and configurable + `maxTokens`. +- **Local-first routing.** Set `routing: "local-first"` with an `escalateTo` + target to run a cheap/local model by default and escalate heavy turns (or a + stuck local model, via the new `escalate` tool) to a frontier model — with full + conversation context preserved. Escalation is sticky across follow-up turns. +- **Model-selection policy** is now owned by a pluggable `ModelDecisionEngine` + (`LocalFirstModelEngine`), keeping the agent loop pure mechanism. +- **Compute awareness.** On startup with a local model, tiny-code estimates RAM + need vs. machine capacity and warns when a model likely won't fit or is too + small (≤3B) to tool-call reliably; an over-RAM local model is routed to the + frontier up front. +- **Priority-driven model selection.** `priority` (`performance` / `cost` / + `balanced`, or `TINY_CODE_PRIORITY`) auto-picks a catalog model when none is + pinned. +- The `/costs` view reports session usage, estimated spend, and routing, and the + usage line distinguishes an unpriced *cloud* turn ("cost unknown") from a + *local* turn ("no API cost"). diff --git a/.env.example b/.env.example index fb759a3..ab7215c 100644 --- a/.env.example +++ b/.env.example @@ -1,10 +1,14 @@ -# Provide at least one. If both are present, Anthropic is used by default. +# Provide at least one for cloud providers. If both are present, Anthropic is +# the default. Ollama runs locally and needs no key. ANTHROPIC_API_KEY= GEMINI_API_KEY= # Optional overrides (also settable via config file / CLI flags) -# TINY_CODE_PROVIDER=anthropic # anthropic | gemini +# TINY_CODE_PROVIDER=anthropic # anthropic | gemini | ollama # TINY_CODE_MODEL=claude-opus-4-8 +# TINY_CODE_OLLAMA_URL=http://localhost:11434/v1 # Ollama OpenAI-compatible endpoint +# TINY_CODE_PRIORITY=performance # performance | cost | balanced — auto-picks a model when none is pinned +# TINY_CODE_EFFORT=high # low | medium | high | xhigh | max — Anthropic thinking budget # Self-improvement: reflect on sessions and propose markdown-only improvement PRs. # On by default; set to 0 to disable. Requires the `gh` CLI installed + authed. diff --git a/README.md b/README.md index 792fe5c..17b89cb 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,13 @@ A small, extensible CLI coding agent built around one constraint: **keep token usage low**. As coding-agent costs climb, tiny-code automates the savings so -you don't have to. Interactive terminal REPL, interchangeable **Anthropic** and -**Gemini** models, and just the core features you actually use: read/write/edit -files, run shell commands, search code, and a custom commands/skills system. -No business logic baked in. +you don't have to. Interactive terminal REPL, interchangeable **Anthropic**, +**Gemini**, and **local (Ollama)** models, and just the core features you +actually use: read/write/edit files, run shell commands, search code, and a +custom commands/skills system. No business logic baked in. + +Run cheap, open-weight models locally and **escalate heavy work to a frontier +model only when needed** — see [Local models & cost-aware routing](#local-models--cost-aware-routing). > Status: early (v0.x). Published as `@therr/tiny-code`; the binary is > `tiny-code`. Names may change before the first npm publish. @@ -39,18 +42,61 @@ export GEMINI_API_KEY=... tiny-code # start the REPL (uses an available key) tiny-code --provider gemini # force a provider tiny-code --model claude-opus-4-8 +tiny-code --provider ollama --model gemma3:12b # run a local model (no API cost) ``` In the REPL: type a request, watch it work. Mutating actions (writes, edits, shell commands) prompt for approval unless pre-approved in config. - `/help` — list commands +- `/costs` — session token usage, estimated $ cost, and cost-saving tips - `/clear` — clear the conversation history and start fresh - `/models` — show known models, pricing, and the active one (see below) - `/improve` — reflect on the session and propose an improvement PR (see below) - `/ [args]` — run a custom command (see below) - `/exit` — quit +## Local models & cost-aware routing + +tiny-code talks to a local [Ollama](https://ollama.com) server over its +OpenAI-compatible API, so any model you've pulled is available — including +**Google Gemma 3** (`gemma3:4b`, `gemma3:12b`, `gemma3:27b`) and +`qwen2.5-coder` (the default, which tool-calls reliably). + +```bash +ollama serve +ollama pull qwen2.5-coder:7b +tiny-code --provider ollama --model qwen2.5-coder:7b +``` + +**Mind the compute cost.** Local models are free of API charges but use your +machine's RAM/VRAM. On startup with an Ollama model, tiny-code prints how much +memory the model needs versus what's free, and warns if it likely won't fit or +if the model is too small (≤3B) to tool-call reliably. Rough guide (≈Q4): + +| Model | ~RAM needed | Good for | +| ------------ | ----------- | --------------------------------- | +| `gemma3:1b` | ~1 GB | trivial text (poor at tool calls) | +| `gemma3:4b` | ~3 GB | lightweight edits, search | +| `gemma3:12b` | ~7 GB | most coding tasks | +| `gemma3:27b` | ~16 GB | stronger reasoning | + +**Local-first routing.** Set a `routing` of `local-first` with an `escalateTo` +target: every turn starts on the cheap/local model, and tiny-code escalates to +the frontier model when a turn looks heavy (refactors, debugging, multi-file +work) or when the local model gets stuck and calls the built-in `escalate` tool. +You get local speed and zero cost for the bulk of the work, and frontier power +only for the hard parts. Run `/costs` any time for usage, spend, and tips. + +```json +{ + "provider": "ollama", + "model": "qwen2.5-coder:7b", + "routing": "local-first", + "escalateTo": { "provider": "anthropic", "model": "claude-opus-4-8" } +} +``` + ## Project context On start, the agent walks up from the working directory looking for `AGENTS.md` @@ -86,11 +132,14 @@ CLI flags. { "provider": "anthropic", "model": "claude-opus-4-8", + "ollamaBaseUrl": "http://localhost:11434/v1", "priority": "performance", "maxTokens": 16000, "thinking": true, "effort": "high", "maxIterations": 50, + "routing": "off", + "escalateTo": { "provider": "anthropic", "model": "claude-opus-4-8" }, "allow": { "tools": [], "bash": ["npm test", "git status", "git diff"], @@ -102,6 +151,14 @@ CLI flags. `allow` pre-approves mutating actions so they skip the confirmation prompt: `bash` matches command prefixes, `write` matches path globs for write/edit. +`routing: "local-first"` plus `escalateTo` enables cost-aware routing (see +[above](#local-models--cost-aware-routing)); it defaults to `local-first` +automatically whenever `escalateTo` is present. `ollamaBaseUrl` points at your +Ollama server's OpenAI-compatible endpoint. + +Approximate cloud pricing used for the `/costs` estimate lives in the model +catalog (`src/models/catalog.ts`) — edit it to match current vendor rates. + ## Token efficiency Minimizing token usage is a first-class goal — coding-agent bills grow fast, diff --git a/TODO.md b/TODO.md index d149cce..10484b0 100644 --- a/TODO.md +++ b/TODO.md @@ -23,6 +23,19 @@ Explore/Plan agent). **Approach:** a `spawn_agent` tool whose `execute` construc a child `AgentLoop` with its own message history and a read-only tool subset, returning the child's final text. Keep depth at 1 to start. +> Note: the cheap/expensive model split is now handled by **local-first +> routing** (`routing: "local-first"` + `escalateTo`): turns start on the +> local/cheap model and escalate to a frontier model when heavy or stuck (see +> `src/agent/router.ts`, `src/tools/escalate.ts`, and the loop's escalation +> logic). Sub-agents remain useful for *parallel* isolated runs. + +## More local-model interoperability +Ollama is wired in via its OpenAI-compatible endpoint (`src/providers/ollama.ts`), +which already covers LM Studio and vLLM (same wire format) by pointing +`ollamaBaseUrl`/`TINY_CODE_OLLAMA_URL` at them. **Next:** an optional +`/api/tags` probe to list locally-installed models and surface tokens/sec in the +usage line; per-model context-window awareness for the RAM advisory. + ## Web search / fetch Let the agent look up docs during a task. **Approach:** add `web_search` and `web_fetch` tools. For Anthropic, optionally delegate to the server-side diff --git a/package-lock.json b/package-lock.json index eb5c3b2..1685228 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@therr/tiny-code", - "version": "0.1.0", + "version": "0.2.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@therr/tiny-code", - "version": "0.1.0", + "version": "0.2.0", "license": "SEE LICENSE IN LICENSE", "dependencies": { "@anthropic-ai/sdk": "^0.69.0", diff --git a/package.json b/package.json index d3482f7..b035254 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@therr/tiny-code", - "version": "0.1.0", + "version": "0.2.0", "description": "A small, extensible CLI coding agent with interchangeable Anthropic and Gemini models.", "type": "module", "bin": { diff --git a/src/agent/decision/index.ts b/src/agent/decision/index.ts new file mode 100644 index 0000000..ea77556 --- /dev/null +++ b/src/agent/decision/index.ts @@ -0,0 +1,3 @@ +export type { ModelDecisionEngine, RouteDecision, TurnSignals } from './types.js'; +export { LocalFirstModelEngine } from './localFirst.js'; +export type { LocalFirstOptions } from './localFirst.js'; diff --git a/src/agent/decision/localFirst.ts b/src/agent/decision/localFirst.ts new file mode 100644 index 0000000..87c5d92 --- /dev/null +++ b/src/agent/decision/localFirst.ts @@ -0,0 +1,87 @@ +import type { ModelProvider } from '../../providers/types.js'; +import { classifyTurn, type TaskWeight } from '../router.js'; +import { checkLocalModel } from '../../system/resources.js'; +import { estimateCost } from '../../models/catalog.js'; +import type { ModelDecisionEngine, RouteDecision, TurnSignals } from './types.js'; + +export interface LocalFirstOptions { + /** The cheap/local model that handles turns by default. */ + primary: ModelProvider; + /** The frontier model heavy/stuck turns escalate to. */ + escalation: ModelProvider; + /** Task-weight classifier. Defaults to {@link classifyTurn}. */ + classify?: (input: string) => TaskWeight; + /** Consecutive tool-error iterations before auto-escalating. Defaults to 3. */ + stuckThreshold?: number; + /** RAM-fit check (injectable for tests). Defaults to {@link checkLocalModel}. */ + ramCheck?: (model: string) => { warn: boolean }; +} + +/** Per-MTok probe used to compare relative model cost for the cost-aware note. */ +const COST_PROBE = { inputTokens: 1_000_000, outputTokens: 1_000_000 }; + +/** + * The default local-first policy: handle each turn on the cheap/local model and + * escalate to the frontier model only when capability demands it — the task + * looks heavy up front, the model explicitly hands off via the `escalate` tool, + * it gets stuck on repeated tool errors, or (compute awareness) a local model + * won't fit in available RAM. + * + * Cost awareness is wired in but used defensively: the engine can *report* the + * cost implication of an escalation (see {@link costNote}) but never initiates + * a cost-driven route change — escalation is always a capability decision. + */ +export class LocalFirstModelEngine implements ModelDecisionEngine { + private readonly primary: ModelProvider; + private readonly escalation: ModelProvider; + private readonly classify: (input: string) => TaskWeight; + private readonly stuckThreshold: number; + private readonly ramCheck: (model: string) => { warn: boolean }; + + constructor(opts: LocalFirstOptions) { + this.primary = opts.primary; + this.escalation = opts.escalation; + this.classify = opts.classify ?? classifyTurn; + this.stuckThreshold = opts.stuckThreshold ?? 3; + this.ramCheck = opts.ramCheck ?? checkLocalModel; + } + + selectInitial(userInput: string): RouteDecision { + // Compute awareness: a local primary that won't fit in RAM runs slowly or + // fails outright, so route it to the frontier up front — even for light work. + if (this.primary.name === 'ollama' && this.ramCheck(this.primary.model).warn) { + return { provider: this.escalation, reason: 'compute: local model exceeds RAM' }; + } + if (this.classify(userInput) === 'heavy') { + return { provider: this.escalation, reason: 'heavy task' }; + } + return { provider: this.primary }; + } + + considerEscalation(signals: TurnSignals): RouteDecision | undefined { + if (signals.alreadyEscalated) return undefined; + if (signals.escalateRequested) { + return { provider: this.escalation, reason: 'requested by model' }; + } + if (signals.consecutiveErrors >= this.stuckThreshold) { + return { provider: this.escalation, reason: 'stuck — repeated tool errors' }; + } + return undefined; + } + + /** + * A short, human-readable summary of what escalating costs, relative to the + * primary. Pure reporting — does not influence routing. Local models report + * as having no API cost. + */ + costNote(): string { + const from = estimateCost(this.primary.model, COST_PROBE); + const to = estimateCost(this.escalation.model, COST_PROBE); + if (to === null) return 'escalates to a local model (no API cost)'; + if (from === null || from === 0) { + return 'escalates from a free/local model to a paid model (adds API cost)'; + } + const multiplier = to / from; + return `escalation costs ~${multiplier.toFixed(1)}× the primary per token`; + } +} diff --git a/src/agent/decision/types.ts b/src/agent/decision/types.ts new file mode 100644 index 0000000..c0cd976 --- /dev/null +++ b/src/agent/decision/types.ts @@ -0,0 +1,47 @@ +import type { ModelProvider } from '../../providers/types.js'; + +/** + * A model-selection decision: the provider to use and an optional human-readable + * reason. An empty/absent `reason` means "no route note" — the loop only fires + * {@link AgentUI.onRoute} when a decision both changes the active provider and + * carries a reason. + */ +export interface RouteDecision { + provider: ModelProvider; + reason?: string; +} + +/** + * Runtime signals the loop feeds the engine once per iteration so it can decide + * whether to switch providers mid-turn. These are mechanical bookkeeping the + * loop already tracks; the engine owns the policy that interprets them. + */ +export interface TurnSignals { + /** The model called the `escalate` tool during this iteration. */ + escalateRequested: boolean; + /** Consecutive iterations that ended with at least one tool error. */ + consecutiveErrors: number; + /** Whether the turn has already been escalated (keeps the engine stateless). */ + alreadyEscalated: boolean; + /** The provider currently handling the turn. */ + current: ModelProvider; + /** Iteration index (0-based). */ + iteration: number; +} + +/** + * Owns all model-selection policy. {@link AgentLoop} depends on this interface + * instead of inlining classification + escalation rules, so the loop keeps only + * mechanism (send → stream → run tools → repeat) and the policy lives in one + * cohesive, testable place. Implementations may reason about task weight, cost, + * and compute (RAM) fit; the loop never sees that reasoning. + */ +export interface ModelDecisionEngine { + /** Pick the provider that starts a turn, given the user's input. */ + selectInitial(userInput: string): RouteDecision; + /** + * Decide whether to switch providers mid-turn. Returns `undefined` to stay on + * the current provider. Called once per iteration after tool results. + */ + considerEscalation(signals: TurnSignals): RouteDecision | undefined; +} diff --git a/src/agent/loop.ts b/src/agent/loop.ts index 9610314..d57570b 100644 --- a/src/agent/loop.ts +++ b/src/agent/loop.ts @@ -3,6 +3,7 @@ import type { ToolRegistry } from '../tools/registry.js'; import type { PermissionGate } from '../permissions/gate.js'; import type { ToolResult } from '../tools/types.js'; import type { Message, ToolResultBlock, ToolUseBlock } from './types.js'; +import type { ModelDecisionEngine, RouteDecision } from './decision/index.js'; export type { Usage }; @@ -12,7 +13,19 @@ export interface AgentUI { onToolStart(name: string, input: unknown): void; onToolResult(name: string, result: ToolResult): void; onToolDenied(name: string): void; - onUsage(usage: Usage): void; + /** + * `model` and `provider` identify what produced the usage, so the UI can + * price it accurately — and tell an unpriced *cloud* turn (cost unknown) apart + * from a *local* turn (no API cost). + */ + onUsage(usage: Usage, model?: string, provider?: string): void; + /** + * Fired when local-first routing sends a turn to the frontier model. `initial` + * is true when the turn *started* there (up-front classification), false when + * it was escalated mid-turn — so the UI doesn't claim "escalated" for a turn + * that never ran locally. + */ + onRoute(provider: string, model: string, reason: string, initial?: boolean): void; onAssistantEnd(): void; onMaxIterations(): void; } @@ -25,6 +38,11 @@ export interface AgentLoopOptions { ui: AgentUI; cwd: string; maxIterations?: number; + /** + * Owns model-selection policy (which model starts a turn, when to escalate). + * When omitted, the loop runs `provider` as a single provider with no routing. + */ + engine?: ModelDecisionEngine | undefined; } /** @@ -40,8 +58,17 @@ export class AgentLoop { private readonly ui: AgentUI; private readonly cwd: string; private readonly maxIterations: number; + private readonly engine: ModelDecisionEngine | undefined; private readonly messages: Message[] = []; private sessionUsage: Usage = { inputTokens: 0, outputTokens: 0 }; + /** + * Set once a turn escalates mid-flight (model request or stuck). Subsequent + * turns then start on the frontier model so a multi-turn hard task doesn't + * ping-pong back to the local model on each follow-up. Reset by clearHistory(). + */ + private escalatedSession = false; + /** The provider escalated to; reused to start follow-up turns once sticky. */ + private escalatedProvider: ModelProvider | undefined; constructor(opts: AgentLoopOptions) { this.provider = opts.provider; @@ -51,6 +78,7 @@ export class AgentLoop { this.ui = opts.ui; this.cwd = opts.cwd; this.maxIterations = opts.maxIterations ?? 50; + this.engine = opts.engine; } /** Conversation history (for inspection / persistence). */ @@ -59,9 +87,12 @@ export class AgentLoop { } /** Drop the conversation history so the next turn starts fresh. Cumulative - * token usage is preserved, since it reflects the whole session's cost. */ + * token usage is preserved, since it reflects the whole session's cost. + * Also clears sticky escalation: a fresh conversation re-routes from scratch. */ clearHistory(): void { this.messages.length = 0; + this.escalatedSession = false; + this.escalatedProvider = undefined; } /** Cumulative token usage across all turns in this session. */ @@ -74,11 +105,28 @@ export class AgentLoop { this.messages.push({ role: 'user', content: [{ type: 'text', text: userInput }] }); const tools = this.registry.toSchemas(); + let active = this.provider; + let escalated = false; + let consecutiveErrors = 0; + + if (this.engine) { + if (this.escalatedSession && this.escalatedProvider) { + // A prior turn escalated; stay on the frontier model for follow-ups so a + // multi-turn hard task doesn't ping-pong back to the local model. + active = this.escalatedProvider; + escalated = true; + } else { + const initial = this.engine.selectInitial(userInput); + active = this.applyRoute(active, initial, true); + escalated = active !== this.provider; + } + } + for (let iteration = 0; iteration < this.maxIterations; iteration += 1) { let text = ''; const toolCalls: ToolUseBlock[] = []; - for await (const event of this.provider.send({ + for await (const event of active.send({ system: this.system, messages: [...this.messages], tools, @@ -91,7 +139,7 @@ export class AgentLoop { } else { this.sessionUsage.inputTokens += event.usage.inputTokens; this.sessionUsage.outputTokens += event.usage.outputTokens; - this.ui.onUsage(event.usage); + this.ui.onUsage(event.usage, active.model, active.name); } } @@ -104,16 +152,61 @@ export class AgentLoop { if (toolCalls.length === 0) return; + const escalateRequested = toolCalls.some((c) => c.name === 'escalate'); + const results: ToolResultBlock[] = []; + let anyError = false; for (const call of toolCalls) { - results.push(await this.executeToolCall(call)); + const result = await this.executeToolCall(call); + if (result.isError) anyError = true; + results.push(result); } this.messages.push({ role: 'user', content: results }); + + // Hand the turn's runtime signals to the engine, which owns the policy + // for whether to switch providers mid-turn (explicit escalate, stuck, …). + consecutiveErrors = anyError ? consecutiveErrors + 1 : 0; + if (this.engine) { + const next = this.engine.considerEscalation({ + escalateRequested, + consecutiveErrors, + alreadyEscalated: escalated, + current: active, + iteration, + }); + if (next) { + active = this.applyRoute(active, next); + escalated = active !== this.provider; + if (escalated) { + // Remember it so follow-up turns start here (sticky escalation). + this.escalatedSession = true; + this.escalatedProvider = active; + } + } + } } this.ui.onMaxIterations(); } + /** + * Apply a routing decision: switch to its provider and surface an `onRoute` + * event when it actually changes the active provider and carries a reason. + * `initial` distinguishes up-front routing ("routed to") from a mid-turn + * hand-off ("escalated to") so the UI never claims a turn was escalated when + * it started on the frontier model. + */ + private applyRoute( + active: ModelProvider, + decision: RouteDecision, + initial = false, + ): ModelProvider { + if (decision.provider !== active && decision.reason) { + this.ui.onRoute(decision.provider.name, decision.provider.model, decision.reason, initial); + } + return decision.provider; + } + private async executeToolCall(call: ToolUseBlock): Promise { const tool = this.registry.get(call.name); if (!tool) { diff --git a/src/agent/router.ts b/src/agent/router.ts new file mode 100644 index 0000000..b3b2871 --- /dev/null +++ b/src/agent/router.ts @@ -0,0 +1,55 @@ +/** + * Lightweight, dependency-free task classification for local-first routing. + * + * The cheap/local model handles each turn by default; this heuristic flags the + * turns that are better started on the frontier model. It is intentionally + * conservative — when in doubt it returns 'light' and lets the local model + * escalate explicitly (via the `escalate` tool) if it gets stuck. + */ +export type TaskWeight = 'light' | 'heavy'; + +/** + * Strong, unambiguous signals that a turn genuinely needs the frontier model. + * These rarely show up in routine one-line requests. + */ +const HEAVY_PATTERNS: RegExp[] = [ + /\brefactor(?:ing|ed)?\b/i, + /\barchitect(?:ure|ural)?\b/i, + /\bmigrat(?:e|ion|ing)\b/i, + /\bredesign\b/i, + /\broot[- ]?cause\b/i, + /\bthink (?:hard|carefully|through|deeply)\b/i, + /\bacross (?:the |multiple |several )?(?:files|modules|codebase)\b/i, + /\bend[- ]to[- ]end\b/i, +]; + +/** + * Verbs that signal a heavy task only when paired with a scope/complexity cue. + * On their own — "implement a getter", "debug this typo", "optimize the loop" — + * they're everyday coding and stay local; eagerly escalating them would blunt + * the local-first cost savings. The local model can still escalate itself via + * the `escalate` tool when it actually struggles. + */ +const AMBIGUOUS_VERBS = /\b(?:implement(?:s|ing|ed)?|debug(?:ging|ged)?|optimi[sz]e|design)\b/i; +const SCOPE_CUES = + /\b(?:entire|whole|complete(?:ly)?|across|multiple|several|system|subsystem|pipeline|codebase|module|from scratch)\b/i; + +/** Number of file-path-looking tokens above which a turn is considered heavy. */ +const MULTI_FILE_THRESHOLD = 3; +/** Character length above which a turn is considered heavy. */ +const LONG_INPUT_CHARS = 600; + +/** Classify a user turn as 'light' (local) or 'heavy' (escalate to frontier). */ +export function classifyTurn(input: string): TaskWeight { + const text = input.trim(); + if (text.length >= LONG_INPUT_CHARS) return 'heavy'; + if (HEAVY_PATTERNS.some((re) => re.test(text))) return 'heavy'; + + const fileMentions = text.match(/[\w./-]+\.[a-z]{1,5}\b/gi) ?? []; + if (fileMentions.length >= MULTI_FILE_THRESHOLD) return 'heavy'; + + // Ambiguous verbs escalate only alongside a scope/complexity cue. + if (AMBIGUOUS_VERBS.test(text) && SCOPE_CUES.test(text)) return 'heavy'; + + return 'light'; +} diff --git a/src/agent/systemPrompt.ts b/src/agent/systemPrompt.ts index 2713379..bfd9afd 100644 --- a/src/agent/systemPrompt.ts +++ b/src/agent/systemPrompt.ts @@ -4,8 +4,12 @@ export interface SystemPromptParams { cwd: string; projectContext: string; tools: ToolSchema[]; + /** When true, this model is the cheap/local model in a local-first setup. */ + escalation?: boolean; } +const ESCALATION_GUIDANCE = `Cost-aware routing: you are running as a fast, low-cost model. Handle routine work yourself — reading, searching, listing, and small, well-scoped edits. If a task needs deep reasoning, a large or multi-file refactor, tricky debugging, or you find yourself stuck or uncertain, call the \`escalate\` tool with a brief reason to hand off to a more capable model. Prefer escalating early over guessing.`; + const BASE_PERSONA = `You are a precise, autonomous coding agent operating in a terminal. Guidelines: @@ -27,6 +31,10 @@ export function buildSystemPrompt(params: SystemPromptParams): string { `Available tools:\n${toolList}`, ]; + if (params.escalation) { + sections.push(ESCALATION_GUIDANCE); + } + if (params.projectContext.trim().length > 0) { sections.push( `Project-specific instructions (from the project's context file):\n\n${params.projectContext.trim()}`, diff --git a/src/cli.ts b/src/cli.ts index 2ed0237..19112bb 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -12,8 +12,8 @@ Usage: tiny-code [options] Options: - --provider anthropic | gemini (default: inferred from API keys) - --model Model id override + --provider anthropic | gemini | ollama (default: inferred from API keys) + --model Model id override (e.g. claude-opus-4-8, gemma3:12b) --config Path to a config JSON file -v, --version Print version -h, --help Show this help @@ -21,8 +21,13 @@ Options: Environment: ANTHROPIC_API_KEY Required for the Anthropic provider GEMINI_API_KEY Required for the Gemini provider + TINY_CODE_OLLAMA_URL Ollama OpenAI-compatible base URL (default http://localhost:11434/v1) TINY_CODE_PRIORITY performance | cost | balanced — auto-picks a model when none is pinned (default: performance) + +Cost-saving: set "routing": "local-first" with an "escalateTo" target in your +config to run cheap/local models by default and escalate heavy tasks. Run /costs +in the session for usage and tips. `; function main(): void { diff --git a/src/config/load.ts b/src/config/load.ts index f1f1ed0..4d25096 100644 --- a/src/config/load.ts +++ b/src/config/load.ts @@ -5,10 +5,18 @@ import { z } from 'zod'; import type { Priority } from '../models/catalog.js'; import { recommendModel } from '../models/catalog.js'; -export type Provider = 'anthropic' | 'gemini'; +export type Provider = 'anthropic' | 'gemini' | 'ollama'; export type Effort = 'low' | 'medium' | 'high' | 'xhigh' | 'max'; +export type Routing = 'local-first' | 'off'; export type { Priority } from '../models/catalog.js'; +/** A frontier model to escalate heavy tasks to under local-first routing. */ +export interface EscalateTarget { + provider: Provider; + model: string; + ollamaBaseUrl?: string | undefined; +} + /** Auto-approval rules that bypass the interactive permission prompt. */ export interface AllowRules { /** Tool names that never prompt (in addition to read-only tools). */ @@ -26,10 +34,16 @@ export interface ResolvedConfig { priority: Priority; anthropicApiKey: string | undefined; geminiApiKey: string | undefined; + /** OpenAI-compatible base URL for the Ollama provider. */ + ollamaBaseUrl: string; maxTokens: number; thinking: boolean; effort: Effort; maxIterations: number; + /** 'local-first' starts turns on the cheap model and escalates heavy ones. */ + routing: Routing; + /** Frontier model heavy tasks escalate to (only used when routing is 'local-first'). */ + escalateTo: EscalateTarget | undefined; commandDirs: string[]; allow: AllowRules; improve: ImproveConfig; @@ -54,17 +68,29 @@ export interface CliOverrides { const DEFAULT_MODELS: Record = { anthropic: 'claude-opus-4-8', gemini: 'gemini-2.5-pro', + ollama: 'qwen2.5-coder:7b', }; +const DEFAULT_OLLAMA_URL = 'http://localhost:11434/v1'; + +const EscalateTargetSchema = z.object({ + provider: z.enum(['anthropic', 'gemini', 'ollama']), + model: z.string(), + ollamaBaseUrl: z.string().url().optional(), +}); + const FileConfigSchema = z .object({ - provider: z.enum(['anthropic', 'gemini']).optional(), + provider: z.enum(['anthropic', 'gemini', 'ollama']).optional(), model: z.string().optional(), + ollamaBaseUrl: z.string().url().optional(), priority: z.enum(['performance', 'cost', 'balanced']).optional(), maxTokens: z.number().int().positive().optional(), thinking: z.boolean().optional(), effort: z.enum(['low', 'medium', 'high', 'xhigh', 'max']).optional(), maxIterations: z.number().int().positive().optional(), + routing: z.enum(['local-first', 'off']).optional(), + escalateTo: EscalateTargetSchema.optional(), commandDirs: z.array(z.string()).optional(), allow: z .object({ @@ -131,6 +157,12 @@ export function loadConfig(overrides: CliOverrides = {}, cwd: string = process.c const effort = (env.TINY_CODE_EFFORT as Effort | undefined) ?? file.effort ?? 'high'; + const ollamaBaseUrl = env.TINY_CODE_OLLAMA_URL ?? file.ollamaBaseUrl ?? DEFAULT_OLLAMA_URL; + + const escalateTo = file.escalateTo; + // Default to local-first whenever an escalation target is configured. + const routing: Routing = file.routing ?? (escalateTo ? 'local-first' : 'off'); + const defaultCommandDirs = [ join(cwd, '.agent', 'commands'), join(home, '.config', 'tiny-code', 'commands'), @@ -142,10 +174,13 @@ export function loadConfig(overrides: CliOverrides = {}, cwd: string = process.c priority, anthropicApiKey, geminiApiKey, + ollamaBaseUrl, maxTokens, thinking: file.thinking ?? true, effort, maxIterations: file.maxIterations ?? 50, + routing, + escalateTo, commandDirs: file.commandDirs ?? defaultCommandDirs, allow: { tools: file.allow?.tools ?? [], diff --git a/src/index.ts b/src/index.ts index df727ba..d10bb65 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,25 +8,35 @@ export type { AgentUI, AgentLoopOptions } from './agent/loop.js'; export { buildSystemPrompt } from './agent/systemPrompt.js'; export type { SystemPromptParams } from './agent/systemPrompt.js'; -export { createProvider, AnthropicProvider, GeminiProvider } from './providers/index.js'; +export { createProvider, AnthropicProvider, GeminiProvider, OllamaProvider } from './providers/index.js'; export type { ModelProvider, ProviderEvent, SendRequest, ToolSchema, Usage } from './providers/types.js'; +export { toOpenAiMessages, toOpenAiTools } from './providers/ollama.js'; + +export { classifyTurn } from './agent/router.js'; +export type { TaskWeight } from './agent/router.js'; +export { LocalFirstModelEngine } from './agent/decision/index.js'; +export type { ModelDecisionEngine, RouteDecision, TurnSignals, LocalFirstOptions } from './agent/decision/index.js'; +export { checkLocalModel, estimateModelRamGb, MODEL_RAM_GB } from './system/resources.js'; +export type { LocalModelCheck } from './system/resources.js'; export { ALL_TOOLS, createRegistry, toJsonSchema } from './tools/registry.js'; export type { ToolRegistry } from './tools/registry.js'; export { defineTool } from './tools/types.js'; +export { escalateTool } from './tools/escalate.js'; export type { Tool, ToolContext, ToolResult } from './tools/types.js'; export { PermissionGate } from './permissions/gate.js'; export type { PermissionPrompt, PermissionRequest, PermissionChoice } from './permissions/gate.js'; export { loadConfig } from './config/load.js'; -export type { ResolvedConfig, CliOverrides, Provider, Effort, Priority, AllowRules } from './config/load.js'; +export type { ResolvedConfig, CliOverrides, Provider, Effort, Priority, AllowRules, Routing, EscalateTarget } from './config/load.js'; export { MODEL_CATALOG, CATALOG_AS_OF, getModelInfo, estimateCostUsd, + estimateCost, formatUsd, blendedCostPerMTok, recommendModel, diff --git a/src/models/catalog.ts b/src/models/catalog.ts index a5a9c98..b7f1a94 100644 --- a/src/models/catalog.ts +++ b/src/models/catalog.ts @@ -66,6 +66,16 @@ export function estimateCostUsd(usage: Usage, info: ModelInfo): number { ); } +/** + * Estimate the USD cost of a token usage for a model id, or `null` when the + * model isn't in the catalog — e.g. a local/Ollama model that has no API price. + * A `null` means "no known price", not "free"; callers decide how to present it. + */ +export function estimateCost(modelId: string, usage: Usage): number | null { + const info = getModelInfo(modelId); + return info ? estimateCostUsd(usage, info) : null; +} + /** Format a USD amount with precision that stays readable for tiny costs. */ export function formatUsd(amount: number): string { return `$${amount.toFixed(amount < 1 ? 4 : 2)}`; diff --git a/src/providers/index.ts b/src/providers/index.ts index 3ac08ec..89c6b3f 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -2,10 +2,12 @@ import type { ModelProvider } from './types.js'; import type { ResolvedConfig } from '../config/load.js'; import { AnthropicProvider } from './anthropic.js'; import { GeminiProvider } from './gemini.js'; +import { OllamaProvider } from './ollama.js'; export type { ModelProvider, ProviderEvent, SendRequest, ToolSchema, Usage } from './types.js'; export { AnthropicProvider } from './anthropic.js'; export { GeminiProvider } from './gemini.js'; +export { OllamaProvider } from './ollama.js'; /** Construct the configured provider, validating that its API key is present. */ export function createProvider(config: ResolvedConfig): ModelProvider { @@ -22,6 +24,15 @@ export function createProvider(config: ResolvedConfig): ModelProvider { }); } + if (config.provider === 'ollama') { + // No API key required — Ollama runs locally. + return new OllamaProvider({ + baseUrl: config.ollamaBaseUrl, + model: config.model, + maxTokens: config.maxTokens, + }); + } + if (!config.geminiApiKey) { throw new Error('GEMINI_API_KEY is not set. Export it or switch providers with --provider anthropic.'); } diff --git a/src/providers/ollama.ts b/src/providers/ollama.ts new file mode 100644 index 0000000..79f0ca8 --- /dev/null +++ b/src/providers/ollama.ts @@ -0,0 +1,272 @@ +import type { Message } from '../agent/types.js'; +import type { ModelProvider, ProviderEvent, SendRequest, ToolSchema } from './types.js'; + +export interface OllamaProviderOptions { + /** OpenAI-compatible base URL, e.g. "http://localhost:11434/v1". */ + baseUrl: string; + model: string; + /** Ignored by Ollama but required by the OpenAI wire format; defaults to "ollama". */ + apiKey?: string; + /** Cap on tokens to generate per response. Omitted from the request if unset. */ + maxTokens?: number; + /** + * Abort the request if no bytes arrive for this long (ms). This is an *idle* + * timeout, reset on every received chunk — a slow-but-progressing model keeps + * going; a hung one (common when the machine is RAM-starved) is cut loose. + * Defaults to 120_000. + */ + timeoutMs?: number; +} + +interface OpenAiMessage { + role: 'system' | 'user' | 'assistant' | 'tool'; + content: string; + tool_calls?: { id: string; type: 'function'; function: { name: string; arguments: string } }[]; + tool_call_id?: string; +} + +/** + * Translate internal messages into OpenAI chat messages (the shape Ollama's + * `/v1/chat/completions` endpoint accepts). Unlike Gemini, OpenAI correlates + * tool results to calls by `tool_call_id`, and our Anthropic-style ids survive + * the round trip — so no id synthesis is needed. + * + * Assumes the loop never mixes plain text and tool results in one user turn in a + * way that would interleave them: we emit all `tool` messages first, then any + * text as a trailing user message. OpenAI requires each `tool` message to follow + * the assistant `tool_calls` that produced it; today's loop builds messages so + * that holds. If a future change interleaves them, revisit this ordering. + */ +export function toOpenAiMessages(messages: Message[]): OpenAiMessage[] { + const out: OpenAiMessage[] = []; + for (const m of messages) { + if (m.role === 'user') { + // A user turn may carry plain text and/or tool results; emit each result + // as its own `tool` message and gather any text into one user message. + let text = ''; + for (const b of m.content) { + if (b.type === 'text') text += b.text; + else if (b.type === 'tool_result') { + out.push({ role: 'tool', tool_call_id: b.toolUseId, content: b.content }); + } + } + if (text.length > 0) out.push({ role: 'user', content: text }); + continue; + } + + // assistant: merge text + tool_use into a single message + let text = ''; + const toolCalls: NonNullable = []; + for (const b of m.content) { + if (b.type === 'text') text += b.text; + else if (b.type === 'tool_use') { + toolCalls.push({ + id: b.id, + type: 'function', + function: { name: b.name, arguments: JSON.stringify(b.input ?? {}) }, + }); + } + } + const msg: OpenAiMessage = { role: 'assistant', content: text }; + if (toolCalls.length > 0) msg.tool_calls = toolCalls; + out.push(msg); + } + return out; +} + +/** Translate normalized tool schemas into OpenAI's `tools` array. */ +export function toOpenAiTools(tools: ToolSchema[]): unknown[] { + return tools.map((t) => ({ + type: 'function', + function: { name: t.name, description: t.description, parameters: t.jsonSchema }, + })); +} + +interface StreamChoice { + delta?: { + content?: string | null; + tool_calls?: { + index: number; + id?: string; + function?: { name?: string; arguments?: string }; + }[]; + }; + finish_reason?: string | null; +} + +interface StreamChunk { + choices?: StreamChoice[]; + usage?: { prompt_tokens?: number; completion_tokens?: number } | null; +} + +export class OllamaProvider implements ModelProvider { + readonly name = 'ollama' as const; + readonly model: string; + private readonly baseUrl: string; + private readonly apiKey: string; + private readonly maxTokens: number | undefined; + private readonly timeoutMs: number; + + constructor(opts: OllamaProviderOptions) { + this.baseUrl = opts.baseUrl.replace(/\/$/, ''); + this.model = opts.model; + this.apiKey = opts.apiKey ?? 'ollama'; + this.maxTokens = opts.maxTokens; + this.timeoutMs = opts.timeoutMs ?? 120_000; + } + + async *send(req: SendRequest): AsyncIterable { + const messages: OpenAiMessage[] = [ + { role: 'system', content: req.system }, + ...toOpenAiMessages(req.messages), + ]; + + const body = { + model: this.model, + messages, + tools: req.tools.length > 0 ? toOpenAiTools(req.tools) : undefined, + stream: true, + max_tokens: this.maxTokens, + }; + + // Idle-timeout guard: abort if the server goes silent for `timeoutMs`. The + // raw fetch (unlike the cloud SDKs) has no built-in timeout, so without this + // a stuck local model would freeze the REPL with no way to recover. + const controller = new AbortController(); + let timer: ReturnType; + const armTimer = (): void => { + clearTimeout(timer); + timer = setTimeout(() => controller.abort(), this.timeoutMs); + }; + armTimer(); + + try { + let res: Response; + try { + // `stream_options.include_usage` is best-effort: it gives us token counts, + // but older Ollama builds reject unknown body fields with a 400. Rather than + // breaking every local turn over a reporting nicety, retry once without it. + res = await this.post({ ...body, stream_options: { include_usage: true } }, controller.signal); + if (res.status === 400) res = await this.post(body, controller.signal); + } catch (err) { + if (controller.signal.aborted) throw this.timeoutError(); + throw new Error( + `Cannot reach Ollama at ${this.baseUrl}. Is 'ollama serve' running? (${(err as Error).message})`, + ); + } + + if (!res.ok || !res.body) { + const detail = await res.text().catch(() => ''); + throw new Error(`Ollama request failed (${res.status}): ${detail.slice(0, 200)}`); + } + + // Accumulate tool calls by their streamed index; arguments arrive in fragments. + const calls = new Map(); + let usage = { inputTokens: 0, outputTokens: 0 }; + let finish = 'stop'; + + try { + for await (const chunk of parseSse(res.body)) { + armTimer(); // progress: reset the idle clock + const choice = chunk.choices?.[0]; + if (choice?.delta?.content) yield { type: 'text', delta: choice.delta.content }; + + for (const tc of choice?.delta?.tool_calls ?? []) { + const acc = calls.get(tc.index) ?? { id: '', name: '', args: '' }; + if (tc.id) acc.id = tc.id; + if (tc.function?.name) acc.name = tc.function.name; + if (tc.function?.arguments) acc.args += tc.function.arguments; + calls.set(tc.index, acc); + } + + if (choice?.finish_reason) finish = choice.finish_reason; + if (chunk.usage) { + usage = { + inputTokens: chunk.usage.prompt_tokens ?? 0, + outputTokens: chunk.usage.completion_tokens ?? 0, + }; + } + } + } catch (err) { + if (controller.signal.aborted) throw this.timeoutError(); + throw err; + } + + for (const [index, c] of [...calls.entries()].sort((a, b) => a[0] - b[0])) { + let input: unknown = {}; + try { + input = c.args.trim() ? JSON.parse(c.args) : {}; + } catch { + // Small models occasionally emit malformed JSON; degrade gracefully. + input = {}; + } + yield { type: 'tool_call', id: c.id || `ollama-call-${index}`, name: c.name, input }; + } + + yield { + type: 'done', + usage, + stopReason: calls.size > 0 ? 'tool_use' : finish, + }; + } finally { + clearTimeout(timer!); + } + } + + private timeoutError(): Error { + return new Error( + `Ollama at ${this.baseUrl} went silent for ${Math.round(this.timeoutMs / 1000)}s and was aborted. ` + + `The model '${this.model}' may be too large for this machine.`, + ); + } + + /** POST a chat-completions request body to the Ollama server. */ + private post(body: unknown, signal: AbortSignal): Promise { + return fetch(`${this.baseUrl}/chat/completions`, { + method: 'POST', + headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${this.apiKey}` }, + body: JSON.stringify(body), + signal, + }); + } +} + +/** Decode a single SSE line into a chunk, or `undefined` for non-data/keep-alive lines. */ +function parseSseLine(raw: string): StreamChunk | undefined { + const line = raw.trim(); + if (!line.startsWith('data:')) return undefined; + const payload = line.slice(5).trim(); + if (payload === '[DONE]' || payload.length === 0) return undefined; + try { + return JSON.parse(payload) as StreamChunk; + } catch { + // Ignore partial/non-JSON keep-alive lines. + return undefined; + } +} + +/** Parse an SSE byte stream into decoded JSON chunks, skipping the `[DONE]` sentinel. */ +async function* parseSse(body: ReadableStream): AsyncIterable { + const decoder = new TextDecoder(); + let buffer = ''; + const reader = body.getReader(); + try { + for (;;) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + let nl: number; + while ((nl = buffer.indexOf('\n')) !== -1) { + const chunk = parseSseLine(buffer.slice(0, nl)); + buffer = buffer.slice(nl + 1); + if (chunk) yield chunk; + } + } + // Emit a final line that arrived without a trailing newline (e.g. a closing + // usage frame); otherwise the last chunk's token counts would be dropped. + const tail = parseSseLine(buffer); + if (tail) yield tail; + } finally { + reader.releaseLock(); + } +} diff --git a/src/providers/types.ts b/src/providers/types.ts index 45262f3..c18443e 100644 --- a/src/providers/types.ts +++ b/src/providers/types.ts @@ -34,7 +34,7 @@ export interface SendRequest { * {@link ProviderEvent}. */ export interface ModelProvider { - readonly name: 'anthropic' | 'gemini'; + readonly name: 'anthropic' | 'gemini' | 'ollama'; readonly model: string; send(req: SendRequest): AsyncIterable; } diff --git a/src/repl.ts b/src/repl.ts index bbb4500..41f9b71 100644 --- a/src/repl.ts +++ b/src/repl.ts @@ -1,13 +1,18 @@ import * as readline from 'node:readline'; import pc from 'picocolors'; import { createTerminalUI } from './ui/render.js'; +import type { TerminalUI } from './ui/render.js'; import { AgentLoop } from './agent/loop.js'; import { PermissionGate } from './permissions/gate.js'; import type { PermissionPrompt } from './permissions/gate.js'; -import { createRegistry } from './tools/registry.js'; +import { ALL_TOOLS, createRegistry } from './tools/registry.js'; +import { escalateTool } from './tools/escalate.js'; import { createProvider } from './providers/index.js'; +import { LocalFirstModelEngine } from './agent/decision/index.js'; +import type { ModelDecisionEngine } from './agent/decision/index.js'; +import { checkLocalModel } from './system/resources.js'; import { loadConfig } from './config/load.js'; -import type { CliOverrides } from './config/load.js'; +import type { CliOverrides, ResolvedConfig } from './config/load.js'; import { loadProjectContext } from './config/context.js'; import { buildSystemPrompt } from './agent/systemPrompt.js'; import { loadCommands, renderCommand } from './commands/loader.js'; @@ -23,9 +28,32 @@ import { } from './models/catalog.js'; import type { Usage } from './providers/types.js'; +const COST_TIPS = [ + 'Let the local model handle searches, listing, and small edits; save the frontier model for heavy lifting.', + 'Keep requests focused — narrow context means fewer input tokens.', + 'For big refactors or tricky bugs, let routing escalate rather than forcing the local model.', + 'Use smaller models (e.g. gemma3:4b, qwen2.5-coder:7b) for boilerplate; reserve 12B+ for reasoning.', + 'Lower the Anthropic `effort` setting for simple tasks to cut output tokens.', +]; + +function printCosts(ui: TerminalUI, config: ResolvedConfig): void { + const t = ui.getTotals(); + console.log(pc.bold('\nSession usage:')); + console.log(` Tokens ${t.inputTokens} in / ${t.outputTokens} out`); + console.log(` Est cost ${formatUsd(t.cost)} (cloud turns only; local models are free)`); + const routing = + config.routing === 'local-first' && config.escalateTo + ? `local-first · ${config.provider}:${config.model} → ${config.escalateTo.provider}:${config.escalateTo.model}` + : `${config.provider}:${config.model}`; + console.log(` Routing ${routing}`); + console.log(pc.bold('\nTips to cut cost:')); + for (const tip of COST_TIPS) console.log(` • ${pc.dim(tip)}`); +} + function printHelp(commands: Map): void { console.log(pc.bold('\nBuilt-in:')); console.log(' /help Show this help'); + console.log(' /costs Show token usage, est. cost, and cost-saving tips'); console.log(' /clear Clear the conversation history and start fresh'); console.log(' /models Show known models, pricing, and the active one'); console.log(' /improve Reflect on this session and propose an improvement PR'); @@ -72,9 +100,32 @@ export async function startRepl(overrides: CliOverrides): Promise { const cwd = process.cwd(); const config = loadConfig(overrides, cwd); const provider = createProvider(config); // throws with a clear message if the API key is missing - const registry = createRegistry(); + + // Local-first routing: build the frontier provider and expose the `escalate` tool. + const localFirst = config.routing === 'local-first' && config.escalateTo !== undefined; + const escalationProvider = localFirst + ? createProvider({ + ...config, + provider: config.escalateTo!.provider, + model: config.escalateTo!.model, + ollamaBaseUrl: config.escalateTo!.ollamaBaseUrl ?? config.ollamaBaseUrl, + }) + : undefined; + + // The decision engine owns all model-selection policy; the loop only runs it. + const engine: ModelDecisionEngine | undefined = + localFirst && escalationProvider + ? new LocalFirstModelEngine({ primary: provider, escalation: escalationProvider }) + : undefined; + + const registry = createRegistry(localFirst ? [...ALL_TOOLS, escalateTool] : undefined); const projectContext = loadProjectContext(cwd); - const system = buildSystemPrompt({ cwd, projectContext, tools: registry.toSchemas() }); + const system = buildSystemPrompt({ + cwd, + projectContext, + tools: registry.toSchemas(), + escalation: localFirst, + }); const commands = loadCommands(config.commandDirs); const rl = readline.createInterface({ input: process.stdin, output: process.stdout }); @@ -92,7 +143,7 @@ export async function startRepl(overrides: CliOverrides): Promise { const gate = new PermissionGate(config.allow, prompt); const modelInfo = getModelInfo(config.model); - const ui = createTerminalUI(modelInfo); + const ui = createTerminalUI({ model: provider.model, provider: provider.name }); const agent = new AgentLoop({ provider, registry, @@ -101,6 +152,7 @@ export async function startRepl(overrides: CliOverrides): Promise { ui, cwd, maxIterations: config.maxIterations, + engine, }); // Tracks the transcript length at the last reflection, so the auto-trigger on @@ -141,17 +193,40 @@ export async function startRepl(overrides: CliOverrides): Promise { } }; + const routeNote = localFirst + ? pc.dim(` → escalates to ${config.escalateTo!.provider}:${config.escalateTo!.model}`) + : ''; const priceTag = modelInfo ? ` · $${modelInfo.inputPricePerMTok}/$${modelInfo.outputPricePerMTok} per 1M in/out` : ''; console.log( pc.bold('tiny-code') + - pc.dim(` · ${provider.name}:${provider.model}${priceTag} · ${cwd}`), + pc.dim(` · ${provider.name}:${provider.model}${priceTag} · ${cwd}`) + + routeNote, ); + + // Compute-cost advisory for local models: does this machine have the RAM? + if (provider.name === 'ollama') { + const check = checkLocalModel(provider.model); + const ramLine = `~${check.needGb}GB needed · ${check.freeGb}GB free / ${check.totalGb}GB total`; + if (check.warn) { + console.log( + pc.yellow(`⚠ ${provider.model} may exceed available memory (${ramLine}). Expect slow or failed runs.`), + ); + } else { + console.log(pc.dim(`Local model: ${ramLine}. No API cost.`)); + } + if (check.toolCallRisk) { + console.log( + pc.yellow('⚠ Small models (≤3B) often tool-call unreliably; prefer gemma3:4b+ or qwen2.5-coder:7b for agentic work.'), + ); + } + } + if (projectContext.trim().length > 0) { console.log(pc.dim('Loaded project context.')); } - console.log(pc.dim('Type a request, /help for commands, /exit to quit.')); + console.log(pc.dim('Type a request, /help for commands, /costs for usage, /exit to quit.')); const handle = async (line: string): Promise => { const input = line.trim(); @@ -169,6 +244,11 @@ export async function startRepl(overrides: CliOverrides): Promise { ask(); return; } + if (input === '/costs') { + printCosts(ui, config); + ask(); + return; + } if (input === '/clear') { agent.clearHistory(); console.log(pc.dim('Conversation history cleared.')); diff --git a/src/system/resources.ts b/src/system/resources.ts new file mode 100644 index 0000000..f9522da --- /dev/null +++ b/src/system/resources.ts @@ -0,0 +1,87 @@ +import { totalmem, freemem } from 'node:os'; + +/** + * Approximate memory needed to run common local models at ~Q4 quantization, + * in GB (weights + a modest KV cache / runtime overhead). These are guidelines + * for the startup advisory, not exact figures — long contexts need more. + */ +export const MODEL_RAM_GB: Record = { + 'gemma3:1b': 1, + 'gemma3:4b': 3, + 'gemma3:12b': 7, + 'gemma3:27b': 16, + 'qwen2.5-coder:1.5b': 2, + 'qwen2.5-coder:7b': 5, + 'qwen2.5-coder:14b': 9, + 'qwen2.5-coder:32b': 18, + 'llama3.2:3b': 3, + 'llama3.1:8b': 6, +}; + +const GB = 1024 ** 3; + +/** Parse a parameter count (in billions) out of a model tag like "gemma3:12b". */ +export function parseParamsB(model: string): number | undefined { + const match = model.match(/(\d+(?:\.\d+)?)\s*b\b/i); + return match ? Number(match[1]) : undefined; +} + +/** Estimate RAM (GB) for a model: explicit table first, else a size-based guess. */ +export function estimateModelRamGb(model: string): number { + const known = MODEL_RAM_GB[model.toLowerCase()]; + if (known !== undefined) return known; + const params = parseParamsB(model); + // ~0.6 GB per billion params at Q4, plus ~1.5 GB runtime/KV-cache overhead. + return params !== undefined ? Math.round(params * 0.6 + 1.5) : 4; +} + +/** + * Fraction of total RAM a model may need before we warn. Leaves headroom for the + * OS and other apps; a model that wants nearly all of physical RAM will thrash. + */ +const CAPACITY_HEADROOM = 0.8; + +export interface LocalModelCheck { + needGb: number; + totalGb: number; + freeGb: number; + /** True when the model likely won't fit in this machine's RAM (capacity-based). */ + warn: boolean; + /** + * Soft hint: the model exceeds *currently free* memory. On Linux `free` is + * misleadingly low (most RAM is reclaimable cache), so this is advisory only — + * never the basis for the hard {@link warn}. + */ + freeTight: boolean; + /** True for small models (≤3B) that tool-call unreliably. */ + toolCallRisk: boolean; +} + +/** + * Compare a local model's memory footprint against the host's RAM. The hard + * warning is capacity-based (`totalmem`), since that is what actually determines + * feasibility — Linux reports little "free" memory because it caches aggressively, + * so a free-memory test would spuriously warn on machines that run the model fine. + * `mem` defaults to the live host readings but can be injected for testing. + */ +export function checkLocalModel( + model: string, + mem: { total: number; free: number } = { total: totalmem(), free: freemem() }, +): LocalModelCheck { + const needGb = estimateModelRamGb(model); + const totalGb = mem.total / GB; + const freeGb = mem.free / GB; + const params = parseParamsB(model); + return { + needGb, + totalGb: round1(totalGb), + freeGb: round1(freeGb), + warn: needGb > totalGb * CAPACITY_HEADROOM, + freeTight: needGb > freeGb, + toolCallRisk: params !== undefined && params <= 3, + }; +} + +function round1(n: number): number { + return Math.round(n * 10) / 10; +} diff --git a/src/tools/escalate.ts b/src/tools/escalate.ts new file mode 100644 index 0000000..2214364 --- /dev/null +++ b/src/tools/escalate.ts @@ -0,0 +1,26 @@ +import { z } from 'zod'; +import { defineTool } from './types.js'; + +/** + * A signal tool, not a worker. When local-first routing is active and the + * current (cheap/local) model decides a task needs deeper reasoning — a large + * or multi-file refactor, tricky debugging, or it is simply stuck — it calls + * this tool. The agent loop watches for it and swaps in the configured frontier + * model for the rest of the turn, with full conversation context preserved. + * The tool itself just acknowledges; the loop performs the handoff. + */ +export const escalateTool = defineTool({ + name: 'escalate', + description: + 'Hand off the current task to a more capable model when it needs deep reasoning, a large or multi-file change, tricky debugging, or you are stuck. Prefer escalating early over guessing. Provide a brief reason.', + mutating: false, + schema: z.object({ + reason: z.string().describe('A brief reason the task needs a more capable model.'), + }), + async execute(input) { + return { + output: `Escalation acknowledged (${input.reason}). A more capable model will continue this task.`, + summary: 'escalating', + }; + }, +}); diff --git a/src/ui/render.ts b/src/ui/render.ts index 0f613bf..ea87e44 100644 --- a/src/ui/render.ts +++ b/src/ui/render.ts @@ -1,8 +1,8 @@ import pc from 'picocolors'; import type { AgentUI } from '../agent/loop.js'; import type { ToolResult } from '../tools/types.js'; -import type { ModelInfo } from '../models/catalog.js'; -import { estimateCostUsd, formatUsd } from '../models/catalog.js'; +import type { Usage } from '../providers/types.js'; +import { getModelInfo, estimateCostUsd, formatUsd } from '../models/catalog.js'; function preview(name: string, input: unknown): string { const obj = (input ?? {}) as Record; @@ -12,21 +12,46 @@ function preview(name: string, input: unknown): string { return JSON.stringify(obj); } -function fmtN(n: number): string { - return n.toLocaleString('en-US'); -} - function truncate(s: string, n: number): string { const oneLine = s.replace(/\s*\n\s*/g, ' ').trim(); return oneLine.length > n ? `${oneLine.slice(0, n)}…` : oneLine; } -/** - * Minimal streaming UI: assistant text inline, compact colored tool summaries. - * Pass the active model's catalog info to also show a per-turn cost estimate. - */ -export function createTerminalUI(modelInfo?: ModelInfo): AgentUI { +/** Compact token count, e.g. 1234 -> "1.2k". */ +function fmtTokens(n: number): string { + return n >= 1000 ? `${(n / 1000).toFixed(1)}k` : String(n); +} + +/** Paid (non-local) providers, where missing pricing means "unknown" not "free". */ +function isCloud(provider?: string): boolean { + return provider === 'anthropic' || provider === 'gemini'; +} + +export interface SessionTotals { + inputTokens: number; + outputTokens: number; + /** Accumulated USD across priced (cloud) turns. */ + cost: number; +} + +export interface TerminalUI extends AgentUI { + /** Cumulative token + cost totals for the session (used by /costs). */ + getTotals(): SessionTotals; +} + +export interface TerminalUIOptions { + /** Default model id, used to price usage when the loop doesn't supply one. */ + model?: string; + provider?: string; + /** Print the per-turn usage line. Default true; set false to stay silent. */ + showUsage?: boolean; +} + +/** Minimal streaming UI: assistant text inline, compact colored tool summaries. */ +export function createTerminalUI(opts: TerminalUIOptions = {}): TerminalUI { + const showUsage = opts.showUsage ?? true; let atLineStart = true; + const totals: SessionTotals = { inputTokens: 0, outputTokens: 0, cost: 0 }; const write = (s: string): void => { if (s.length === 0) return; @@ -55,11 +80,31 @@ export function createTerminalUI(modelInfo?: ModelInfo): AgentUI { ensureNewline(); write(pc.yellow(` ⊘ ${name} denied\n`)); }, - onUsage(usage) { - const cost = modelInfo ? ` ${formatUsd(estimateCostUsd(usage, modelInfo))}` : ''; - write( - pc.dim(` ↑ ${fmtN(usage.inputTokens)} ↓ ${fmtN(usage.outputTokens)} tokens${cost}\n`), - ); + onUsage(usage: Usage, model?: string, provider?: string) { + totals.inputTokens += usage.inputTokens; + totals.outputTokens += usage.outputTokens; + const info = getModelInfo(model ?? opts.model ?? ''); + const cost = info ? estimateCostUsd(usage, info) : null; + if (cost !== null) totals.cost += cost; + + if (!showUsage) return; + ensureNewline(); + const tokens = `${fmtTokens(usage.inputTokens)} in / ${fmtTokens(usage.outputTokens)} out`; + let money: string; + if (cost !== null) { + money = `${formatUsd(cost)} turn · ${formatUsd(totals.cost)} session`; + } else if (isCloud(provider ?? opts.provider)) { + // A paid cloud model we don't have pricing for — don't imply it was free. + money = 'cost unknown'; + } else { + money = 'local (no API cost)'; + } + write(pc.dim(`· ${tokens} · ${money}\n`)); + }, + onRoute(provider, model, reason, initial) { + ensureNewline(); + const verb = initial ? '▸ routed to' : '↑ escalated to'; + write(pc.yellow(`${verb} ${provider}:${model} (${reason})\n`)); }, onAssistantEnd() { ensureNewline(); @@ -68,5 +113,8 @@ export function createTerminalUI(modelInfo?: ModelInfo): AgentUI { ensureNewline(); write(pc.yellow('[Reached max iterations — stopping]\n')); }, + getTotals() { + return { ...totals }; + }, }; } diff --git a/tests/agent/decision/localFirst.test.ts b/tests/agent/decision/localFirst.test.ts new file mode 100644 index 0000000..056bb29 --- /dev/null +++ b/tests/agent/decision/localFirst.test.ts @@ -0,0 +1,150 @@ +import { describe, it, expect } from 'vitest'; +import { LocalFirstModelEngine } from '../../../src/agent/decision/localFirst.js'; +import type { TurnSignals } from '../../../src/agent/decision/types.js'; +import type { ModelProvider, ProviderEvent, SendRequest } from '../../../src/providers/types.js'; + +/** A provider stub with just the identity fields the engine reads. */ +function fakeProvider(model: string, name: ModelProvider['name']): ModelProvider { + return { + name, + model, + // eslint-disable-next-line require-yield + async *send(_req: SendRequest): AsyncIterable { + return; + }, + }; +} + +const local = fakeProvider('qwen2.5-coder:7b', 'ollama'); +const frontier = fakeProvider('claude-opus-4-8', 'anthropic'); + +function signals(overrides: Partial = {}): TurnSignals { + return { + escalateRequested: false, + consecutiveErrors: 0, + alreadyEscalated: false, + current: local, + iteration: 0, + ...overrides, + }; +} + +describe('LocalFirstModelEngine', () => { + const fits = { warn: false }; + + it('routes a heavy turn to the frontier up front', () => { + const engine = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + classify: () => 'heavy', + ramCheck: () => fits, + }); + expect(engine.selectInitial('refactor everything')).toEqual({ + provider: frontier, + reason: 'heavy task', + }); + }); + + it('keeps a light turn on the primary (no reason)', () => { + const engine = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + classify: () => 'light', + ramCheck: () => fits, + }); + expect(engine.selectInitial('list files')).toEqual({ provider: local }); + }); + + it('escalates a light turn when the local model exceeds RAM (compute awareness)', () => { + const engine = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + classify: () => 'light', + ramCheck: () => ({ warn: true }), + }); + expect(engine.selectInitial('list files')).toEqual({ + provider: frontier, + reason: 'compute: local model exceeds RAM', + }); + }); + + it('does not apply the RAM check to a non-local primary', () => { + const cloudPrimary = fakeProvider('gemini-2.5-flash', 'gemini'); + let called = false; + const engine = new LocalFirstModelEngine({ + primary: cloudPrimary, + escalation: frontier, + classify: () => 'light', + ramCheck: () => { + called = true; + return { warn: true }; + }, + }); + expect(engine.selectInitial('list files')).toEqual({ provider: cloudPrimary }); + expect(called).toBe(false); + }); + + describe('considerEscalation', () => { + const engine = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + ramCheck: () => fits, + }); + + it('escalates when the model requests it', () => { + expect(engine.considerEscalation(signals({ escalateRequested: true }))).toEqual({ + provider: frontier, + reason: 'requested by model', + }); + }); + + it('escalates once consecutive errors hit the threshold', () => { + expect(engine.considerEscalation(signals({ consecutiveErrors: 3 }))).toEqual({ + provider: frontier, + reason: 'stuck — repeated tool errors', + }); + }); + + it('stays put below the threshold and without a request', () => { + expect(engine.considerEscalation(signals({ consecutiveErrors: 2 }))).toBeUndefined(); + }); + + it('never re-routes once already escalated', () => { + expect( + engine.considerEscalation( + signals({ alreadyEscalated: true, escalateRequested: true, consecutiveErrors: 9 }), + ), + ).toBeUndefined(); + }); + + it('respects a custom stuck threshold', () => { + const eager = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + stuckThreshold: 1, + ramCheck: () => fits, + }); + expect(eager.considerEscalation(signals({ consecutiveErrors: 1 }))?.reason).toBe( + 'stuck — repeated tool errors', + ); + }); + }); + + describe('costNote (cost awareness)', () => { + it('reports added API cost when escalating from a free local model', () => { + const engine = new LocalFirstModelEngine({ primary: local, escalation: frontier }); + expect(engine.costNote()).toContain('adds API cost'); + }); + + it('reports a cost multiplier between two priced cloud models', () => { + const cheap = fakeProvider('gemini-2.5-flash', 'gemini'); + const engine = new LocalFirstModelEngine({ primary: cheap, escalation: frontier }); + expect(engine.costNote()).toMatch(/escalation costs ~\d+(\.\d+)?× the primary per token/); + }); + + it('reports no API cost when escalating to a local model', () => { + const engine = new LocalFirstModelEngine({ primary: frontier, escalation: local }); + expect(engine.costNote()).toContain('no API cost'); + }); + }); +}); diff --git a/tests/agent/loop.test.ts b/tests/agent/loop.test.ts index 026ca4c..cec1e20 100644 --- a/tests/agent/loop.test.ts +++ b/tests/agent/loop.test.ts @@ -2,8 +2,10 @@ import { describe, it, expect } from 'vitest'; import { z } from 'zod'; import { AgentLoop } from '../../src/agent/loop.js'; import type { AgentUI } from '../../src/agent/loop.js'; +import { LocalFirstModelEngine } from '../../src/agent/decision/index.js'; import { createRegistry } from '../../src/tools/registry.js'; import { defineTool } from '../../src/tools/types.js'; +import { escalateTool } from '../../src/tools/escalate.js'; import { PermissionGate } from '../../src/permissions/gate.js'; import type { PermissionChoice } from '../../src/permissions/gate.js'; import type { ModelProvider, ProviderEvent, SendRequest } from '../../src/providers/types.js'; @@ -15,11 +17,18 @@ const DONE: ProviderEvent = { }; class ScriptedProvider implements ModelProvider { - readonly name = 'anthropic' as const; - readonly model = 'fake'; + readonly name: 'anthropic' | 'gemini' | 'ollama'; + readonly model: string; readonly sent: SendRequest[] = []; - constructor(private readonly turns: ProviderEvent[][]) {} + constructor( + private readonly turns: ProviderEvent[][], + model = 'fake', + name: 'anthropic' | 'gemini' | 'ollama' = 'anthropic', + ) { + this.model = model; + this.name = name; + } async *send(req: SendRequest): AsyncIterable { this.sent.push(req); @@ -36,6 +45,7 @@ function recordingUI(): { ui: AgentUI; events: string[] } { onToolResult: (n, r) => events.push(`result:${n}:${r.output}:${r.isError ?? false}`), onToolDenied: (n) => events.push(`denied:${n}`), onUsage: () => events.push('usage'), + onRoute: (p, m, r, initial) => events.push(`route:${p}:${m}:${r}:${initial ? 'initial' : 'escalated'}`), onAssistantEnd: () => events.push('assistantEnd'), onMaxIterations: () => events.push('maxIter'), }; @@ -165,6 +175,152 @@ describe('AgentLoop', () => { } }); + it('routes a heavy turn to the escalation provider up front', async () => { + const local = new ScriptedProvider([[{ type: 'text', delta: 'local' }, DONE]], 'local'); + const frontier = new ScriptedProvider( + [[{ type: 'text', delta: 'frontier' }, DONE]], + 'big', + 'anthropic', + ); + const { ui, events } = recordingUI(); + const loop = new AgentLoop({ + provider: local, + registry, + gate: gateWith('yes'), + ui, + system: 'sys', + cwd: process.cwd(), + engine: new LocalFirstModelEngine({ primary: local, escalation: frontier, classify: () => 'heavy' }), + }); + await loop.run('refactor everything'); + + expect(frontier.sent).toHaveLength(1); + expect(local.sent).toHaveLength(0); + expect(events).toContain('route:anthropic:big:heavy task:initial'); + }); + + it('keeps a light turn on the local provider', async () => { + const local = new ScriptedProvider([[{ type: 'text', delta: 'local' }, DONE]], 'local'); + const frontier = new ScriptedProvider([[DONE]], 'big'); + const { ui, events } = recordingUI(); + const loop = new AgentLoop({ + provider: local, + registry, + gate: gateWith('yes'), + ui, + system: 'sys', + cwd: process.cwd(), + engine: new LocalFirstModelEngine({ primary: local, escalation: frontier, classify: () => 'light' }), + }); + await loop.run('list files'); + + expect(local.sent).toHaveLength(1); + expect(frontier.sent).toHaveLength(0); + expect(events).not.toContain('route:anthropic:big:heavy task'); + }); + + it('escalates mid-turn when the local model calls the escalate tool', async () => { + const escalateRegistry = createRegistry([echoTool, escalateTool]); + const local = new ScriptedProvider( + [[{ type: 'tool_call', id: 'e1', name: 'escalate', input: { reason: 'too hard' } }, DONE]], + 'local', + ); + const frontier = new ScriptedProvider( + [[{ type: 'text', delta: 'handled' }, DONE]], + 'big', + 'anthropic', + ); + const { ui, events } = recordingUI(); + const loop = new AgentLoop({ + provider: local, + registry: escalateRegistry, + gate: gateWith('yes'), + ui, + system: 'sys', + cwd: process.cwd(), + engine: new LocalFirstModelEngine({ primary: local, escalation: frontier, classify: () => 'light' }), + }); + await loop.run('start small then get stuck'); + + // First send on local, second (post-escalation) on frontier. + expect(local.sent).toHaveLength(1); + expect(frontier.sent).toHaveLength(1); + expect(events).toContain('route:anthropic:big:requested by model:escalated'); + expect(events).toContain('text:handled'); + }); + + it('auto-escalates after the engine sees repeated tool errors (stuck)', async () => { + // Three iterations of a failing (unknown) tool trip the default stuck threshold. + const failing: ProviderEvent[][] = []; + for (let i = 0; i < 3; i += 1) { + failing.push([{ type: 'tool_call', id: `g${i}`, name: 'ghost', input: {} }, DONE]); + } + const local = new ScriptedProvider(failing, 'local'); + const frontier = new ScriptedProvider([[{ type: 'text', delta: 'rescued' }, DONE]], 'big', 'anthropic'); + const { ui, events } = recordingUI(); + const loop = new AgentLoop({ + provider: local, + registry, + gate: gateWith('yes'), + ui, + system: 'sys', + cwd: process.cwd(), + engine: new LocalFirstModelEngine({ primary: local, escalation: frontier, classify: () => 'light' }), + }); + await loop.run('keep failing'); + + expect(local.sent).toHaveLength(3); + expect(frontier.sent).toHaveLength(1); + expect(events).toContain('route:anthropic:big:stuck — repeated tool errors:escalated'); + expect(events).toContain('text:rescued'); + }); + + it('stays on the frontier model for follow-up turns once escalated', async () => { + const escalateRegistry = createRegistry([echoTool, escalateTool]); + const local = new ScriptedProvider( + [[{ type: 'tool_call', id: 'e1', name: 'escalate', input: { reason: 'too hard' } }, DONE]], + 'local', + ); + const frontier = new ScriptedProvider( + [ + [{ type: 'text', delta: 'handled' }, DONE], // finishes the escalated turn + [{ type: 'text', delta: 'follow-up' }, DONE], // next turn should land here too + ], + 'big', + 'anthropic', + ); + const { ui } = recordingUI(); + const loop = new AgentLoop({ + provider: local, + registry: escalateRegistry, + gate: gateWith('yes'), + ui, + system: 'sys', + cwd: process.cwd(), + engine: new LocalFirstModelEngine({ primary: local, escalation: frontier, classify: () => 'light' }), + }); + + await loop.run('start small then get stuck'); + await loop.run('a routine follow-up'); + + // The follow-up turn never touched the local provider. + expect(local.sent).toHaveLength(1); + expect(frontier.sent).toHaveLength(2); + + // clearHistory resets stickiness: the next light turn goes back to local. + loop.clearHistory(); + await loop.run('another light request'); + expect(local.sent).toHaveLength(2); + }); + + it('behaves as a single provider when no escalation is configured', async () => { + const provider = new ScriptedProvider([[{ type: 'text', delta: 'hi' }, DONE]]); + const { ui, events } = recordingUI(); + await makeLoop(provider, ui, gateWith('yes')).run('refactor the whole codebase'); + expect(provider.sent).toHaveLength(1); + expect(events).not.toContain('route:anthropic:big:heavy task'); + }); + it('accumulates token usage across a single turn', async () => { const provider = new ScriptedProvider([ [ diff --git a/tests/agent/router.test.ts b/tests/agent/router.test.ts new file mode 100644 index 0000000..dbd6e77 --- /dev/null +++ b/tests/agent/router.test.ts @@ -0,0 +1,33 @@ +import { describe, it, expect } from 'vitest'; +import { classifyTurn } from '../../src/agent/router.js'; + +describe('classifyTurn', () => { + it('treats simple lookups and small edits as light', () => { + expect(classifyTurn('list the files in src')).toBe('light'); + expect(classifyTurn('what does this function return?')).toBe('light'); + expect(classifyTurn('rename foo to bar in utils.ts')).toBe('light'); + }); + + it('flags strong reasoning-heavy keywords as heavy', () => { + expect(classifyTurn('refactor the provider layer')).toBe('heavy'); + expect(classifyTurn('migrate the build to esbuild')).toBe('heavy'); + expect(classifyTurn('design a caching architecture')).toBe('heavy'); + expect(classifyTurn('find the root cause of the hang')).toBe('heavy'); + }); + + it('keeps routine uses of ambiguous verbs light', () => { + expect(classifyTurn('implement a getter for name')).toBe('light'); + expect(classifyTurn('debug this typo')).toBe('light'); + expect(classifyTurn('optimize the inner loop')).toBe('light'); + }); + + it('escalates ambiguous verbs only when paired with a scope cue', () => { + expect(classifyTurn('implement the auth system from scratch')).toBe('heavy'); + expect(classifyTurn('optimize rendering across the whole pipeline')).toBe('heavy'); + }); + + it('flags multi-file and long requests as heavy', () => { + expect(classifyTurn('update a.ts, b.ts, and c.ts to match')).toBe('heavy'); + expect(classifyTurn('x'.repeat(700))).toBe('heavy'); + }); +}); diff --git a/tests/config/load.test.ts b/tests/config/load.test.ts index 3d4ee5e..2da1595 100644 --- a/tests/config/load.test.ts +++ b/tests/config/load.test.ts @@ -12,6 +12,7 @@ const ENV_KEYS = [ 'TINY_CODE_PRIORITY', 'TINY_CODE_MAX_TOKENS', 'TINY_CODE_EFFORT', + 'TINY_CODE_OLLAMA_URL', 'TINY_CODE_IMPROVE', 'HOME', ]; @@ -146,4 +147,38 @@ describe('loadConfig', () => { const cfg = loadConfig({}, cwd); expect(cfg.model).toBe('from-env'); }); + + it('supports the ollama provider with its default model and base URL', () => { + const cfg = loadConfig({ provider: 'ollama' }, cwd); + expect(cfg.provider).toBe('ollama'); + expect(cfg.model).toBe('qwen2.5-coder:7b'); + expect(cfg.ollamaBaseUrl).toBe('http://localhost:11434/v1'); + }); + + it('honors TINY_CODE_OLLAMA_URL over the default', () => { + process.env.TINY_CODE_OLLAMA_URL = 'http://gpu-box:11434/v1'; + const cfg = loadConfig({ provider: 'ollama' }, cwd); + expect(cfg.ollamaBaseUrl).toBe('http://gpu-box:11434/v1'); + }); + + it('defaults routing to local-first when an escalateTo target is configured', async () => { + await writeFile( + join(cwd, 'tiny-code.config.json'), + JSON.stringify({ + provider: 'ollama', + model: 'gemma3:12b', + escalateTo: { provider: 'anthropic', model: 'claude-opus-4-8' }, + }), + ); + const cfg = loadConfig({}, cwd); + expect(cfg.routing).toBe('local-first'); + expect(cfg.escalateTo).toEqual({ provider: 'anthropic', model: 'claude-opus-4-8' }); + }); + + it('defaults routing to off with no escalateTo target', () => { + process.env.ANTHROPIC_API_KEY = 'sk-test'; + const cfg = loadConfig({}, cwd); + expect(cfg.routing).toBe('off'); + expect(cfg.escalateTo).toBeUndefined(); + }); }); diff --git a/tests/providers/ollamaSend.test.ts b/tests/providers/ollamaSend.test.ts new file mode 100644 index 0000000..1fda23d --- /dev/null +++ b/tests/providers/ollamaSend.test.ts @@ -0,0 +1,155 @@ +import { describe, it, expect, vi, afterEach } from 'vitest'; +import { OllamaProvider } from '../../src/providers/ollama.js'; +import type { ProviderEvent } from '../../src/providers/types.js'; + +/** Build a fake SSE Response body from a list of OpenAI-style chunks. */ +function sseResponse(chunks: unknown[]): Response { + const lines = chunks.map((c) => `data: ${JSON.stringify(c)}\n\n`).concat('data: [DONE]\n\n'); + const stream = new ReadableStream({ + start(controller) { + const enc = new TextEncoder(); + for (const line of lines) controller.enqueue(enc.encode(line)); + controller.close(); + }, + }); + return new Response(stream, { status: 200, headers: { 'Content-Type': 'text/event-stream' } }); +} + +afterEach(() => vi.restoreAllMocks()); + +async function collect(provider: OllamaProvider): Promise { + const events: ProviderEvent[] = []; + for await (const e of provider.send({ + system: 's', + messages: [{ role: 'user', content: [{ type: 'text', text: 'go' }] }], + tools: [{ name: 'ls', description: 'list', jsonSchema: { type: 'object' } }], + })) { + events.push(e); + } + return events; +} + +describe('OllamaProvider.send', () => { + it('maps streamed deltas into text, tool_call, and done events', async () => { + vi.spyOn(globalThis, 'fetch').mockResolvedValue( + sseResponse([ + { choices: [{ delta: { content: 'Hel' } }] }, + { choices: [{ delta: { content: 'lo' } }] }, + { + choices: [ + { + delta: { tool_calls: [{ index: 0, id: 'c1', function: { name: 'ls', arguments: '{"path":' } }] }, + }, + ], + }, + { + choices: [{ delta: { tool_calls: [{ index: 0, function: { arguments: '"."}' } }] }, finish_reason: 'tool_calls' }], + }, + { choices: [], usage: { prompt_tokens: 11, completion_tokens: 7 } }, + ]), + ); + + const provider = new OllamaProvider({ baseUrl: 'http://localhost:11434/v1', model: 'qwen2.5-coder:7b' }); + const events = await collect(provider); + + const text = events.filter((e) => e.type === 'text').map((e) => (e as { delta: string }).delta); + expect(text.join('')).toBe('Hello'); + + const call = events.find((e) => e.type === 'tool_call'); + expect(call).toMatchObject({ type: 'tool_call', id: 'c1', name: 'ls', input: { path: '.' } }); + + const done = events.find((e) => e.type === 'done'); + expect(done).toMatchObject({ + type: 'done', + stopReason: 'tool_use', + usage: { inputTokens: 11, outputTokens: 7 }, + }); + }); + + it('degrades to empty input on malformed tool-call JSON', async () => { + vi.spyOn(globalThis, 'fetch').mockResolvedValue( + sseResponse([ + { + choices: [ + { delta: { tool_calls: [{ index: 0, id: 'c1', function: { name: 'ls', arguments: '{bad' } }] } }, + ], + }, + ]), + ); + const provider = new OllamaProvider({ baseUrl: 'http://localhost:11434/v1', model: 'm' }); + const events = await collect(provider); + const call = events.find((e) => e.type === 'tool_call'); + expect(call).toMatchObject({ name: 'ls', input: {} }); + }); + + it('retries without stream_options when the server rejects it with a 400', async () => { + const fetchMock = vi + .spyOn(globalThis, 'fetch') + .mockResolvedValueOnce(new Response('unknown field "stream_options"', { status: 400 })) + .mockResolvedValueOnce(sseResponse([{ choices: [{ delta: { content: 'ok' } }] }])); + + const provider = new OllamaProvider({ baseUrl: 'http://localhost:11434/v1', model: 'm' }); + const events = await collect(provider); + + expect(fetchMock).toHaveBeenCalledTimes(2); + const firstBody = JSON.parse((fetchMock.mock.calls[0]![1] as RequestInit).body as string); + const retryBody = JSON.parse((fetchMock.mock.calls[1]![1] as RequestInit).body as string); + expect(firstBody.stream_options).toEqual({ include_usage: true }); + expect(retryBody.stream_options).toBeUndefined(); + expect(events.filter((e) => e.type === 'text').map((e) => (e as { delta: string }).delta).join('')).toBe('ok'); + }); + + it('forwards maxTokens as max_tokens, and omits it when unset', async () => { + const fetchMock = vi + .spyOn(globalThis, 'fetch') + .mockResolvedValue(sseResponse([{ choices: [{ delta: { content: 'ok' } }] }])); + + await collect(new OllamaProvider({ baseUrl: 'http://localhost:11434/v1', model: 'm', maxTokens: 256 })); + const capped = JSON.parse((fetchMock.mock.calls[0]![1] as RequestInit).body as string); + expect(capped.max_tokens).toBe(256); + + fetchMock.mockClear(); + await collect(new OllamaProvider({ baseUrl: 'http://localhost:11434/v1', model: 'm' })); + const uncapped = JSON.parse((fetchMock.mock.calls[0]![1] as RequestInit).body as string); + expect(uncapped).not.toHaveProperty('max_tokens'); + }); + + it('still parses a final usage frame that lacks a trailing newline', async () => { + const raw = + 'data: {"choices":[{"delta":{"content":"hi"}}]}\n\n' + + 'data: {"choices":[],"usage":{"prompt_tokens":3,"completion_tokens":4}}'; // no trailing \n + const stream = new ReadableStream({ + start(controller) { + controller.enqueue(new TextEncoder().encode(raw)); + controller.close(); + }, + }); + vi.spyOn(globalThis, 'fetch').mockResolvedValue( + new Response(stream, { status: 200, headers: { 'Content-Type': 'text/event-stream' } }), + ); + + const provider = new OllamaProvider({ baseUrl: 'http://localhost:11434/v1', model: 'm' }); + const done = (await collect(provider)).find((e) => e.type === 'done'); + expect(done).toMatchObject({ usage: { inputTokens: 3, outputTokens: 4 } }); + }); + + it('throws a helpful error when Ollama is unreachable', async () => { + vi.spyOn(globalThis, 'fetch').mockRejectedValue(new Error('ECONNREFUSED')); + const provider = new OllamaProvider({ baseUrl: 'http://localhost:11434/v1', model: 'm' }); + await expect(collect(provider)).rejects.toThrow(/Cannot reach Ollama/); + }); + + it('aborts and reports a timeout when the server goes silent', async () => { + // Never resolves on its own — only the idle-timeout abort can end it. + vi.spyOn(globalThis, 'fetch').mockImplementation( + (_url, init) => + new Promise((_resolve, reject) => { + (init as RequestInit).signal?.addEventListener('abort', () => + reject(new DOMException('aborted', 'AbortError')), + ); + }), + ); + const provider = new OllamaProvider({ baseUrl: 'http://localhost:11434/v1', model: 'm', timeoutMs: 20 }); + await expect(collect(provider)).rejects.toThrow(/went silent.*aborted/); + }); +}); diff --git a/tests/providers/translate.test.ts b/tests/providers/translate.test.ts index 58fd9fc..0a55560 100644 --- a/tests/providers/translate.test.ts +++ b/tests/providers/translate.test.ts @@ -1,6 +1,7 @@ import { describe, it, expect } from 'vitest'; import { toAnthropicMessages } from '../../src/providers/anthropic.js'; import { toGeminiContents } from '../../src/providers/gemini.js'; +import { toOpenAiMessages, toOpenAiTools } from '../../src/providers/ollama.js'; import type { Message } from '../../src/agent/types.js'; const conversation: Message[] = [ @@ -68,3 +69,25 @@ describe('toGeminiContents', () => { expect(out[0]!.parts).toHaveLength(0); }); }); + +describe('toOpenAiMessages', () => { + it('maps roles and correlates tool results by tool_call_id', () => { + const out = toOpenAiMessages(conversation); + + const assistant = out.find((m) => m.role === 'assistant')!; + expect(assistant.tool_calls?.[0]).toMatchObject({ + id: 'call-1', + type: 'function', + function: { name: 'ls' }, + }); + expect(assistant.tool_calls?.[0]!.function.arguments).toBe(JSON.stringify({ path: '.' })); + + const toolMsg = out.find((m) => m.role === 'tool')!; + expect(toolMsg).toMatchObject({ role: 'tool', tool_call_id: 'call-1', content: 'a.txt' }); + }); + + it('produces a function tool array', () => { + const tools = toOpenAiTools([{ name: 'ls', description: 'list', jsonSchema: { type: 'object' } }]); + expect(tools[0]).toMatchObject({ type: 'function', function: { name: 'ls', description: 'list' } }); + }); +}); diff --git a/tests/system/resources.test.ts b/tests/system/resources.test.ts new file mode 100644 index 0000000..e3b7169 --- /dev/null +++ b/tests/system/resources.test.ts @@ -0,0 +1,39 @@ +import { describe, it, expect } from 'vitest'; +import { checkLocalModel, estimateModelRamGb, parseParamsB } from '../../src/system/resources.js'; + +const GB = 1024 ** 3; + +describe('parseParamsB / estimateModelRamGb', () => { + it('extracts billions of params from a tag', () => { + expect(parseParamsB('gemma3:12b')).toBe(12); + expect(parseParamsB('qwen2.5-coder:1.5b')).toBe(1.5); + expect(parseParamsB('mystery-model')).toBeUndefined(); + }); + + it('uses the explicit table when available, else a size-based estimate', () => { + expect(estimateModelRamGb('gemma3:12b')).toBe(7); + // unknown 20b model -> 20*0.6 + 1.5 ~= 14 (rounded) + expect(estimateModelRamGb('something:20b')).toBe(Math.round(20 * 0.6 + 1.5)); + }); +}); + +describe('checkLocalModel', () => { + it('warns when the model needs more than the machine can hold', () => { + const check = checkLocalModel('gemma3:27b', { total: 8 * GB, free: 4 * GB }); // ~16GB + expect(check.warn).toBe(true); + expect(check.needGb).toBe(16); + }); + + it('does not warn when total capacity is ample and flags small-model tool risk', () => { + const check = checkLocalModel('gemma3:1b', { total: 64 * GB, free: 48 * GB }); + expect(check.warn).toBe(false); + expect(check.toolCallRisk).toBe(true); + }); + + it('does not warn on low free memory when total capacity is sufficient (Linux cache case)', () => { + // 32GB box with only 2GB nominally free — gemma3:4b (~3GB) fits in capacity. + const check = checkLocalModel('gemma3:4b', { total: 32 * GB, free: 2 * GB }); + expect(check.warn).toBe(false); + expect(check.freeTight).toBe(true); // soft hint still set + }); +}); diff --git a/tests/ui/render.test.ts b/tests/ui/render.test.ts index f443bb1..89acde2 100644 --- a/tests/ui/render.test.ts +++ b/tests/ui/render.test.ts @@ -44,8 +44,7 @@ describe('createTerminalUI', () => { const ui = createTerminalUI(); ui.onUsage({ inputTokens: 1234, outputTokens: 567 }); }); - expect(out).toContain('1,234'); - expect(out).toContain('567'); + expect(out).toContain('1.2k in / 567 out'); }); it('previews path- and pattern-based tools', () => { @@ -57,4 +56,62 @@ describe('createTerminalUI', () => { expect(out).toContain('src/x.ts'); expect(out).toContain('**/*.ts'); }); + + it('shows a cost line for cloud models and accumulates session totals', () => { + const out = capture(() => { + const ui = createTerminalUI({ model: 'claude-opus-4-8' }); + ui.onUsage({ inputTokens: 1000, outputTokens: 1000 }); + expect(ui.getTotals().inputTokens).toBe(1000); + expect(ui.getTotals().cost).toBeGreaterThan(0); + }); + expect(out).toContain('1.0k in / 1.0k out'); + expect(out).toContain('session'); + }); + + it('labels local models as having no API cost', () => { + const out = capture(() => { + const ui = createTerminalUI({ model: 'qwen2.5-coder:7b', provider: 'ollama' }); + ui.onUsage({ inputTokens: 500, outputTokens: 200 }); + expect(ui.getTotals().cost).toBe(0); + }); + expect(out).toContain('local (no API cost)'); + }); + + it('shows "cost unknown" for an unpriced cloud model rather than implying it is free', () => { + const out = capture(() => { + const ui = createTerminalUI({ provider: 'anthropic' }); + // A future/untracked cloud model id that the catalog has no pricing for. + ui.onUsage({ inputTokens: 100, outputTokens: 100 }, 'claude-opus-5'); + expect(ui.getTotals().cost).toBe(0); + }); + expect(out).toContain('cost unknown'); + expect(out).not.toContain('no API cost'); + }); + + it('stays silent when showUsage is false but still tracks totals', () => { + const out = capture(() => { + const ui = createTerminalUI({ model: 'claude-opus-4-8', showUsage: false }); + ui.onUsage({ inputTokens: 100, outputTokens: 100 }); + expect(ui.getTotals().inputTokens).toBe(100); + }); + expect(out).toBe(''); + }); + + it('renders a mid-turn escalation route line', () => { + const out = capture(() => { + const ui = createTerminalUI(); + ui.onRoute('anthropic', 'claude-opus-4-8', 'requested by model'); + }); + expect(out).toContain('escalated to anthropic:claude-opus-4-8'); + expect(out).toContain('requested by model'); + }); + + it('renders up-front routing as "routed to", not "escalated"', () => { + const out = capture(() => { + const ui = createTerminalUI(); + ui.onRoute('anthropic', 'claude-opus-4-8', 'heavy task', true); + }); + expect(out).toContain('routed to anthropic:claude-opus-4-8'); + expect(out).not.toContain('escalated'); + }); });