Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions backend/app/llm/provider_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,9 @@ def chat(
messages_dict = [{"role": m.role, "content": m.content} for m in messages]

start_time = time.time()
retries = 0
last_error = None
last_error: Optional[Exception] = None

while retries <= self.settings.MAX_RETRIES:
for attempt in range(1, self.settings.MAX_RETRIES + 1):
try:
response = litellm.completion(
model=model_string,
Expand Down Expand Up @@ -147,20 +146,19 @@ def chat(
)
except Exception as e:
last_error = e
retries += 1
if retries <= self.settings.MAX_RETRIES:
backoff = self.settings.RETRY_BACKOFF * (2 ** (retries - 1))
if attempt < self.settings.MAX_RETRIES:
backoff = self.settings.RETRY_BACKOFF * (2 ** (attempt - 1))
logger.warning(
"Provider request failed (attempt %s/%s): %s. Retrying in %ss...",
retries,
attempt,
self.settings.MAX_RETRIES,
e,
backoff,
)
time.sleep(backoff)

error_msg = str(last_error)
logger.error("Provider request failed after %s retries: %s", self.settings.MAX_RETRIES, error_msg)
error_msg = str(last_error) if last_error else "unknown error"
logger.error("Provider request failed after %s attempt(s): %s", self.settings.MAX_RETRIES, error_msg)
raise ProviderError(
f"Provider '{self.provider_name}' request failed: {error_msg}",
self.provider_name,
Expand Down
103 changes: 103 additions & 0 deletions backend/tests/test_provider_client_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Tests for ProviderClient.chat retry behavior.

These tests pin down the attempt-count semantics: with MAX_RETRIES=N the
client should make at most N attempts, and the failure log should say
"attempt(s)" rather than the misleading "retries".
"""
from __future__ import annotations

import sys
import time
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch

import pytest

# Make the backend importable without installing the package.
ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(ROOT / "backend"))

from app.llm.provider_client import ProviderClient, ProviderError # noqa: E402


class _FakeSettings:
MAX_RETRIES = 2
RETRY_BACKOFF = 0.0 # no sleeping in tests

def get_active_model(self, provider):
return "fake-model"

def get_runtime_api_key(self, provider):
return "test-key"

def get_runtime_base_url(self, provider):
return None

def _get_default_base_url(self, provider):
return None


class _FakeConfig:
api_format = "openai"

def get_api_key(self):
return "test-key"

def get_base_url(self):
return None


def _make_client() -> ProviderClient:
return ProviderClient(provider_name="openai")


def _ok_response():
usage = SimpleNamespace(prompt_tokens=1, completion_tokens=2, total_tokens=3)
choice = SimpleNamespace(
message=SimpleNamespace(content="ok"),
finish_reason="stop",
)
return SimpleNamespace(choices=[choice], usage=usage)


def test_chat_makes_exactly_max_retries_attempts_then_succeeds(monkeypatch):
client = _make_client()
client.settings = _FakeSettings()
client.config = _FakeConfig()
calls = {"n": 0}

def fake_completion(**kwargs):
calls["n"] += 1
if calls["n"] < 2:
raise RuntimeError("transient")
return _ok_response()

monkeypatch.setattr(client, "_get_litellm", lambda: SimpleNamespace(completion=fake_completion))
monkeypatch.setattr(client, "_get_api_config", lambda: {"api_key": "k", "api_base": None, "timeout": 5})

out = client.chat(messages=[SimpleNamespace(role="user", content="hi")])
assert out.text == "ok"
assert calls["n"] == 2 # MAX_RETRIES=2, succeeded on attempt 2


def test_chat_stops_after_max_retries_attempts(monkeypatch, caplog):
client = _make_client()
client.settings = _FakeSettings()
client.config = _FakeConfig()
calls = {"n": 0}

def fake_completion(**kwargs):
calls["n"] += 1
raise RuntimeError("always fails")

monkeypatch.setattr(client, "_get_litellm", lambda: SimpleNamespace(completion=fake_completion))
monkeypatch.setattr(client, "_get_api_config", lambda: {"api_key": "k", "api_base": None, "timeout": 5})

with caplog.at_level("ERROR", logger="app.llm.provider_client"):
with pytest.raises(ProviderError) as ei:
client.chat(messages=[SimpleNamespace(role="user", content="hi")])
assert calls["n"] == _FakeSettings.MAX_RETRIES
assert "attempt" in caplog.text
assert "retries" not in caplog.text
assert ei.value.status_code == 502