Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/conductor/engine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,7 @@ async def _handle_partial_output(
session,
interrupt_result.guidance,
agent_name=agent.name,
agent_model=agent.model,
)

# Fallback: re-execute the agent with guidance appended to prompt
Expand Down
35 changes: 33 additions & 2 deletions src/conductor/providers/copilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class SDKResponse:
cache_read_tokens: Tokens read from cache (if available).
cache_write_tokens: Tokens written to cache (if available).
partial: Whether this response is partial (from a mid-agent interrupt).
resolved_model: Model name the SDK reported in the assistant.usage event's
model field. None when that event is absent (error or interrupt paths)
or carries no usable model name (the model field is missing or empty).
"""

content: str
Expand All @@ -129,6 +132,7 @@ class SDKResponse:
cache_read_tokens: int | None = None
cache_write_tokens: int | None = None
partial: bool = False
resolved_model: str | None = None


class CopilotProvider(AgentProvider):
Expand Down Expand Up @@ -574,7 +578,13 @@ async def _execute_with_retry(
output_tokens=output_tokens,
cache_read_tokens=cache_read,
cache_write_tokens=cache_write,
model=agent.model or self._default_model,
model=(
sdk_response.resolved_model
if sdk_response and agent.model in (None, "auto")
else None
)
or agent.model
or self._default_model,
partial=is_partial,
)
except ProviderError as e:
Expand Down Expand Up @@ -888,6 +898,7 @@ async def _execute_sdk_call(
cache_read_tokens=sdk_response.cache_read_tokens,
cache_write_tokens=sdk_response.cache_write_tokens,
partial=True,
resolved_model=sdk_response.resolved_model,
)
return partial_content, partial_usage

Expand All @@ -896,6 +907,7 @@ async def _execute_sdk_call(
total_output_tokens = sdk_response.output_tokens
cache_read_tokens = sdk_response.cache_read_tokens
cache_write_tokens = sdk_response.cache_write_tokens
current_resolved_model = sdk_response.resolved_model

# If no output schema (or output_mode is raw), we're done
if output_schema is None:
Expand All @@ -905,6 +917,7 @@ async def _execute_sdk_call(
output_tokens=total_output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
resolved_model=current_resolved_model,
)
return {"result": response_content}, final_usage

Expand All @@ -921,6 +934,7 @@ async def _execute_sdk_call(
output_tokens=total_output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
resolved_model=current_resolved_model,
)
return parsed_content, final_usage
except (json.JSONDecodeError, ValueError) as e:
Expand Down Expand Up @@ -965,6 +979,9 @@ async def _execute_sdk_call(
total_output_tokens = (
total_output_tokens or 0
) + recovery_response.output_tokens
# Keep the latest resolved model (recovery uses the same session/model)
if recovery_response.resolved_model:
current_resolved_model = recovery_response.resolved_model

# All recovery attempts exhausted
expected_fields = list(output_schema.keys())
Expand Down Expand Up @@ -1049,6 +1066,9 @@ async def _send_and_wait(
# Mutable container for usage data: [input_tokens, output_tokens, cache_read, cache_write]
usage_ref: list[int | None] = [None, None, None, None]

# Mutable container for the resolved model name (from assistant.usage event)
resolved_model_ref: list[str | None] = [None]

# Mutable container for tool iteration counting
tool_iteration_ref: list[int] = [0]

Expand Down Expand Up @@ -1093,6 +1113,10 @@ def on_event(event: Any) -> None:
usage_ref[2] = int(cache_read)
if cache_write is not None:
usage_ref[3] = int(cache_write)
# Capture the actual model resolved by the SDK (e.g., when model="auto")
sdk_model = getattr(event.data, "model", None)
if sdk_model:
resolved_model_ref[0] = sdk_model
elif event_type == "session.idle":
done.set()
elif event_type == "error" or event_type == "session.error":
Expand Down Expand Up @@ -1148,6 +1172,7 @@ def on_event(event: Any) -> None:
cache_read_tokens=usage_ref[2],
cache_write_tokens=usage_ref[3],
partial=True,
resolved_model=resolved_model_ref[0],
)

if error_message:
Expand All @@ -1162,6 +1187,7 @@ def on_event(event: Any) -> None:
output_tokens=usage_ref[1],
cache_read_tokens=usage_ref[2],
cache_write_tokens=usage_ref[3],
resolved_model=resolved_model_ref[0],
)

async def _abort_session(self, session: Any, done: asyncio.Event) -> None:
Expand Down Expand Up @@ -1219,6 +1245,7 @@ async def send_followup(
session: Any,
guidance: str,
agent_name: str | None = None,
agent_model: str | None = None,
) -> AgentOutput:
"""Send follow-up guidance to an interrupted session.

