diff --git a/sdks/python/agenta/sdk/utils/types.py b/sdks/python/agenta/sdk/utils/types.py index 8e629b92fb..01bc5003ad 100644 --- a/sdks/python/agenta/sdk/utils/types.py +++ b/sdks/python/agenta/sdk/utils/types.py @@ -2,7 +2,17 @@ from copy import deepcopy from dataclasses import dataclass from enum import Enum -from typing import Annotated, ClassVar, List, Union, Optional, Dict, Literal, Any +from typing import ( + Annotated, + ClassVar, + List, + Union, + Optional, + Dict, + Literal, + Any, + TypeAlias, +) from pydantic import ConfigDict, BaseModel, HttpUrl, RootModel from pydantic import Field, model_validator, AliasChoices @@ -406,6 +416,33 @@ class ResponseFormatJSONSchema(BaseModel): ] +class ChatCompletionNamedToolChoiceFunction(BaseModel): + name: str + + +class ChatCompletionNamedToolChoice(BaseModel): + type: Literal["function"] + function: ChatCompletionNamedToolChoiceFunction + + +class ChatCompletionAllowedTools(BaseModel): + mode: Literal["auto", "required"] + tools: List[Dict[str, Any]] + + +class ChatCompletionAllowedToolChoice(BaseModel): + type: Literal["allowed_tools"] + allowed_tools: ChatCompletionAllowedTools + + +ToolChoice: TypeAlias = Union[ + Literal["none", "auto", "required"], + ChatCompletionNamedToolChoice, + ChatCompletionAllowedToolChoice, + Dict[str, Any], +] + + class ModelConfig(BaseModel): """Configuration for model parameters""" @@ -468,7 +505,7 @@ class ModelConfig(BaseModel): default=None, description="A list of tools the model may call. Currently, only functions are supported as a tool", ) - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = Field( + tool_choice: Optional[ToolChoice] = Field( default=None, description="Controls which (if any) tool is called by the model" ) @@ -596,7 +633,7 @@ class AgLLM(AgSchemaMixin): json_schema_extra={"x-ag-type": "choice"}, ) chat_template_kwargs: Optional[Dict[str, Any]] = Field(default=None) - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = Field( + tool_choice: Optional[ToolChoice] = Field( default=None, ) template_format: Literal["mustache", "curly", "fstring", "jinja2"] = Field( @@ -1015,7 +1052,13 @@ def to_openai_kwargs(self, llm_config: Optional[ModelConfig] = None) -> dict: kwargs["tools"] = llm_config.tools # Only set tool_choice if tools are present if llm_config.tool_choice is not None: - kwargs["tool_choice"] = llm_config.tool_choice + if isinstance(llm_config.tool_choice, BaseModel): + kwargs["tool_choice"] = llm_config.tool_choice.model_dump( + by_alias=True, + exclude_none=True, + ) + else: + kwargs["tool_choice"] = llm_config.tool_choice return kwargs diff --git a/sdks/python/oss/tests/pytest/unit/test_prompt_template_extensions.py b/sdks/python/oss/tests/pytest/unit/test_prompt_template_extensions.py index d29a5246ef..eb620440b0 100644 --- a/sdks/python/oss/tests/pytest/unit/test_prompt_template_extensions.py +++ b/sdks/python/oss/tests/pytest/unit/test_prompt_template_extensions.py @@ -1,6 +1,7 @@ from contextlib import nullcontext import pytest +from pydantic import ValidationError from agenta.sdk.engines.running.handlers import ( _coerce_fallback_policy, @@ -182,6 +183,99 @@ def test_null_chat_template_kwargs_is_omitted_from_provider_kwargs(): assert "chat_template_kwargs" not in prompt.to_openai_kwargs() +def test_tool_choice_required_is_passed_through_with_tools(): + tool = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city.", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + prompt = PromptTemplate( + llm_config=ModelConfig( + model="gpt-4o-mini", + tools=[tool], + tool_choice="required", + ) + ) + + kwargs = prompt.to_openai_kwargs() + + assert kwargs["tools"] == [tool] + assert kwargs["tool_choice"] == "required" + + +def test_tool_choice_named_function_is_passed_through_with_tools(): + tool = { + "type": "function", + "function": {"name": "get_weather", "parameters": {"type": "object"}}, + } + prompt = PromptTemplate( + llm_config=ModelConfig( + model="gpt-4o-mini", + tools=[tool], + tool_choice={ + "type": "function", + "function": {"name": "get_weather"}, + }, + ) + ) + + assert prompt.to_openai_kwargs()["tool_choice"] == { + "type": "function", + "function": {"name": "get_weather"}, + } + + +def test_tool_choice_allowed_tools_is_passed_through_with_tools(): + allowed_tool = { + "type": "function", + "function": {"name": "get_weather"}, + } + prompt = PromptTemplate( + llm_config=ModelConfig( + model="gpt-4o-mini", + tools=[allowed_tool], + tool_choice={ + "type": "allowed_tools", + "allowed_tools": { + "mode": "required", + "tools": [allowed_tool], + }, + }, + ) + ) + + assert prompt.to_openai_kwargs()["tool_choice"] == { + "type": "allowed_tools", + "allowed_tools": { + "mode": "required", + "tools": [allowed_tool], + }, + } + + +def test_tool_choice_is_omitted_when_tools_are_absent(): + prompt = PromptTemplate( + llm_config=ModelConfig( + model="gpt-4o-mini", + tool_choice="required", + ) + ) + + assert "tool_choice" not in prompt.to_openai_kwargs() + + +def test_tool_choice_rejects_invalid_string_values(): + with pytest.raises(ValidationError): + ModelConfig(model="gpt-4o-mini", tool_choice="always") + + def test_fallback_config_uses_model_config_defaults(): prompt = PromptTemplate(fallback_configs=[{"temperature": 0.2}])