From 79d2536a5bfc026e05ccc1554fdb2fee4fb6b40b Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 9 Jun 2026 14:05:58 +0000 Subject: [PATCH] Abstract model-selection policy into a ModelDecisionEngine Pull the scattered routing logic (classifyTurn, escalate-tool handoff, stuck-detection) out of AgentLoop into a cohesive, injectable ModelDecisionEngine interface. AgentLoop now keeps only mechanism and delegates all "which model" decisions to the engine. - Add src/agent/decision/{types,localFirst,index}.ts: the engine interface plus LocalFirstModelEngine, the default policy. It reproduces prior behavior and adds compute awareness (routes a local model that won't fit RAM to the frontier) and cost awareness (costNote reporting). - AgentLoop takes an optional `engine` instead of `escalationProvider` + `router`; routing: 'off' injects no engine (single-provider path unchanged). - Wire the engine in repl.ts; export the new types from the public API. - Adapt loop tests and add unit tests for the engine. https://claude.ai/code/session_01T4UTQD35m11g4ChB8Cjd1w --- src/agent/decision/index.ts | 3 + src/agent/decision/localFirst.ts | 87 ++++++++++++++ src/agent/decision/types.ts | 47 ++++++++ src/agent/loop.ts | 76 ++++++------ src/index.ts | 2 + src/repl.ts | 12 +- tests/agent/decision/localFirst.test.ts | 150 ++++++++++++++++++++++++ tests/agent/loop.test.ts | 36 +++++- 8 files changed, 367 insertions(+), 46 deletions(-) create mode 100644 src/agent/decision/index.ts create mode 100644 src/agent/decision/localFirst.ts create mode 100644 src/agent/decision/types.ts create mode 100644 tests/agent/decision/localFirst.test.ts diff --git a/src/agent/decision/index.ts b/src/agent/decision/index.ts new file mode 100644 index 0000000..ea77556 --- /dev/null +++ b/src/agent/decision/index.ts @@ -0,0 +1,3 @@ +export type { ModelDecisionEngine, RouteDecision, TurnSignals } from './types.js'; +export { LocalFirstModelEngine } from './localFirst.js'; +export type { LocalFirstOptions } from './localFirst.js'; diff --git a/src/agent/decision/localFirst.ts b/src/agent/decision/localFirst.ts new file mode 100644 index 0000000..4d60c2b --- /dev/null +++ b/src/agent/decision/localFirst.ts @@ -0,0 +1,87 @@ +import type { ModelProvider } from '../../providers/types.js'; +import { classifyTurn, type TaskWeight } from '../router.js'; +import { checkLocalModel } from '../../system/resources.js'; +import { estimateCost } from '../../providers/pricing.js'; +import type { ModelDecisionEngine, RouteDecision, TurnSignals } from './types.js'; + +export interface LocalFirstOptions { + /** The cheap/local model that handles turns by default. */ + primary: ModelProvider; + /** The frontier model heavy/stuck turns escalate to. */ + escalation: ModelProvider; + /** Task-weight classifier. Defaults to {@link classifyTurn}. */ + classify?: (input: string) => TaskWeight; + /** Consecutive tool-error iterations before auto-escalating. Defaults to 3. */ + stuckThreshold?: number; + /** RAM-fit check (injectable for tests). Defaults to {@link checkLocalModel}. */ + ramCheck?: (model: string) => { warn: boolean }; +} + +/** Per-MTok probe used to compare relative model cost for the cost-aware note. */ +const COST_PROBE = { inputTokens: 1_000_000, outputTokens: 1_000_000 }; + +/** + * The default local-first policy: handle each turn on the cheap/local model and + * escalate to the frontier model only when capability demands it — the task + * looks heavy up front, the model explicitly hands off via the `escalate` tool, + * it gets stuck on repeated tool errors, or (compute awareness) a local model + * won't fit in available RAM. + * + * Cost awareness is wired in but used defensively: the engine can *report* the + * cost implication of an escalation (see {@link costNote}) but never initiates + * a cost-driven route change — escalation is always a capability decision. + */ +export class LocalFirstModelEngine implements ModelDecisionEngine { + private readonly primary: ModelProvider; + private readonly escalation: ModelProvider; + private readonly classify: (input: string) => TaskWeight; + private readonly stuckThreshold: number; + private readonly ramCheck: (model: string) => { warn: boolean }; + + constructor(opts: LocalFirstOptions) { + this.primary = opts.primary; + this.escalation = opts.escalation; + this.classify = opts.classify ?? classifyTurn; + this.stuckThreshold = opts.stuckThreshold ?? 3; + this.ramCheck = opts.ramCheck ?? checkLocalModel; + } + + selectInitial(userInput: string): RouteDecision { + // Compute awareness: a local primary that won't fit in RAM runs slowly or + // fails outright, so route it to the frontier up front — even for light work. + if (this.primary.name === 'ollama' && this.ramCheck(this.primary.model).warn) { + return { provider: this.escalation, reason: 'compute: local model exceeds RAM' }; + } + if (this.classify(userInput) === 'heavy') { + return { provider: this.escalation, reason: 'heavy task' }; + } + return { provider: this.primary }; + } + + considerEscalation(signals: TurnSignals): RouteDecision | undefined { + if (signals.alreadyEscalated) return undefined; + if (signals.escalateRequested) { + return { provider: this.escalation, reason: 'requested by model' }; + } + if (signals.consecutiveErrors >= this.stuckThreshold) { + return { provider: this.escalation, reason: 'stuck — repeated tool errors' }; + } + return undefined; + } + + /** + * A short, human-readable summary of what escalating costs, relative to the + * primary. Pure reporting — does not influence routing. Local models report + * as having no API cost. + */ + costNote(): string { + const from = estimateCost(this.primary.model, COST_PROBE); + const to = estimateCost(this.escalation.model, COST_PROBE); + if (to === null) return 'escalates to a local model (no API cost)'; + if (from === null || from === 0) { + return 'escalates from a free/local model to a paid model (adds API cost)'; + } + const multiplier = to / from; + return `escalation costs ~${multiplier.toFixed(1)}× the primary per token`; + } +} diff --git a/src/agent/decision/types.ts b/src/agent/decision/types.ts new file mode 100644 index 0000000..c0cd976 --- /dev/null +++ b/src/agent/decision/types.ts @@ -0,0 +1,47 @@ +import type { ModelProvider } from '../../providers/types.js'; + +/** + * A model-selection decision: the provider to use and an optional human-readable + * reason. An empty/absent `reason` means "no route note" — the loop only fires + * {@link AgentUI.onRoute} when a decision both changes the active provider and + * carries a reason. + */ +export interface RouteDecision { + provider: ModelProvider; + reason?: string; +} + +/** + * Runtime signals the loop feeds the engine once per iteration so it can decide + * whether to switch providers mid-turn. These are mechanical bookkeeping the + * loop already tracks; the engine owns the policy that interprets them. + */ +export interface TurnSignals { + /** The model called the `escalate` tool during this iteration. */ + escalateRequested: boolean; + /** Consecutive iterations that ended with at least one tool error. */ + consecutiveErrors: number; + /** Whether the turn has already been escalated (keeps the engine stateless). */ + alreadyEscalated: boolean; + /** The provider currently handling the turn. */ + current: ModelProvider; + /** Iteration index (0-based). */ + iteration: number; +} + +/** + * Owns all model-selection policy. {@link AgentLoop} depends on this interface + * instead of inlining classification + escalation rules, so the loop keeps only + * mechanism (send → stream → run tools → repeat) and the policy lives in one + * cohesive, testable place. Implementations may reason about task weight, cost, + * and compute (RAM) fit; the loop never sees that reasoning. + */ +export interface ModelDecisionEngine { + /** Pick the provider that starts a turn, given the user's input. */ + selectInitial(userInput: string): RouteDecision; + /** + * Decide whether to switch providers mid-turn. Returns `undefined` to stay on + * the current provider. Called once per iteration after tool results. + */ + considerEscalation(signals: TurnSignals): RouteDecision | undefined; +} diff --git a/src/agent/loop.ts b/src/agent/loop.ts index fe10bdd..d41bbd2 100644 --- a/src/agent/loop.ts +++ b/src/agent/loop.ts @@ -3,6 +3,7 @@ import type { ToolRegistry } from '../tools/registry.js'; import type { PermissionGate } from '../permissions/gate.js'; import type { ToolResult } from '../tools/types.js'; import type { Message, ToolResultBlock, ToolUseBlock } from './types.js'; +import type { ModelDecisionEngine, RouteDecision } from './decision/index.js'; /** Sink for everything the loop wants to surface. The REPL provides the real one. */ export interface AgentUI { @@ -26,15 +27,13 @@ export interface AgentLoopOptions { ui: AgentUI; cwd: string; maxIterations?: number; - /** Frontier model to escalate heavy/stuck turns to (enables local-first routing). */ - escalationProvider?: ModelProvider | undefined; - /** Classifies a turn up front so heavy tasks start on the frontier model. */ - router?: ((input: string) => 'light' | 'heavy') | undefined; + /** + * Owns model-selection policy (which model starts a turn, when to escalate). + * When omitted, the loop runs `provider` as a single provider with no routing. + */ + engine?: ModelDecisionEngine | undefined; } -/** Consecutive tool-error iterations before auto-escalating a stuck local model. */ -const STUCK_THRESHOLD = 3; - /** * The provider-agnostic, UI-agnostic agentic loop: send → stream → run tools → * feed results back → repeat until the model stops requesting tools (or the @@ -48,8 +47,7 @@ export class AgentLoop { private readonly ui: AgentUI; private readonly cwd: string; private readonly maxIterations: number; - private readonly escalationProvider: ModelProvider | undefined; - private readonly router: ((input: string) => 'light' | 'heavy') | undefined; + private readonly engine: ModelDecisionEngine | undefined; private readonly messages: Message[] = []; constructor(opts: AgentLoopOptions) { @@ -60,8 +58,7 @@ export class AgentLoop { this.ui = opts.ui; this.cwd = opts.cwd; this.maxIterations = opts.maxIterations ?? 50; - this.escalationProvider = opts.escalationProvider; - this.router = opts.router; + this.engine = opts.engine; } /** Conversation history (for inspection / persistence). */ @@ -74,10 +71,16 @@ export class AgentLoop { this.messages.push({ role: 'user', content: [{ type: 'text', text: userInput }] }); const tools = this.registry.toSchemas(); - let active = this.selectInitialProvider(userInput); - let escalated = active === this.escalationProvider; + let active = this.provider; + let escalated = false; let consecutiveErrors = 0; + if (this.engine) { + const initial = this.engine.selectInitial(userInput); + active = this.applyRoute(active, initial); + escalated = active !== this.provider; + } + for (let iteration = 0; iteration < this.maxIterations; iteration += 1) { let text = ''; const toolCalls: ToolUseBlock[] = []; @@ -106,11 +109,7 @@ export class AgentLoop { if (toolCalls.length === 0) return; - // The local model can explicitly hand off via the `escalate` tool. - if (!escalated && toolCalls.some((c) => c.name === 'escalate')) { - active = this.escalate('requested by model'); - escalated = true; - } + const escalateRequested = toolCalls.some((c) => c.name === 'escalate'); const results: ToolResultBlock[] = []; let anyError = false; @@ -121,12 +120,20 @@ export class AgentLoop { } this.messages.push({ role: 'user', content: results }); - // Auto-escalate a local model that appears stuck (repeated tool errors). - if (!escalated) { - consecutiveErrors = anyError ? consecutiveErrors + 1 : 0; - if (consecutiveErrors >= STUCK_THRESHOLD) { - active = this.escalate('stuck — repeated tool errors'); - escalated = true; + // Hand the turn's runtime signals to the engine, which owns the policy + // for whether to switch providers mid-turn (explicit escalate, stuck, …). + consecutiveErrors = anyError ? consecutiveErrors + 1 : 0; + if (this.engine) { + const next = this.engine.considerEscalation({ + escalateRequested, + consecutiveErrors, + alreadyEscalated: escalated, + current: active, + iteration, + }); + if (next) { + active = this.applyRoute(active, next); + escalated = active !== this.provider; } } } @@ -134,20 +141,15 @@ export class AgentLoop { this.ui.onMaxIterations(); } - /** Pick the provider for a turn: heavy tasks start on the frontier model. */ - private selectInitialProvider(input: string): ModelProvider { - if (this.escalationProvider && this.router && this.router(input) === 'heavy') { - this.ui.onRoute(this.escalationProvider.name, this.escalationProvider.model, 'heavy task'); - return this.escalationProvider; + /** + * Apply a routing decision: switch to its provider and surface an `onRoute` + * event when it actually changes the active provider and carries a reason. + */ + private applyRoute(active: ModelProvider, decision: RouteDecision): ModelProvider { + if (decision.provider !== active && decision.reason) { + this.ui.onRoute(decision.provider.name, decision.provider.model, decision.reason); } - return this.provider; - } - - /** Switch to the frontier provider mid-turn. Falls back to the primary if unset. */ - private escalate(reason: string): ModelProvider { - if (!this.escalationProvider) return this.provider; - this.ui.onRoute(this.escalationProvider.name, this.escalationProvider.model, reason); - return this.escalationProvider; + return decision.provider; } private async executeToolCall(call: ToolUseBlock): Promise { diff --git a/src/index.ts b/src/index.ts index 1a8ce7f..dd4558e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -16,6 +16,8 @@ export type { ModelPricing } from './providers/pricing.js'; export { classifyTurn } from './agent/router.js'; export type { TaskWeight } from './agent/router.js'; +export { LocalFirstModelEngine } from './agent/decision/index.js'; +export type { ModelDecisionEngine, RouteDecision, TurnSignals, LocalFirstOptions } from './agent/decision/index.js'; export { checkLocalModel, estimateModelRamGb, MODEL_RAM_GB } from './system/resources.js'; export type { LocalModelCheck } from './system/resources.js'; diff --git a/src/repl.ts b/src/repl.ts index 1672fdc..f1eda27 100644 --- a/src/repl.ts +++ b/src/repl.ts @@ -8,7 +8,8 @@ import type { PermissionPrompt } from './permissions/gate.js'; import { ALL_TOOLS, createRegistry } from './tools/registry.js'; import { escalateTool } from './tools/escalate.js'; import { createProvider } from './providers/index.js'; -import { classifyTurn } from './agent/router.js'; +import { LocalFirstModelEngine } from './agent/decision/index.js'; +import type { ModelDecisionEngine } from './agent/decision/index.js'; import { formatUsd } from './providers/pricing.js'; import { checkLocalModel } from './system/resources.js'; import { loadConfig } from './config/load.js'; @@ -70,6 +71,12 @@ export async function startRepl(overrides: CliOverrides): Promise { }) : undefined; + // The decision engine owns all model-selection policy; the loop only runs it. + const engine: ModelDecisionEngine | undefined = + localFirst && escalationProvider + ? new LocalFirstModelEngine({ primary: provider, escalation: escalationProvider }) + : undefined; + const registry = createRegistry(localFirst ? [...ALL_TOOLS, escalateTool] : undefined); const projectContext = loadProjectContext(cwd); const system = buildSystemPrompt({ @@ -103,8 +110,7 @@ export async function startRepl(overrides: CliOverrides): Promise { ui, cwd, maxIterations: config.maxIterations, - escalationProvider, - router: localFirst ? classifyTurn : undefined, + engine, }); const routeNote = localFirst diff --git a/tests/agent/decision/localFirst.test.ts b/tests/agent/decision/localFirst.test.ts new file mode 100644 index 0000000..056bb29 --- /dev/null +++ b/tests/agent/decision/localFirst.test.ts @@ -0,0 +1,150 @@ +import { describe, it, expect } from 'vitest'; +import { LocalFirstModelEngine } from '../../../src/agent/decision/localFirst.js'; +import type { TurnSignals } from '../../../src/agent/decision/types.js'; +import type { ModelProvider, ProviderEvent, SendRequest } from '../../../src/providers/types.js'; + +/** A provider stub with just the identity fields the engine reads. */ +function fakeProvider(model: string, name: ModelProvider['name']): ModelProvider { + return { + name, + model, + // eslint-disable-next-line require-yield + async *send(_req: SendRequest): AsyncIterable { + return; + }, + }; +} + +const local = fakeProvider('qwen2.5-coder:7b', 'ollama'); +const frontier = fakeProvider('claude-opus-4-8', 'anthropic'); + +function signals(overrides: Partial = {}): TurnSignals { + return { + escalateRequested: false, + consecutiveErrors: 0, + alreadyEscalated: false, + current: local, + iteration: 0, + ...overrides, + }; +} + +describe('LocalFirstModelEngine', () => { + const fits = { warn: false }; + + it('routes a heavy turn to the frontier up front', () => { + const engine = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + classify: () => 'heavy', + ramCheck: () => fits, + }); + expect(engine.selectInitial('refactor everything')).toEqual({ + provider: frontier, + reason: 'heavy task', + }); + }); + + it('keeps a light turn on the primary (no reason)', () => { + const engine = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + classify: () => 'light', + ramCheck: () => fits, + }); + expect(engine.selectInitial('list files')).toEqual({ provider: local }); + }); + + it('escalates a light turn when the local model exceeds RAM (compute awareness)', () => { + const engine = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + classify: () => 'light', + ramCheck: () => ({ warn: true }), + }); + expect(engine.selectInitial('list files')).toEqual({ + provider: frontier, + reason: 'compute: local model exceeds RAM', + }); + }); + + it('does not apply the RAM check to a non-local primary', () => { + const cloudPrimary = fakeProvider('gemini-2.5-flash', 'gemini'); + let called = false; + const engine = new LocalFirstModelEngine({ + primary: cloudPrimary, + escalation: frontier, + classify: () => 'light', + ramCheck: () => { + called = true; + return { warn: true }; + }, + }); + expect(engine.selectInitial('list files')).toEqual({ provider: cloudPrimary }); + expect(called).toBe(false); + }); + + describe('considerEscalation', () => { + const engine = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + ramCheck: () => fits, + }); + + it('escalates when the model requests it', () => { + expect(engine.considerEscalation(signals({ escalateRequested: true }))).toEqual({ + provider: frontier, + reason: 'requested by model', + }); + }); + + it('escalates once consecutive errors hit the threshold', () => { + expect(engine.considerEscalation(signals({ consecutiveErrors: 3 }))).toEqual({ + provider: frontier, + reason: 'stuck — repeated tool errors', + }); + }); + + it('stays put below the threshold and without a request', () => { + expect(engine.considerEscalation(signals({ consecutiveErrors: 2 }))).toBeUndefined(); + }); + + it('never re-routes once already escalated', () => { + expect( + engine.considerEscalation( + signals({ alreadyEscalated: true, escalateRequested: true, consecutiveErrors: 9 }), + ), + ).toBeUndefined(); + }); + + it('respects a custom stuck threshold', () => { + const eager = new LocalFirstModelEngine({ + primary: local, + escalation: frontier, + stuckThreshold: 1, + ramCheck: () => fits, + }); + expect(eager.considerEscalation(signals({ consecutiveErrors: 1 }))?.reason).toBe( + 'stuck — repeated tool errors', + ); + }); + }); + + describe('costNote (cost awareness)', () => { + it('reports added API cost when escalating from a free local model', () => { + const engine = new LocalFirstModelEngine({ primary: local, escalation: frontier }); + expect(engine.costNote()).toContain('adds API cost'); + }); + + it('reports a cost multiplier between two priced cloud models', () => { + const cheap = fakeProvider('gemini-2.5-flash', 'gemini'); + const engine = new LocalFirstModelEngine({ primary: cheap, escalation: frontier }); + expect(engine.costNote()).toMatch(/escalation costs ~\d+(\.\d+)?× the primary per token/); + }); + + it('reports no API cost when escalating to a local model', () => { + const engine = new LocalFirstModelEngine({ primary: frontier, escalation: local }); + expect(engine.costNote()).toContain('no API cost'); + }); + }); +}); diff --git a/tests/agent/loop.test.ts b/tests/agent/loop.test.ts index 63aa615..304176f 100644 --- a/tests/agent/loop.test.ts +++ b/tests/agent/loop.test.ts @@ -2,6 +2,7 @@ import { describe, it, expect } from 'vitest'; import { z } from 'zod'; import { AgentLoop } from '../../src/agent/loop.js'; import type { AgentUI } from '../../src/agent/loop.js'; +import { LocalFirstModelEngine } from '../../src/agent/decision/index.js'; import { createRegistry } from '../../src/tools/registry.js'; import { defineTool } from '../../src/tools/types.js'; import { escalateTool } from '../../src/tools/escalate.js'; @@ -189,8 +190,7 @@ describe('AgentLoop', () => { ui, system: 'sys', cwd: process.cwd(), - escalationProvider: frontier, - router: () => 'heavy', + engine: new LocalFirstModelEngine({ primary: local, escalation: frontier, classify: () => 'heavy' }), }); await loop.run('refactor everything'); @@ -210,8 +210,7 @@ describe('AgentLoop', () => { ui, system: 'sys', cwd: process.cwd(), - escalationProvider: frontier, - router: () => 'light', + engine: new LocalFirstModelEngine({ primary: local, escalation: frontier, classify: () => 'light' }), }); await loop.run('list files'); @@ -239,8 +238,7 @@ describe('AgentLoop', () => { ui, system: 'sys', cwd: process.cwd(), - escalationProvider: frontier, - router: () => 'light', + engine: new LocalFirstModelEngine({ primary: local, escalation: frontier, classify: () => 'light' }), }); await loop.run('start small then get stuck'); @@ -251,6 +249,32 @@ describe('AgentLoop', () => { expect(events).toContain('text:handled'); }); + it('auto-escalates after the engine sees repeated tool errors (stuck)', async () => { + // Three iterations of a failing (unknown) tool trip the default stuck threshold. + const failing: ProviderEvent[][] = []; + for (let i = 0; i < 3; i += 1) { + failing.push([{ type: 'tool_call', id: `g${i}`, name: 'ghost', input: {} }, DONE]); + } + const local = new ScriptedProvider(failing, 'local'); + const frontier = new ScriptedProvider([[{ type: 'text', delta: 'rescued' }, DONE]], 'big', 'anthropic'); + const { ui, events } = recordingUI(); + const loop = new AgentLoop({ + provider: local, + registry, + gate: gateWith('yes'), + ui, + system: 'sys', + cwd: process.cwd(), + engine: new LocalFirstModelEngine({ primary: local, escalation: frontier, classify: () => 'light' }), + }); + await loop.run('keep failing'); + + expect(local.sent).toHaveLength(3); + expect(frontier.sent).toHaveLength(1); + expect(events).toContain('route:anthropic:big:stuck — repeated tool errors'); + expect(events).toContain('text:rescued'); + }); + it('behaves as a single provider when no escalation is configured', async () => { const provider = new ScriptedProvider([[{ type: 'text', delta: 'hi' }, DONE]]); const { ui, events } = recordingUI();