From e1e6dd44f3b88ddb1830437ff6759554a8823fe2 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Sat, 30 May 2026 09:57:19 +0800 Subject: [PATCH] fix: complete realtime tool failures --- src/agents/realtime/session.py | 94 +++++++++++++++++++++++++++++---- tests/realtime/test_session.py | 96 ++++++++++++++++++++++++++++++---- 2 files changed, 171 insertions(+), 19 deletions(-) diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index ca809dd9c4..bd269b3799 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -22,7 +22,13 @@ from ..logger import logger from ..run_config import ToolErrorFormatterArgs from ..run_context import RunContextWrapper, TContext -from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, FunctionTool, invoke_function_tool +from ..tool import ( + DEFAULT_APPROVAL_REJECTION_MESSAGE, + FunctionTool, + default_tool_error_function, + invoke_function_tool, + maybe_invoke_function_tool_failure_error_function, +) from ..tool_context import ToolContext from ..util._approvals import evaluate_needs_approval_setting from .agent import RealtimeAgent @@ -714,6 +720,54 @@ async def reject_tool_call( finally: self._finish_tool_call(call_id, mark_completed=mark_completed) + async def _send_function_tool_failure_output( + self, + event: RealtimeModelToolCallEvent, + *, + tool: FunctionTool, + tool_context: ToolContext[Any], + agent: RealtimeAgent, + error: Exception, + ) -> bool: + output = await maybe_invoke_function_tool_failure_error_function( + function_tool=tool, + context=tool_context, + error=error, + ) + if output is None: + return False + + await self._send_tool_output_completion( + _PendingToolOutput( + tool_call=event, + output=output, + start_response=True, + tool_end_event=RealtimeToolEnd( + info=self._event_info, + tool=tool, + output=output, + agent=agent, + arguments=event.arguments, + ), + ) + ) + return True + + async def _send_handoff_failure_output( + self, + event: RealtimeModelToolCallEvent, + *, + tool_context: ToolContext[Any], + error: Exception, + ) -> None: + await self._send_tool_output_completion( + _PendingToolOutput( + tool_call=event, + output=default_tool_error_function(tool_context, error), + start_response=True, + ) + ) + async def _handle_tool_call( self, event: RealtimeModelToolCallEvent, @@ -773,11 +827,22 @@ async def _handle_tool_call( tool_arguments=event.arguments, agent=agent, ) - result = await invoke_function_tool( - function_tool=func_tool, - context=tool_context, - arguments=event.arguments, - ) + try: + result = await invoke_function_tool( + function_tool=func_tool, + context=tool_context, + arguments=event.arguments, + ) + except Exception as exc: + if await self._send_function_tool_failure_output( + event, + tool=func_tool, + tool_context=tool_context, + agent=agent, + error=exc, + ): + mark_completed = True + raise await self._send_tool_output_completion( _PendingToolOutput( @@ -806,11 +871,20 @@ async def _handle_tool_call( ) # Execute the handoff to get the new agent - result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments) - if not isinstance(result, RealtimeAgent): - raise UserError( - f"Handoff {handoff.tool_name} returned invalid result: {type(result)}" + try: + result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments) + if not isinstance(result, RealtimeAgent): + raise UserError( + f"Handoff {handoff.tool_name} returned invalid result: {type(result)}" + ) + except Exception as exc: + await self._send_handoff_failure_output( + event, + tool_context=tool_context, + error=exc, ) + mark_completed = True + raise # Store previous agent for event previous_agent = agent diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 03148c739a..dcb8596ad7 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -220,6 +220,11 @@ async def test_handle_tool_call_handoff_invalid_result_raises(): RealtimeModelToolCallEvent(name="switch", call_id="c1", arguments="{}") ) + outputs = [event for event in model.events if isinstance(event, RealtimeModelSendToolOutput)] + assert len(outputs) == 1 + assert outputs[0].start_response is True + assert "Handoff switch returned invalid result" in outputs[0].output + @pytest.mark.asyncio async def test_on_guardrail_task_done_emits_error_event(): @@ -1209,13 +1214,49 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str: with pytest.raises(ToolTimeoutError, match="timed out"): await session._handle_tool_call(tool_call_event) - assert len(mock_model.sent_tool_outputs) == 0 - assert session._event_queue.qsize() == 1 + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert start_response is True + assert "timed out" in sent_output.lower() + assert session._event_queue.qsize() == 2 tool_start_event = await session._event_queue.get() assert isinstance(tool_start_event, RealtimeToolStart) assert tool_start_event.tool == timeout_tool assert tool_start_event.arguments == "{}" + tool_end_event = await session._event_queue.get() + assert isinstance(tool_end_event, RealtimeToolEnd) + assert "timed out" in str(tool_end_event.output).lower() + + @pytest.mark.asyncio + async def test_function_tool_exception_sends_model_visible_output(self, mock_model, mock_agent): + async def failing_tool(_ctx: ToolContext[Any], _arguments: str) -> str: + raise ValueError("tool failed") + + function_tool = FunctionTool( + name="failing_tool", + description="fails", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=failing_tool, + ) + mock_agent.get_all_tools.return_value = [function_tool] + + session = RealtimeSession(mock_model, mock_agent, None) + tool_call_event = RealtimeModelToolCallEvent( + name="failing_tool", + call_id="call_fails", + arguments="{}", + ) + + with pytest.raises(ValueError, match="tool failed"): + await session._handle_tool_call(tool_call_event) + + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert start_response is True + assert "tool failed" in sent_output @pytest.mark.asyncio async def test_function_tool_timeout_uses_async_error_function_result( @@ -1296,7 +1337,11 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str: assert isinstance(session._stored_exception, ToolTimeoutError) assert session._stored_exception.tool_name == "slow_tool" - assert len(mock_model.sent_tool_outputs) == 0 + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert start_response is True + assert "timed out" in sent_output.lower() events = [] while True: @@ -1310,6 +1355,7 @@ async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str: for event in events ) assert any(isinstance(event, RealtimeToolStart) for event in events) + assert any(isinstance(event, RealtimeToolEnd) for event in events) error_event = next(event for event in events if isinstance(event, RealtimeError)) assert "Tool call task failed" in error_event.error["message"] @@ -1386,6 +1432,34 @@ async def test_handoff_tool_handling(self, mock_model): # Verify agent was updated assert session._current_agent == second_agent + @pytest.mark.asyncio + async def test_handoff_tool_exception_sends_model_visible_output(self, mock_model): + handoff = Handoff( + tool_name="transfer_to_broken_agent", + tool_description="broken handoff", + input_json_schema={}, + on_invoke_handoff=AsyncMock(side_effect=RuntimeError("handoff failed")), + input_filter=None, + agent_name="broken_agent", + is_enabled=True, + ) + agent = RealtimeAgent(name="agent", handoffs=[handoff]) + session = RealtimeSession(mock_model, agent, None) + tool_call_event = RealtimeModelToolCallEvent( + name="transfer_to_broken_agent", + call_id="call_handoff_fails", + arguments="{}", + ) + + with pytest.raises(RuntimeError, match="handoff failed"): + await session._handle_tool_call(tool_call_event) + + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert start_response is True + assert "handoff failed" in sent_output + @pytest.mark.asyncio async def test_handoff_session_update_preserves_custom_voice(self, mock_model): custom_voice = {"id": "voice_test"} @@ -1868,7 +1942,7 @@ async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str: async def test_function_tool_exception_handling( self, mock_model, mock_agent, mock_function_tool ): - """Test that exceptions in function tools are handled (currently they propagate)""" + """Test that function tool exceptions notify the model before propagating locally.""" # Set up tool to raise exception mock_function_tool.on_invoke_tool.side_effect = ValueError("Tool error") mock_agent.get_all_tools.return_value = [mock_function_tool] @@ -1879,18 +1953,22 @@ async def test_function_tool_exception_handling( name="test_function", call_id="call_error", arguments="{}" ) - # Currently exceptions propagate (no error handling implemented) with pytest.raises(ValueError, match="Tool error"): await session._handle_tool_call(tool_call_event) - # Tool start event should have been queued before the error - assert session._event_queue.qsize() == 1 + assert session._event_queue.qsize() == 2 tool_start_event = await session._event_queue.get() assert isinstance(tool_start_event, RealtimeToolStart) assert tool_start_event.arguments == "{}" - # But no tool output should have been sent and no end event queued - assert len(mock_model.sent_tool_outputs) == 0 + tool_end_event = await session._event_queue.get() + assert isinstance(tool_end_event, RealtimeToolEnd) + assert "Tool error" in str(tool_end_event.output) + + assert len(mock_model.sent_tool_outputs) == 1 + _sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert "Tool error" in sent_output + assert start_response is True @pytest.mark.asyncio async def test_tool_call_with_complex_arguments(