diff --git a/src/services/ai/agent.test.ts b/src/services/ai/agent.test.ts index 10ea317..09cccac 100644 --- a/src/services/ai/agent.test.ts +++ b/src/services/ai/agent.test.ts @@ -35,6 +35,12 @@ mock.module('./response-validator', () => ({ validateResponse: mock(async () => ({ approved: true })), })); +// Mock tool execution so multi-round tests don't hit the real DB. +const mockExecuteTool = mock(async () => ({ success: true, output: 'tool ok' })); +mock.module('./tool-executor', () => ({ + executeTool: mockExecuteTool, +})); + import { ExpenseBotAgent } from './agent'; import type { AgentContext } from './types'; @@ -63,6 +69,23 @@ function makeTextResult(text: string): StreamRoundResult { }; } +/** Build a StreamRoundResult that requests a single tool call (no text yet) */ +function makeToolCallResult(toolName: string): StreamRoundResult { + return { + text: '', + toolCalls: [{ id: 'call_1', name: toolName, arguments: '{}' }], + finishReason: 'tool_calls', + assistantMessage: { + role: 'assistant', + content: null, + tool_calls: [ + { id: 'call_1', type: 'function', function: { name: toolName, arguments: '{}' } }, + ], + }, + providerUsed: 'mock', + }; +} + /** * Mock aiStreamRound to return a text response, also calling onTextDelta for each chunk. * Uses mockImplementationOnce. @@ -623,4 +646,62 @@ describe('ExpenseBotAgent', () => { ).rejects.toBeInstanceOf(TypeError); }); }); + + // -- run() -- multi-round completion & overall timeout --------------------- + + describe('run() -- multi-round completion', () => { + beforeEach(() => { + mockExecuteTool.mockClear(); + spyOn( + agent as unknown as { sleep: (ms: number) => Promise }, + 'sleep', + ).mockResolvedValue(undefined); + }); + + it('completes a 3-round query (tool, tool, final text) without an overall abort', async () => { + // Problem B regression: a legitimate multi-round query must finish. Each round + // is a separate aiStreamRound call — the first two request tool calls, the third + // returns the final answer. With mocks there is no real wall-clock delay, so this + // guards that the tool loop completes across rounds and the agent returns the text. + mockAiStreamRound + .mockImplementationOnce(async () => makeToolCallResult('get_expenses')) + .mockImplementationOnce(async () => makeToolCallResult('calculate')) + .mockImplementationOnce(async (_opts, callbacks) => { + const text = 'Final answer after 3 rounds'; + callbacks.onTextDelta?.(text); + return makeTextResult(text); + }); + + const result = await agent.run( + 'How much did I spend last month?', + [], + mockBot as unknown as import('gramio').Bot, + ); + + expect(result).toContain('Final answer after 3 rounds'); + expect(mockAiStreamRound).toHaveBeenCalledTimes(3); + expect(mockExecuteTool).toHaveBeenCalledTimes(2); + }); + + it('overall-deadline abort yields the timeout message, not the generic AI error', async () => { + // Acceptance criterion 3: a truly-stuck run is bounded by the overall cap and + // surfaces "Время ожидания истекло", NOT "Ошибка AI". The streaming layer throws + // a plain Error with name='AbortError' when the overall signal fires. + const abortError = new Error('AI overall deadline exceeded'); + abortError.name = 'AbortError'; + mockIsRetryableError.mockReturnValue(true); + mockGetBackoffDelay.mockReturnValue(0); + mockAiStreamRound.mockRejectedValue(abortError); + + const { AgentError } = await import('../../errors'); + try { + await agent.run('question', [], mockBot as unknown as import('gramio').Bot); + expect.unreachable('should have thrown'); + } catch (err) { + expect(err).toBeInstanceOf(AgentError); + expect((err as InstanceType).userMessage).toContain('ожидания'); + expect((err as InstanceType).userMessage).not.toContain('Ошибка AI'); + } + }); + }); }); diff --git a/src/services/ai/agent.ts b/src/services/ai/agent.ts index 8dea2a6..7f6d36f 100644 --- a/src/services/ai/agent.ts +++ b/src/services/ai/agent.ts @@ -34,7 +34,13 @@ const logger = createLogger('agent'); export const aiDebugLogger = new AiDebugLogger(env.AI_DEBUG_LOGS, 'logs'); const MAX_TOOL_ROUNDS = 10; -const AGENT_TIMEOUT_MS = 60_000; +/** + * Overall wall-clock cap for one agent run (all rounds + all retries combined). + * A legitimate multi-round query with a slow model needs well over 60s, so this + * is generous; it only bounds a truly-stuck run. Per-provider timeouts (in + * streaming.ts) catch individual hung providers much sooner. + */ +const AGENT_TIMEOUT_MS = 180_000; const MAX_API_RETRIES = 2; // 3 attempts total (1 initial + 2 retries) export class ExpenseBotAgent { @@ -201,35 +207,44 @@ export class ExpenseBotAgent { debugCtx: AiDebugRunContext | null, toolCallNames: string[], ): Promise<{ text: string; toolCount: number }> { - for (let attempt = 0; attempt <= MAX_API_RETRIES; attempt++) { - const controller = new AbortController(); - const timeout = setTimeout(() => controller.abort(), AGENT_TIMEOUT_MS); + // One overall deadline spanning all retries — a truly-stuck run is bounded + // by AGENT_TIMEOUT_MS, not AGENT_TIMEOUT_MS × attempts. Once it fires, the + // signal stays aborted, so any retry attempt aborts immediately. + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), AGENT_TIMEOUT_MS); - try { - return await this.runAgentLoop( - messages, - writer, - debugCtx, - controller.signal, - toolCallNames, - ); - } catch (error) { - if (attempt < MAX_API_RETRIES && isRetryableError(error)) { - const delay = getBackoffDelay(attempt, error); - const errName = error instanceof Error ? error.message : String(error); - logger.warn( - `[AGENT] Attempt ${attempt + 1}/${MAX_API_RETRIES + 1} failed (${errName}), retrying in ${delay}ms`, + try { + for (let attempt = 0; attempt <= MAX_API_RETRIES; attempt++) { + try { + return await this.runAgentLoop( + messages, + writer, + debugCtx, + controller.signal, + toolCallNames, ); - writer.reset(); - toolCallNames.length = 0; - await this.sleep(delay); - continue; - } + } catch (error) { + // Overall deadline fired — stop retrying, surface the timeout. + if (controller.signal.aborted) { + throw error; + } + if (attempt < MAX_API_RETRIES && isRetryableError(error)) { + const delay = getBackoffDelay(attempt, error); + const errName = error instanceof Error ? error.message : String(error); + logger.warn( + `[AGENT] Attempt ${attempt + 1}/${MAX_API_RETRIES + 1} failed (${errName}), retrying in ${delay}ms`, + ); + writer.reset(); + toolCallNames.length = 0; + await this.sleep(delay); + continue; + } - throw error; - } finally { - clearTimeout(timeout); + throw error; + } } + } finally { + clearTimeout(timeout); } throw new Error('[AGENT] Retry loop exhausted'); } diff --git a/src/services/ai/streaming.test.ts b/src/services/ai/streaming.test.ts index 639c997..ce0ed7b 100644 --- a/src/services/ai/streaming.test.ts +++ b/src/services/ai/streaming.test.ts @@ -383,19 +383,16 @@ describe('aiStreamRound fallback chain', () => { expect(result.providerUsed).toContain('Gemini'); }); - it('aborts cleanly when signal triggers before stream start', async () => { + it('stops the chain immediately when the overall signal is pre-aborted', async () => { + // When the caller's overall deadline has already passed, trying any provider + // is pointless — the chain must abort at once with an AbortError-classified error, + // NOT loop through all three providers. const clientsMod = await import('./clients'); - const createMock = mock(async (_params: unknown, opts?: { signal?: AbortSignal }) => { - // Simulate SDK v6: a pre-aborted signal yields an APIError with undefined status - if (opts?.signal?.aborted) { - throw new OpenAI.APIError(undefined as unknown as number, {}, 'aborted', new Headers()); - } - return { - [Symbol.asyncIterator]: async function* () { - yield { choices: [{ delta: { content: 'nope' }, finish_reason: 'stop' }] }; - }, - }; - }); + const createMock = mock(async () => ({ + [Symbol.asyncIterator]: async function* () { + yield { choices: [{ delta: { content: 'nope' }, finish_reason: 'stop' }] }; + }, + })); spyOn(clientsMod, 'zaiClient').mockReturnValue({ chat: { completions: { create: createMock } }, } as unknown as OpenAI); @@ -409,15 +406,112 @@ describe('aiStreamRound fallback chain', () => { const controller = new AbortController(); controller.abort(); - // All providers abort → aggregated error with 3 provider failures - await expect( - streamingModule.aiStreamRound({ + try { + await streamingModule.aiStreamRound({ messages: [{ role: 'user', content: 'hi' }], maxTokens: 50, chain: 'smart', signal: controller.signal, - }), - ).rejects.toThrow(/All 3 providers/); + }); + throw new Error('should have thrown'); + } catch (err) { + expect((err as Error).name).toBe('AbortError'); + } + // No provider should have been called — the overall signal was already aborted. + expect(createMock).not.toHaveBeenCalled(); + }); + + it('per-provider timeout fires → falls back to next provider with a FRESH (non-aborted) signal', async () => { + // Problem A regression: the first provider hangs until its OWN per-provider + // timeout fires. The fallback must reach the second provider with a signal that + // is NOT aborted (the shared-signal bug would hand it the already-aborted signal). + // On the OLD code the first provider's hang never resolves → the test times out. + const clientsMod = await import('./clients'); + const calls: string[] = []; + const secondSignalStates: Array = []; + + const createMock = mock( + async ( + params: OpenAI.ChatCompletionCreateParamsStreaming, + opts?: { signal?: AbortSignal }, + ) => { + calls.push(params.model); + if (params.model === 'test-model') { + // Hang until our injected per-provider timeout aborts this provider's signal. + await new Promise((_resolve, reject) => { + opts?.signal?.addEventListener('abort', () => { + const err = new Error('aborted by per-provider timeout'); + err.name = 'AbortError'; + reject(err); + }); + }); + throw new Error('unreachable'); + } + // Second provider: record whether its signal is fresh (not aborted). + secondSignalStates.push(opts?.signal?.aborted); + return { + [Symbol.asyncIterator]: async function* () { + yield { choices: [{ delta: { content: 'fresh fallback' }, finish_reason: 'stop' }] }; + }, + }; + }, + ); + spyOn(clientsMod, 'zaiClient').mockReturnValue({ + chat: { completions: { create: createMock } }, + } as unknown as OpenAI); + spyOn(clientsMod, 'geminiClient').mockReturnValue({ + chat: { completions: { create: createMock } }, + } as unknown as OpenAI); + + const result = await streamingModule.aiStreamRound({ + messages: [{ role: 'user', content: 'hi' }], + maxTokens: 50, + chain: 'smart', + perProviderTimeoutMs: 20, + }); + + expect(result.text).toBe('fresh fallback'); + expect(result.providerUsed).toContain('Gemini'); + expect(calls).toEqual(['test-model', 'gemini-test']); + // The second provider received a fresh, non-aborted signal. + expect(secondSignalStates).toEqual([false]); + }); + + it('per-provider timeout fires but overall signal also aborted → stops the chain', async () => { + // If the caller's overall deadline passed while the first provider was running, + // the chain must NOT try the next provider — it throws an AbortError instead. + const clientsMod = await import('./clients'); + const calls: string[] = []; + const overallController = new AbortController(); + + const createMock = mock(async (params: OpenAI.ChatCompletionCreateParamsStreaming) => { + calls.push(params.model); + // First provider: abort the OVERALL signal, then reject as if its own timeout fired. + overallController.abort(); + const err = new Error('aborted'); + err.name = 'AbortError'; + throw err; + }); + spyOn(clientsMod, 'zaiClient').mockReturnValue({ + chat: { completions: { create: createMock } }, + } as unknown as OpenAI); + spyOn(clientsMod, 'geminiClient').mockReturnValue({ + chat: { completions: { create: createMock } }, + } as unknown as OpenAI); + + try { + await streamingModule.aiStreamRound({ + messages: [{ role: 'user', content: 'hi' }], + maxTokens: 50, + chain: 'smart', + signal: overallController.signal, + }); + throw new Error('should have thrown'); + } catch (err) { + expect((err as Error).name).toBe('AbortError'); + } + // Only the first provider was tried — overall deadline stops the chain. + expect(calls).toEqual(['test-model']); }); it('tool-call resolution: handles missing tc.index (HF Router quirk)', async () => { diff --git a/src/services/ai/streaming.ts b/src/services/ai/streaming.ts index 0f43144..2a72bdb 100644 --- a/src/services/ai/streaming.ts +++ b/src/services/ai/streaming.ts @@ -26,6 +26,14 @@ const logger = createLogger('ai-streaming'); const DEFAULT_TEMPERATURE = 0.3; +/** + * Per-provider wall-clock budget. Each provider in the fallback chain gets its + * own fresh timeout — a slow/hung provider aborts after this and the loop tries + * the next one with a clean signal. NOT shared across providers, so one stuck + * provider does not poison the fallback chain. + */ +const PER_PROVIDER_TIMEOUT_MS = 45_000; + // ── Types ─────────────────────────────────────────────────────────────────── export type ChainName = 'smart' | 'fast' | 'ocr'; @@ -53,7 +61,10 @@ export interface StreamRoundOptions { temperature?: number; /** Which chain to run. Default: 'smart'. */ chain?: ChainName; + /** Overall caller deadline. When aborted, the fallback chain stops immediately. */ signal?: AbortSignal; + /** Per-provider timeout override (defaults to PER_PROVIDER_TIMEOUT_MS). For tests. */ + perProviderTimeoutMs?: number; } export interface StreamCallbacks { @@ -63,6 +74,17 @@ export interface StreamCallbacks { // ── Error helpers (exported for tests) ────────────────────────────────────── +/** + * Build an Error classified as an abort. A plain Error with name='AbortError' + * (not a DOMException) so upstream `error.name === 'AbortError'` checks match + * the existing convention and the timeout user message fires. + */ +function makeAbortError(message: string): Error { + const err = new Error(message); + err.name = 'AbortError'; + return err; +} + /** Provider-down: 5xx, timeout, network. Means "try next", retrying same provider is hopeless. */ function isProviderDown(error: unknown): boolean { if (error instanceof OpenAI.APIError && error.status !== undefined && error.status >= 500) { @@ -304,11 +326,29 @@ export async function aiStreamRound( } const providerErrors: Array<{ name: string; error: Error }> = []; + const perProviderTimeoutMs = options.perProviderTimeoutMs ?? PER_PROVIDER_TIMEOUT_MS; for (const slot of chain) { + // Overall caller deadline already passed — trying more providers is pointless. + if (options.signal?.aborted) { + logger.warn('[AI_STREAM] Overall deadline exceeded, stopping fallback chain'); + throw makeAbortError('AI overall deadline exceeded'); + } + + // Each provider gets its own fresh timeout combined with the overall signal. + // A slow provider aborts after perProviderTimeoutMs without poisoning the next one. + const perProviderController = new AbortController(); + const perProviderTimeout = setTimeout( + () => perProviderController.abort(), + perProviderTimeoutMs, + ); + const combinedSignal = options.signal + ? AbortSignal.any([options.signal, perProviderController.signal]) + : perProviderController.signal; + try { logger.info(`[AI_STREAM] Trying ${chainName} → ${slot.name}`); - return await slot.stream(options, wrappedCallbacks); + return await slot.stream({ ...options, signal: combinedSignal }, wrappedCallbacks); } catch (error) { lastError = error instanceof Error ? error : new Error(String(error)); providerErrors.push({ name: slot.name, error: lastError }); @@ -322,10 +362,19 @@ export async function aiStreamRound( throw error; } + // Overall deadline fired (not just this provider's timeout) — stop the chain + // and signal a timeout so the caller surfaces the "time exceeded" message. + if (options.signal?.aborted) { + logger.warn('[AI_STREAM] Overall deadline exceeded mid-provider, stopping fallback chain'); + throw makeAbortError('AI overall deadline exceeded'); + } + // Always try the next provider in the chain. // isRetryableError is for same-provider retry (backoff), not for fallback decisions. // Different providers have different quirks — one may fail where another succeeds. logger.warn(`[AI_STREAM] ${slot.name} failed, trying next provider`); + } finally { + clearTimeout(perProviderTimeout); } }