Expand All @@ -1235,6 +1262,8 @@ async def send_followup(
interrupts only fire on sequential agents (for-each iterations
do not forward ``interrupt_signal`` to the executor), so the
tag, when supplied, is the unqualified agent name.
agent_model: Optional configured model for the interrupted agent,
used when the follow-up response does not report a resolved model.

Returns:
AgentOutput with the follow-up response content.
Expand Down Expand Up @@ -1271,7 +1300,9 @@ async def send_followup(
output_tokens=sdk_response.output_tokens,
cache_read_tokens=sdk_response.cache_read_tokens,
cache_write_tokens=sdk_response.cache_write_tokens,
model=self._default_model,
model=(sdk_response.resolved_model if agent_model in (None, "auto") else None)
or agent_model
or self._default_model,
)
finally:
await session.disconnect()
Expand Down
244 changes: 244 additions & 0 deletions tests/test_providers/test_copilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,3 +1747,247 @@ async def test_tier_and_reasoning_compose(self, monkeypatch: pytest.MonkeyPatch)
await provider.execute(agent=agent, context={}, rendered_prompt="Analyze")
assert captured["create_session_kwargs"]["context_tier"] == "long_context"
assert captured["create_session_kwargs"]["reasoning_effort"] == "high"


class TestCopilotProviderResolvedModel:
"""Tests for SDKResponse.resolved_model propagation into AgentOutput.model."""

# These are mocked tests; model availability in the live Copilot environment is not required.
# claude-sonnet-4 is used in pricing/usage tests as the canonical model and exists in the
# Conductor pricing table, so it satisfies both propagation and cost-calculation tests.
_RESOLVED_MODEL_FROM_SDK = "claude-sonnet-4"
_PRICEABLE_MODEL = "gpt-4o"
_UNPRICEABLE_RESOLVED_MODEL = "claude-sonnet-4.5"

class _FakeSession:
session_id = "session-fake"

async def disconnect(self) -> None:
return None

class _FakeClient:
async def create_session(self, **kwargs: Any) -> Any:
return TestCopilotProviderResolvedModel._FakeSession()

@pytest.mark.asyncio
async def test_resolved_model_from_sdk_overrides_auto(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""SDKResponse.resolved_model propagates into AgentOutput.model for model='auto'."""
provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1))
provider._client = self._FakeClient()
agent = AgentDef(name="a", model="auto", prompt="p")

async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse:
return SDKResponse(
content='{"result":"ok"}', resolved_model=self._RESOLVED_MODEL_FROM_SDK
)

async def noop() -> None:
return None

monkeypatch.setattr(provider, "_ensure_client_started", noop)
monkeypatch.setattr(provider, "_send_and_wait", fake_send)

result = await provider.execute(agent=agent, context={}, rendered_prompt="p")
assert result.model == self._RESOLVED_MODEL_FROM_SDK

@pytest.mark.asyncio
async def test_resolved_model_fallback_when_absent(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""When resolved_model is None, AgentOutput.model falls back to agent.model."""
provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1))
provider._client = self._FakeClient()
agent = AgentDef(name="a", model="auto", prompt="p")

async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse:
return SDKResponse(content='{"result":"ok"}', resolved_model=None)

async def noop() -> None:
return None

monkeypatch.setattr(provider, "_ensure_client_started", noop)
monkeypatch.setattr(provider, "_send_and_wait", fake_send)

result = await provider.execute(agent=agent, context={}, rendered_prompt="p")
assert result.model == "auto"

@pytest.mark.asyncio
async def test_resolved_model_enables_auto_model_cost_calculation(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A resolved priceable model from execute() produces non-null cost_usd."""
from conductor.engine.usage import UsageTracker

provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1))
provider._client = self._FakeClient()
agent = AgentDef(name="a", model="auto", prompt="p")

async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse:
return SDKResponse(
content='{"result":"ok"}',
input_tokens=1000,
output_tokens=500,
resolved_model=self._RESOLVED_MODEL_FROM_SDK,
)

async def noop() -> None:
return None

monkeypatch.setattr(provider, "_ensure_client_started", noop)
monkeypatch.setattr(provider, "_send_and_wait", fake_send)

result = await provider.execute(agent=agent, context={}, rendered_prompt="p")
tracker = UsageTracker()
usage = tracker.record("agent", result, elapsed=1.0)
assert usage.cost_usd is not None
assert usage.cost_usd > 0

@pytest.mark.asyncio
async def test_auto_model_without_resolved_model_remains_unpriced(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""model='auto' remains unpriceable when the SDK does not report a model."""
from conductor.engine.usage import UsageTracker

provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1))
provider._client = self._FakeClient()
agent = AgentDef(name="a", model="auto", prompt="p")

