diff --git a/src/conductor/engine/workflow.py b/src/conductor/engine/workflow.py index c6c7b22..ad5a893 100644 --- a/src/conductor/engine/workflow.py +++ b/src/conductor/engine/workflow.py @@ -2471,6 +2471,7 @@ async def _handle_partial_output( session, interrupt_result.guidance, agent_name=agent.name, + agent_model=agent.model, ) # Fallback: re-execute the agent with guidance appended to prompt diff --git a/src/conductor/providers/copilot.py b/src/conductor/providers/copilot.py index 3f5dcda..c9659de 100644 --- a/src/conductor/providers/copilot.py +++ b/src/conductor/providers/copilot.py @@ -121,6 +121,9 @@ class SDKResponse: cache_read_tokens: Tokens read from cache (if available). cache_write_tokens: Tokens written to cache (if available). partial: Whether this response is partial (from a mid-agent interrupt). + resolved_model: Model name the SDK reported in the assistant.usage event's + model field. None when that event is absent (error or interrupt paths) + or carries no usable model name (the model field is missing or empty). """ content: str @@ -129,6 +132,7 @@ class SDKResponse: cache_read_tokens: int | None = None cache_write_tokens: int | None = None partial: bool = False + resolved_model: str | None = None class CopilotProvider(AgentProvider): @@ -574,7 +578,13 @@ async def _execute_with_retry( output_tokens=output_tokens, cache_read_tokens=cache_read, cache_write_tokens=cache_write, - model=agent.model or self._default_model, + model=( + sdk_response.resolved_model + if sdk_response and agent.model in (None, "auto") + else None + ) + or agent.model + or self._default_model, partial=is_partial, ) except ProviderError as e: @@ -888,6 +898,7 @@ async def _execute_sdk_call( cache_read_tokens=sdk_response.cache_read_tokens, cache_write_tokens=sdk_response.cache_write_tokens, partial=True, + resolved_model=sdk_response.resolved_model, ) return partial_content, partial_usage @@ -896,6 +907,7 @@ async def _execute_sdk_call( total_output_tokens = sdk_response.output_tokens cache_read_tokens = sdk_response.cache_read_tokens cache_write_tokens = sdk_response.cache_write_tokens + current_resolved_model = sdk_response.resolved_model # If no output schema (or output_mode is raw), we're done if output_schema is None: @@ -905,6 +917,7 @@ async def _execute_sdk_call( output_tokens=total_output_tokens, cache_read_tokens=cache_read_tokens, cache_write_tokens=cache_write_tokens, + resolved_model=current_resolved_model, ) return {"result": response_content}, final_usage @@ -921,6 +934,7 @@ async def _execute_sdk_call( output_tokens=total_output_tokens, cache_read_tokens=cache_read_tokens, cache_write_tokens=cache_write_tokens, + resolved_model=current_resolved_model, ) return parsed_content, final_usage except (json.JSONDecodeError, ValueError) as e: @@ -965,6 +979,9 @@ async def _execute_sdk_call( total_output_tokens = ( total_output_tokens or 0 ) + recovery_response.output_tokens + # Keep the latest resolved model (recovery uses the same session/model) + if recovery_response.resolved_model: + current_resolved_model = recovery_response.resolved_model # All recovery attempts exhausted expected_fields = list(output_schema.keys()) @@ -1049,6 +1066,9 @@ async def _send_and_wait( # Mutable container for usage data: [input_tokens, output_tokens, cache_read, cache_write] usage_ref: list[int | None] = [None, None, None, None] + # Mutable container for the resolved model name (from assistant.usage event) + resolved_model_ref: list[str | None] = [None] + # Mutable container for tool iteration counting tool_iteration_ref: list[int] = [0] @@ -1093,6 +1113,10 @@ def on_event(event: Any) -> None: usage_ref[2] = int(cache_read) if cache_write is not None: usage_ref[3] = int(cache_write) + # Capture the actual model resolved by the SDK (e.g., when model="auto") + sdk_model = getattr(event.data, "model", None) + if sdk_model: + resolved_model_ref[0] = sdk_model elif event_type == "session.idle": done.set() elif event_type == "error" or event_type == "session.error": @@ -1148,6 +1172,7 @@ def on_event(event: Any) -> None: cache_read_tokens=usage_ref[2], cache_write_tokens=usage_ref[3], partial=True, + resolved_model=resolved_model_ref[0], ) if error_message: @@ -1162,6 +1187,7 @@ def on_event(event: Any) -> None: output_tokens=usage_ref[1], cache_read_tokens=usage_ref[2], cache_write_tokens=usage_ref[3], + resolved_model=resolved_model_ref[0], ) async def _abort_session(self, session: Any, done: asyncio.Event) -> None: @@ -1219,6 +1245,7 @@ async def send_followup( session: Any, guidance: str, agent_name: str | None = None, + agent_model: str | None = None, ) -> AgentOutput: """Send follow-up guidance to an interrupted session. @@ -1235,6 +1262,8 @@ async def send_followup( interrupts only fire on sequential agents (for-each iterations do not forward ``interrupt_signal`` to the executor), so the tag, when supplied, is the unqualified agent name. + agent_model: Optional configured model for the interrupted agent, + used when the follow-up response does not report a resolved model. Returns: AgentOutput with the follow-up response content. @@ -1271,7 +1300,9 @@ async def send_followup( output_tokens=sdk_response.output_tokens, cache_read_tokens=sdk_response.cache_read_tokens, cache_write_tokens=sdk_response.cache_write_tokens, - model=self._default_model, + model=(sdk_response.resolved_model if agent_model in (None, "auto") else None) + or agent_model + or self._default_model, ) finally: await session.disconnect() diff --git a/tests/test_providers/test_copilot.py b/tests/test_providers/test_copilot.py index 6cd1dac..e316321 100644 --- a/tests/test_providers/test_copilot.py +++ b/tests/test_providers/test_copilot.py @@ -1747,3 +1747,247 @@ async def test_tier_and_reasoning_compose(self, monkeypatch: pytest.MonkeyPatch) await provider.execute(agent=agent, context={}, rendered_prompt="Analyze") assert captured["create_session_kwargs"]["context_tier"] == "long_context" assert captured["create_session_kwargs"]["reasoning_effort"] == "high" + + +class TestCopilotProviderResolvedModel: + """Tests for SDKResponse.resolved_model propagation into AgentOutput.model.""" + + # These are mocked tests; model availability in the live Copilot environment is not required. + # claude-sonnet-4 is used in pricing/usage tests as the canonical model and exists in the + # Conductor pricing table, so it satisfies both propagation and cost-calculation tests. + _RESOLVED_MODEL_FROM_SDK = "claude-sonnet-4" + _PRICEABLE_MODEL = "gpt-4o" + _UNPRICEABLE_RESOLVED_MODEL = "claude-sonnet-4.5" + + class _FakeSession: + session_id = "session-fake" + + async def disconnect(self) -> None: + return None + + class _FakeClient: + async def create_session(self, **kwargs: Any) -> Any: + return TestCopilotProviderResolvedModel._FakeSession() + + @pytest.mark.asyncio + async def test_resolved_model_from_sdk_overrides_auto( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """SDKResponse.resolved_model propagates into AgentOutput.model for model='auto'.""" + provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1)) + provider._client = self._FakeClient() + agent = AgentDef(name="a", model="auto", prompt="p") + + async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse: + return SDKResponse( + content='{"result":"ok"}', resolved_model=self._RESOLVED_MODEL_FROM_SDK + ) + + async def noop() -> None: + return None + + monkeypatch.setattr(provider, "_ensure_client_started", noop) + monkeypatch.setattr(provider, "_send_and_wait", fake_send) + + result = await provider.execute(agent=agent, context={}, rendered_prompt="p") + assert result.model == self._RESOLVED_MODEL_FROM_SDK + + @pytest.mark.asyncio + async def test_resolved_model_fallback_when_absent( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """When resolved_model is None, AgentOutput.model falls back to agent.model.""" + provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1)) + provider._client = self._FakeClient() + agent = AgentDef(name="a", model="auto", prompt="p") + + async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse: + return SDKResponse(content='{"result":"ok"}', resolved_model=None) + + async def noop() -> None: + return None + + monkeypatch.setattr(provider, "_ensure_client_started", noop) + monkeypatch.setattr(provider, "_send_and_wait", fake_send) + + result = await provider.execute(agent=agent, context={}, rendered_prompt="p") + assert result.model == "auto" + + @pytest.mark.asyncio + async def test_resolved_model_enables_auto_model_cost_calculation( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """A resolved priceable model from execute() produces non-null cost_usd.""" + from conductor.engine.usage import UsageTracker + + provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1)) + provider._client = self._FakeClient() + agent = AgentDef(name="a", model="auto", prompt="p") + + async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse: + return SDKResponse( + content='{"result":"ok"}', + input_tokens=1000, + output_tokens=500, + resolved_model=self._RESOLVED_MODEL_FROM_SDK, + ) + + async def noop() -> None: + return None + + monkeypatch.setattr(provider, "_ensure_client_started", noop) + monkeypatch.setattr(provider, "_send_and_wait", fake_send) + + result = await provider.execute(agent=agent, context={}, rendered_prompt="p") + tracker = UsageTracker() + usage = tracker.record("agent", result, elapsed=1.0) + assert usage.cost_usd is not None + assert usage.cost_usd > 0 + + @pytest.mark.asyncio + async def test_auto_model_without_resolved_model_remains_unpriced( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """model='auto' remains unpriceable when the SDK does not report a model.""" + from conductor.engine.usage import UsageTracker + + provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1)) + provider._client = self._FakeClient() + agent = AgentDef(name="a", model="auto", prompt="p") + + async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse: + return SDKResponse( + content='{"result":"ok"}', + input_tokens=1000, + output_tokens=500, + resolved_model=None, + ) + + async def noop() -> None: + return None + + monkeypatch.setattr(provider, "_ensure_client_started", noop) + monkeypatch.setattr(provider, "_send_and_wait", fake_send) + + result = await provider.execute(agent=agent, context={}, rendered_prompt="p") + usage = UsageTracker().record("agent", result, elapsed=1.0) + assert result.model == "auto" + assert usage.cost_usd is None + + @pytest.mark.asyncio + async def test_explicit_priceable_model_ignores_unpriceable_resolved_model( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Explicit configured models keep their pricing name over SDK aliases.""" + from conductor.engine.usage import UsageTracker + + provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1)) + provider._client = self._FakeClient() + agent = AgentDef(name="a", model=self._PRICEABLE_MODEL, prompt="p") + + async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse: + return SDKResponse( + content='{"result":"ok"}', + input_tokens=1000, + output_tokens=500, + resolved_model=self._UNPRICEABLE_RESOLVED_MODEL, + ) + + async def noop() -> None: + return None + + monkeypatch.setattr(provider, "_ensure_client_started", noop) + monkeypatch.setattr(provider, "_send_and_wait", fake_send) + + result = await provider.execute(agent=agent, context={}, rendered_prompt="p") + usage = UsageTracker().record("agent", result, elapsed=1.0) + assert result.model == self._PRICEABLE_MODEL + assert usage.cost_usd is not None + assert usage.cost_usd > 0 + + @pytest.mark.asyncio + async def test_followup_preserves_explicit_model_over_unpriceable_resolved_model( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Follow-up turns keep explicit pricing names over SDK aliases.""" + provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1)) + + async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse: + return SDKResponse( + content='{"result":"ok"}', + resolved_model=self._UNPRICEABLE_RESOLVED_MODEL, + ) + + monkeypatch.setattr(provider, "_send_and_wait", fake_send) + + result = await provider.send_followup( + self._FakeSession(), + guidance="continue", + agent_model=self._PRICEABLE_MODEL, + ) + assert result.model == self._PRICEABLE_MODEL + + @pytest.mark.asyncio + async def test_followup_uses_resolved_model_for_auto_model( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Follow-up turns use SDK resolved model for auto-routed agents.""" + provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1)) + + async def fake_send(*args: Any, **kwargs: Any) -> SDKResponse: + return SDKResponse( + content='{"result":"ok"}', + resolved_model=self._RESOLVED_MODEL_FROM_SDK, + ) + + monkeypatch.setattr(provider, "_send_and_wait", fake_send) + + result = await provider.send_followup( + self._FakeSession(), + guidance="continue", + agent_model="auto", + ) + assert result.model == self._RESOLVED_MODEL_FROM_SDK + + @pytest.mark.asyncio + async def test_send_and_wait_captures_model_from_usage_event(self) -> None: + """_send_and_wait extracts event.data.model from assistant.usage into resolved_model.""" + from unittest.mock import Mock as _Mock + + provider = CopilotProvider(retry_config=RetryConfig(max_attempts=1)) + captured_cb: list[Any] = [] + + # Build assistant.usage event with explicit token counts and model name. + usage_ev = _Mock() + usage_ev.type.value = "assistant.usage" + usage_ev.data.input_tokens = 100 + usage_ev.data.output_tokens = 50 + usage_ev.data.cache_read_tokens = None + usage_ev.data.cache_write_tokens = None + usage_ev.data.model = self._RESOLVED_MODEL_FROM_SDK + + # session.idle tells _send_and_wait the turn is complete. + idle_ev = _Mock() + idle_ev.type.value = "session.idle" + + def on_event(callback: Any) -> None: + captured_cb.append(callback) + + session = _Mock() + session.on = on_event + + async def fake_send(prompt: str) -> None: + assert captured_cb, "Expected _send_and_wait to register a session event callback" + callback = captured_cb[0] + for ev in (usage_ev, idle_ev): + callback(ev) + + session.send = fake_send + + result = await provider._send_and_wait( + session=session, + prompt="hello", + verbose_enabled=False, + full_enabled=False, + ) + assert result.resolved_model == self._RESOLVED_MODEL_FROM_SDK