Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/agent/decision/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export type { ModelDecisionEngine, RouteDecision, TurnSignals } from './types.js';
export { LocalFirstModelEngine } from './localFirst.js';
export type { LocalFirstOptions } from './localFirst.js';
87 changes: 87 additions & 0 deletions src/agent/decision/localFirst.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import type { ModelProvider } from '../../providers/types.js';
import { classifyTurn, type TaskWeight } from '../router.js';
import { checkLocalModel } from '../../system/resources.js';
import { estimateCost } from '../../providers/pricing.js';

Check failure on line 4 in src/agent/decision/localFirst.ts

View workflow job for this annotation

GitHub Actions / build

Cannot find module '../../providers/pricing.js' or its corresponding type declarations.

Check failure on line 4 in src/agent/decision/localFirst.ts

View workflow job for this annotation

GitHub Actions / build

Cannot find module '../../providers/pricing.js' or its corresponding type declarations.
import type { ModelDecisionEngine, RouteDecision, TurnSignals } from './types.js';

export interface LocalFirstOptions {
/** The cheap/local model that handles turns by default. */
primary: ModelProvider;
/** The frontier model heavy/stuck turns escalate to. */
escalation: ModelProvider;
/** Task-weight classifier. Defaults to {@link classifyTurn}. */
classify?: (input: string) => TaskWeight;
/** Consecutive tool-error iterations before auto-escalating. Defaults to 3. */
stuckThreshold?: number;
/** RAM-fit check (injectable for tests). Defaults to {@link checkLocalModel}. */
ramCheck?: (model: string) => { warn: boolean };
}

/** Per-MTok probe used to compare relative model cost for the cost-aware note. */
const COST_PROBE = { inputTokens: 1_000_000, outputTokens: 1_000_000 };

/**
* The default local-first policy: handle each turn on the cheap/local model and
* escalate to the frontier model only when capability demands it — the task
* looks heavy up front, the model explicitly hands off via the `escalate` tool,
* it gets stuck on repeated tool errors, or (compute awareness) a local model
* won't fit in available RAM.
*
* Cost awareness is wired in but used defensively: the engine can *report* the
* cost implication of an escalation (see {@link costNote}) but never initiates
* a cost-driven route change — escalation is always a capability decision.
*/
export class LocalFirstModelEngine implements ModelDecisionEngine {
private readonly primary: ModelProvider;
private readonly escalation: ModelProvider;
private readonly classify: (input: string) => TaskWeight;
private readonly stuckThreshold: number;
private readonly ramCheck: (model: string) => { warn: boolean };

constructor(opts: LocalFirstOptions) {
this.primary = opts.primary;
this.escalation = opts.escalation;
this.classify = opts.classify ?? classifyTurn;
this.stuckThreshold = opts.stuckThreshold ?? 3;
this.ramCheck = opts.ramCheck ?? checkLocalModel;
}

selectInitial(userInput: string): RouteDecision {
// Compute awareness: a local primary that won't fit in RAM runs slowly or
// fails outright, so route it to the frontier up front — even for light work.
if (this.primary.name === 'ollama' && this.ramCheck(this.primary.model).warn) {
return { provider: this.escalation, reason: 'compute: local model exceeds RAM' };
}
if (this.classify(userInput) === 'heavy') {
return { provider: this.escalation, reason: 'heavy task' };
}
return { provider: this.primary };
}

considerEscalation(signals: TurnSignals): RouteDecision | undefined {
if (signals.alreadyEscalated) return undefined;
if (signals.escalateRequested) {
return { provider: this.escalation, reason: 'requested by model' };
}
if (signals.consecutiveErrors >= this.stuckThreshold) {
return { provider: this.escalation, reason: 'stuck — repeated tool errors' };
}
return undefined;
}

/**
* A short, human-readable summary of what escalating costs, relative to the
* primary. Pure reporting — does not influence routing. Local models report
* as having no API cost.
*/
costNote(): string {
const from = estimateCost(this.primary.model, COST_PROBE);
const to = estimateCost(this.escalation.model, COST_PROBE);
if (to === null) return 'escalates to a local model (no API cost)';
if (from === null || from === 0) {
return 'escalates from a free/local model to a paid model (adds API cost)';
}
const multiplier = to / from;
return `escalation costs ~${multiplier.toFixed(1)}× the primary per token`;
}
}
47 changes: 47 additions & 0 deletions src/agent/decision/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import type { ModelProvider } from '../../providers/types.js';

/**
* A model-selection decision: the provider to use and an optional human-readable
* reason. An empty/absent `reason` means "no route note" — the loop only fires
* {@link AgentUI.onRoute} when a decision both changes the active provider and
* carries a reason.
*/
export interface RouteDecision {
provider: ModelProvider;
reason?: string;
}