async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse:
return SDKResponse(
content='{"result":"ok"}',
input_tokens=1000,
output_tokens=500,
resolved_model=None,
)

async def noop() -> None:
return None

monkeypatch.setattr(provider, "_ensure_client_started", noop)
monkeypatch.setattr(provider, "_send_and_wait", fake_send)

result = await provider.execute(agent=agent, context={}, rendered_prompt="p")
usage = UsageTracker().record("agent", result, elapsed=1.0)
assert result.model == "auto"
assert usage.cost_usd is None

@pytest.mark.asyncio
async def test_explicit_priceable_model_ignores_unpriceable_resolved_model(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Explicit configured models keep their pricing name over SDK aliases."""
from conductor.engine.usage import UsageTracker

provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1))
provider._client = self._FakeClient()
agent = AgentDef(name="a", model=self._PRICEABLE_MODEL, prompt="p")

async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse:
return SDKResponse(
content='{"result":"ok"}',
input_tokens=1000,
output_tokens=500,
resolved_model=self._UNPRICEABLE_RESOLVED_MODEL,
)

async def noop() -> None:
return None

monkeypatch.setattr(provider, "_ensure_client_started", noop)
monkeypatch.setattr(provider, "_send_and_wait", fake_send)

result = await provider.execute(agent=agent, context={}, rendered_prompt="p")
usage = UsageTracker().record("agent", result, elapsed=1.0)
assert result.model == self._PRICEABLE_MODEL
assert usage.cost_usd is not None
assert usage.cost_usd > 0

@pytest.mark.asyncio
async def test_followup_preserves_explicit_model_over_unpriceable_resolved_model(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Follow-up turns keep explicit pricing names over SDK aliases."""
provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1))

async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse:
return SDKResponse(
content='{"result":"ok"}',
resolved_model=self._UNPRICEABLE_RESOLVED_MODEL,
)

monkeypatch.setattr(provider, "_send_and_wait", fake_send)

result = await provider.send_followup(
self._FakeSession(),
guidance="continue",
agent_model=self._PRICEABLE_MODEL,
)
assert result.model == self._PRICEABLE_MODEL

@pytest.mark.asyncio
async def test_followup_uses_resolved_model_for_auto_model(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Follow-up turns use SDK resolved model for auto-routed agents."""
provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1))

async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse:
return SDKResponse(
content='{"result":"ok"}',
resolved_model=self._RESOLVED_MODEL_FROM_SDK,
)

monkeypatch.setattr(provider, "_send_and_wait", fake_send)

result = await provider.send_followup(
self._FakeSession(),
guidance="continue",
agent_model="auto",
)
assert result.model == self._RESOLVED_MODEL_FROM_SDK

@pytest.mark.asyncio
async def test_send_and_wait_captures_model_from_usage_event(self) -> None:
"""_send_and_wait extracts event.data.model from assistant.usage into resolved_model."""
from unittest.mock import Mock as _Mock

provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1))
captured_cb: list[Any] = []

# Build assistant.usage event with explicit token counts and model name.
usage_ev = _Mock()
usage_ev.type.value = "assistant.usage"
usage_ev.data.input_tokens = 100
usage_ev.data.output_tokens = 50
usage_ev.data.cache_read_tokens = None
usage_ev.data.cache_write_tokens = None
usage_ev.data.model = self._RESOLVED_MODEL_FROM_SDK

# session.idle tells _send_and_wait the turn is complete.
idle_ev = _Mock()
idle_ev.type.value = "session.idle"

def on_event(callback: Any) -> None:
captured_cb.append(callback)

session = _Mock()
session.on = on_event

async def fake_send(prompt: str) -> None:
assert captured_cb, "Expected _send_and_wait to register a session event callback"
callback = captured_cb[0]
for ev in (usage_ev, idle_ev):
callback(ev)

session.send = fake_send

result = await provider._send_and_wait(
session=session,
prompt="hello",
verbose_enabled=False,
full_enabled=False,
)
assert result.resolved_model == self._RESOLVED_MODEL_FROM_SDK
Loading