Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ async def call_tool_stream(
cancellation_token: CancellationToken | None = None,
call_id: str | None = None,
) -> AsyncGenerator[Any | ToolResult, None]:
tool = next((tool for tool in self._tools if tool.name == name), None)
# Check if the name is an override name and map it back to the original
original_name = self._override_name_to_original.get(name, name)

tool = next((tool for tool in self._tools if tool.name == original_name), None)
if tool is None:
yield ToolResult(
name=name,
Expand Down Expand Up @@ -210,7 +213,7 @@ async def call_tool_stream(
yield previous_result
# Then yield the error result
result_str = self._format_errors(e)
yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=True)
yield ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=True)
return
else:
# If the tool is not a stream tool, we run it normally and yield the result
Expand All @@ -222,4 +225,4 @@ async def call_tool_stream(
except Exception as e:
result_str = self._format_errors(e)
is_error = True
yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)
yield ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

import pytest
from autogen_core.code_executor import ImportFromModule
from autogen_core.tools import FunctionTool, StaticWorkbench, ToolOverride, Workbench
from autogen_core.tools import (
FunctionTool,
StaticStreamWorkbench,
StaticWorkbench,
ToolOverride,
ToolResult,
Workbench,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -283,3 +290,31 @@ def test_tool_func_3(x: Annotated[int, "Number"]) -> int:
}
workbench_self = StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_self)
assert "tool1" not in workbench_self._override_name_to_original # type: ignore[reportPrivateUsage]


@pytest.mark.asyncio
async def test_static_stream_workbench_call_tool_stream_honors_override() -> None:
"""call_tool_stream must resolve override names like call_tool does."""

def double(x: Annotated[int, "The number to double."]) -> int:
return x * 2

tool = FunctionTool(
double,
name="double",
description="A test tool that doubles a number.",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
overrides: Dict[str, ToolOverride] = {
"double": ToolOverride(name="multiply_by_two", description="Multiplies a number by 2"),
}

async with StaticStreamWorkbench(tools=[tool], tool_overrides=overrides) as workbench:
result: ToolResult | None = None
async for item in workbench.call_tool_stream("multiply_by_two", {"x": 5}):
if isinstance(item, ToolResult):
result = item

assert result is not None
assert result.is_error is False
assert result.name == "multiply_by_two"