diff --git a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py index 40b1ce47d991..a2f1544e4f67 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py @@ -182,7 +182,10 @@ async def call_tool_stream( cancellation_token: CancellationToken | None = None, call_id: str | None = None, ) -> AsyncGenerator[Any | ToolResult, None]: - tool = next((tool for tool in self._tools if tool.name == name), None) + # Check if the name is an override name and map it back to the original + original_name = self._override_name_to_original.get(name, name) + + tool = next((tool for tool in self._tools if tool.name == original_name), None) if tool is None: yield ToolResult( name=name, @@ -210,7 +213,7 @@ async def call_tool_stream( yield previous_result # Then yield the error result result_str = self._format_errors(e) - yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=True) + yield ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=True) return else: # If the tool is not a stream tool, we run it normally and yield the result @@ -222,4 +225,4 @@ async def call_tool_stream( except Exception as e: result_str = self._format_errors(e) is_error = True - yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error) + yield ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error) diff --git a/python/packages/autogen-core/tests/test_static_workbench_overrides.py b/python/packages/autogen-core/tests/test_static_workbench_overrides.py index 37cf1b752f35..fbef21ff7b43 100644 --- a/python/packages/autogen-core/tests/test_static_workbench_overrides.py +++ b/python/packages/autogen-core/tests/test_static_workbench_overrides.py @@ -2,7 +2,14 @@ import pytest from autogen_core.code_executor import ImportFromModule -from autogen_core.tools import FunctionTool, StaticWorkbench, ToolOverride, Workbench +from autogen_core.tools import ( + FunctionTool, + StaticStreamWorkbench, + StaticWorkbench, + ToolOverride, + ToolResult, + Workbench, +) @pytest.mark.asyncio @@ -283,3 +290,31 @@ def test_tool_func_3(x: Annotated[int, "Number"]) -> int: } workbench_self = StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_self) assert "tool1" not in workbench_self._override_name_to_original # type: ignore[reportPrivateUsage] + + +@pytest.mark.asyncio +async def test_static_stream_workbench_call_tool_stream_honors_override() -> None: + """call_tool_stream must resolve override names like call_tool does.""" + + def double(x: Annotated[int, "The number to double."]) -> int: + return x * 2 + + tool = FunctionTool( + double, + name="double", + description="A test tool that doubles a number.", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + overrides: Dict[str, ToolOverride] = { + "double": ToolOverride(name="multiply_by_two", description="Multiplies a number by 2"), + } + + async with StaticStreamWorkbench(tools=[tool], tool_overrides=overrides) as workbench: + result: ToolResult | None = None + async for item in workbench.call_tool_stream("multiply_by_two", {"x": 5}): + if isinstance(item, ToolResult): + result = item + + assert result is not None + assert result.is_error is False + assert result.name == "multiply_by_two"