/**
* Runtime signals the loop feeds the engine once per iteration so it can decide
* whether to switch providers mid-turn. These are mechanical bookkeeping the
* loop already tracks; the engine owns the policy that interprets them.
*/
export interface TurnSignals {
/** The model called the `escalate` tool during this iteration. */
escalateRequested: boolean;
/** Consecutive iterations that ended with at least one tool error. */
consecutiveErrors: number;
/** Whether the turn has already been escalated (keeps the engine stateless). */
alreadyEscalated: boolean;
/** The provider currently handling the turn. */
current: ModelProvider;
/** Iteration index (0-based). */
iteration: number;
}

/**
* Owns all model-selection policy. {@link AgentLoop} depends on this interface
* instead of inlining classification + escalation rules, so the loop keeps only
* mechanism (send → stream → run tools → repeat) and the policy lives in one
* cohesive, testable place. Implementations may reason about task weight, cost,
* and compute (RAM) fit; the loop never sees that reasoning.
*/
export interface ModelDecisionEngine {
/** Pick the provider that starts a turn, given the user's input. */
selectInitial(userInput: string): RouteDecision;
/**
* Decide whether to switch providers mid-turn. Returns `undefined` to stay on
* the current provider. Called once per iteration after tool results.
*/
considerEscalation(signals: TurnSignals): RouteDecision | undefined;
}
101 changes: 54 additions & 47 deletions src/agent/loop.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { ToolRegistry } from '../tools/registry.js';
import type { PermissionGate } from '../permissions/gate.js';
import type { ToolResult } from '../tools/types.js';
import type { Message, ToolResultBlock, ToolUseBlock } from './types.js';
import type { ModelDecisionEngine, RouteDecision } from './decision/index.js';

export type { Usage };

Expand Down Expand Up @@ -37,15 +38,13 @@ export interface AgentLoopOptions {
ui: AgentUI;
cwd: string;
maxIterations?: number;
/** Frontier model to escalate heavy/stuck turns to (enables local-first routing). */
escalationProvider?: ModelProvider | undefined;
/** Classifies a turn up front so heavy tasks start on the frontier model. */
router?: ((input: string) => 'light' | 'heavy') | undefined;
/**
* Owns model-selection policy (which model starts a turn, when to escalate).
* When omitted, the loop runs `provider` as a single provider with no routing.
*/
engine?: ModelDecisionEngine | undefined;
}

/** Consecutive tool-error iterations before auto-escalating a stuck local model. */
const STUCK_THRESHOLD = 3;

/**
* The provider-agnostic, UI-agnostic agentic loop: send → stream → run tools →
* feed results back → repeat until the model stops requesting tools (or the
Expand All @@ -59,8 +58,7 @@ export class AgentLoop {
private readonly ui: AgentUI;
private readonly cwd: string;
private readonly maxIterations: number;
private readonly escalationProvider: ModelProvider | undefined;
private readonly router: ((input: string) => 'light' | 'heavy') | undefined;
private readonly engine: ModelDecisionEngine | undefined;
private readonly messages: Message[] = [];
private sessionUsage: Usage = { inputTokens: 0, outputTokens: 0 };
/**
Expand All @@ -69,6 +67,8 @@ export class AgentLoop {
* ping-pong back to the local model on each follow-up. Reset by clearHistory().
*/
private escalatedSession = false;
/** The provider escalated to; reused to start follow-up turns once sticky. */
private escalatedProvider: ModelProvider | undefined;

constructor(opts: AgentLoopOptions) {
this.provider = opts.provider;
Expand All @@ -78,8 +78,7 @@ export class AgentLoop {
this.ui = opts.ui;
this.cwd = opts.cwd;
this.maxIterations = opts.maxIterations ?? 50;
this.escalationProvider = opts.escalationProvider;
this.router = opts.router;
this.engine = opts.engine;
}

/** Conversation history (for inspection / persistence). */
Expand All @@ -93,6 +92,7 @@ export class AgentLoop {
clearHistory(): void {
this.messages.length = 0;
this.escalatedSession = false;
this.escalatedProvider = undefined;
}

/** Cumulative token usage across all turns in this session. */
Expand All @@ -105,18 +105,23 @@ export class AgentLoop {
this.messages.push({ role: 'user', content: [{ type: 'text', text: userInput }] });
const tools = this.registry.toSchemas();

let active: ModelProvider;
let escalated: boolean;
if (this.escalatedSession && this.escalationProvider) {
// A prior turn escalated; stay on the frontier model for follow-ups.
active = this.escalationProvider;
escalated = true;
} else {
active = this.selectInitialProvider(userInput);
escalated = active === this.escalationProvider;
}
let active = this.provider;
let escalated = false;
let consecutiveErrors = 0;

if (this.engine) {
if (this.escalatedSession && this.escalatedProvider) {
// A prior turn escalated; stay on the frontier model for follow-ups so a
// multi-turn hard task doesn't ping-pong back to the local model.
active = this.escalatedProvider;
escalated = true;
} else {
const initial = this.engine.selectInitial(userInput);
active = this.applyRoute(active, initial);
escalated = active !== this.provider;
}
}

for (let iteration = 0; iteration < this.maxIterations; iteration += 1) {
let text = '';
const toolCalls: ToolUseBlock[] = [];
Expand Down Expand Up @@ -147,11 +152,7 @@ export class AgentLoop {

if (toolCalls.length === 0) return;

// The local model can explicitly hand off via the `escalate` tool.
if (!escalated && toolCalls.some((c) => c.name === 'escalate')) {
active = this.escalate('requested by model');
escalated = true;
}
const escalateRequested = toolCalls.some((c) => c.name === 'escalate');

const results: ToolResultBlock[] = [];
let anyError = false;
Expand All @@ -162,35 +163,41 @@ export class AgentLoop {
}
this.messages.push({ role: 'user', content: results });

// Auto-escalate a local model that appears stuck (repeated tool errors).
if (!escalated) {
consecutiveErrors = anyError ? consecutiveErrors + 1 : 0;
if (consecutiveErrors >= STUCK_THRESHOLD) {
active = this.escalate('stuck — repeated tool errors');
escalated = true;
// Hand the turn's runtime signals to the engine, which owns the policy
// for whether to switch providers mid-turn (explicit escalate, stuck, …).
consecutiveErrors = anyError ? consecutiveErrors + 1 : 0;
if (this.engine) {
const next = this.engine.considerEscalation({
escalateRequested,
consecutiveErrors,
alreadyEscalated: escalated,
current: active,
iteration,
});
if (next) {
active = this.applyRoute(active, next);
escalated = active !== this.provider;
if (escalated) {
// Remember it so follow-up turns start here (sticky escalation).
this.escalatedSession = true;
this.escalatedProvider = active;
}
}
}
}

this.ui.onMaxIterations();
}

/** Pick the provider for a turn: heavy tasks start on the frontier model. */
private selectInitialProvider(input: string): ModelProvider {
if (this.escalationProvider && this.router && this.router(input) === 'heavy') {
this.ui.onRoute(this.escalationProvider.name, this.escalationProvider.model, 'heavy task', true);
return this.escalationProvider;
/**
* Apply a routing decision: switch to its provider and surface an `onRoute`
* event when it actually changes the active provider and carries a reason.
*/
private applyRoute(active: ModelProvider, decision: RouteDecision): ModelProvider {
if (decision.provider !== active && decision.reason) {
this.ui.onRoute(decision.provider.name, decision.provider.model, decision.reason);
}
return this.provider;
}

/** Switch to the frontier provider mid-turn. Falls back to the primary if unset.
* Marks the session as escalated so follow-up turns stay on the frontier model. */
private escalate(reason: string): ModelProvider {
if (!this.escalationProvider) return this.provider;
this.escalatedSession = true;
this.ui.onRoute(this.escalationProvider.name, this.escalationProvider.model, reason);
return this.escalationProvider;
return decision.provider;
}

private async executeToolCall(call: ToolUseBlock): Promise<ToolResultBlock> {
Expand Down
2 changes: 2 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ export { toOpenAiMessages, toOpenAiTools } from './providers/ollama.js';

export { classifyTurn } from './agent/router.js';
export type { TaskWeight } from './agent/router.js';
export { LocalFirstModelEngine } from './agent/decision/index.js';
export type { ModelDecisionEngine, RouteDecision, TurnSignals, LocalFirstOptions } from './agent/decision/index.js';
export { checkLocalModel, estimateModelRamGb, MODEL_RAM_GB } from './system/resources.js';
export type { LocalModelCheck } from './system/resources.js';

Expand Down
12 changes: 9 additions & 3 deletions src/repl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import type { PermissionPrompt } from './permissions/gate.js';
import { ALL_TOOLS, createRegistry } from './tools/registry.js';
import { escalateTool } from './tools/escalate.js';
import { createProvider } from './providers/index.js';
import { classifyTurn } from './agent/router.js';
import { LocalFirstModelEngine } from './agent/decision/index.js';
import type { ModelDecisionEngine } from './agent/decision/index.js';
import { checkLocalModel } from './system/resources.js';
import { loadConfig } from './config/load.js';
import type { CliOverrides, ResolvedConfig } from './config/load.js';
Expand Down Expand Up @@ -111,6 +112,12 @@ export async function startRepl(overrides: CliOverrides): Promise<void> {
})
: undefined;

// The decision engine owns all model-selection policy; the loop only runs it.
const engine: ModelDecisionEngine | undefined =
localFirst && escalationProvider
? new LocalFirstModelEngine({ primary: provider, escalation: escalationProvider })
: undefined;

const registry = createRegistry(localFirst ? [...ALL_TOOLS, escalateTool] : undefined);
const projectContext = loadProjectContext(cwd);
const system = buildSystemPrompt({
Expand Down Expand Up @@ -145,8 +152,7 @@ export async function startRepl(overrides: CliOverrides): Promise<void> {
ui,
cwd,
maxIterations: config.maxIterations,
escalationProvider,
router: localFirst ? classifyTurn : undefined,
engine,
});

// Tracks the transcript length at the last reflection, so the auto-trigger on
Expand Down
Loading
Loading