Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions sdks/python/agenta/sdk/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Annotated, ClassVar, List, Union, Optional, Dict, Literal, Any
from typing import (
Annotated,
ClassVar,
List,
Union,
Optional,
Dict,
Literal,
Any,
TypeAlias,
)

from pydantic import ConfigDict, BaseModel, HttpUrl, RootModel
from pydantic import Field, model_validator, AliasChoices
Expand Down Expand Up @@ -406,6 +416,33 @@ class ResponseFormatJSONSchema(BaseModel):
]


class ChatCompletionNamedToolChoiceFunction(BaseModel):
name: str


class ChatCompletionNamedToolChoice(BaseModel):
type: Literal["function"]
function: ChatCompletionNamedToolChoiceFunction


class ChatCompletionAllowedTools(BaseModel):
mode: Literal["auto", "required"]
tools: List[Dict[str, Any]]


class ChatCompletionAllowedToolChoice(BaseModel):
type: Literal["allowed_tools"]
allowed_tools: ChatCompletionAllowedTools


ToolChoice: TypeAlias = Union[
Literal["none", "auto", "required"],
ChatCompletionNamedToolChoice,
ChatCompletionAllowedToolChoice,
Dict[str, Any],
]


class ModelConfig(BaseModel):
"""Configuration for model parameters"""

Expand Down Expand Up @@ -468,7 +505,7 @@ class ModelConfig(BaseModel):
default=None,
description="A list of tools the model may call. Currently, only functions are supported as a tool",
)
tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = Field(
tool_choice: Optional[ToolChoice] = Field(
default=None, description="Controls which (if any) tool is called by the model"
)

Expand Down Expand Up @@ -596,7 +633,7 @@ class AgLLM(AgSchemaMixin):
json_schema_extra={"x-ag-type": "choice"},
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(default=None)
tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = Field(
tool_choice: Optional[ToolChoice] = Field(
default=None,
)
template_format: Literal["mustache", "curly", "fstring", "jinja2"] = Field(
Expand Down Expand Up @@ -1015,7 +1052,13 @@ def to_openai_kwargs(self, llm_config: Optional[ModelConfig] = None) -> dict:
kwargs["tools"] = llm_config.tools
# Only set tool_choice if tools are present
if llm_config.tool_choice is not None:
kwargs["tool_choice"] = llm_config.tool_choice
if isinstance(llm_config.tool_choice, BaseModel):
kwargs["tool_choice"] = llm_config.tool_choice.model_dump(
by_alias=True,
exclude_none=True,
)
else:
kwargs["tool_choice"] = llm_config.tool_choice

return kwargs

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import nullcontext

import pytest
from pydantic import ValidationError

from agenta.sdk.engines.running.handlers import (
_coerce_fallback_policy,
Expand Down Expand Up @@ -182,6 +183,99 @@ def test_null_chat_template_kwargs_is_omitted_from_provider_kwargs():
assert "chat_template_kwargs" not in prompt.to_openai_kwargs()


def test_tool_choice_required_is_passed_through_with_tools():
tool = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city.",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
prompt = PromptTemplate(
llm_config=ModelConfig(
model="gpt-4o-mini",
tools=[tool],
tool_choice="required",
)
)

kwargs = prompt.to_openai_kwargs()

assert kwargs["tools"] == [tool]
assert kwargs["tool_choice"] == "required"


def test_tool_choice_named_function_is_passed_through_with_tools():
tool = {
"type": "function",
"function": {"name": "get_weather", "parameters": {"type": "object"}},
}
prompt = PromptTemplate(
llm_config=ModelConfig(
model="gpt-4o-mini",
tools=[tool],
tool_choice={
"type": "function",
"function": {"name": "get_weather"},
},
)
)

assert prompt.to_openai_kwargs()["tool_choice"] == {
"type": "function",
"function": {"name": "get_weather"},
}


def test_tool_choice_allowed_tools_is_passed_through_with_tools():
allowed_tool = {
"type": "function",
"function": {"name": "get_weather"},
}
prompt = PromptTemplate(
llm_config=ModelConfig(
model="gpt-4o-mini",
tools=[allowed_tool],
tool_choice={
"type": "allowed_tools",
"allowed_tools": {
"mode": "required",
"tools": [allowed_tool],
},
},
)
)

assert prompt.to_openai_kwargs()["tool_choice"] == {
"type": "allowed_tools",
"allowed_tools": {
"mode": "required",
"tools": [allowed_tool],
},
}


def test_tool_choice_is_omitted_when_tools_are_absent():
prompt = PromptTemplate(
llm_config=ModelConfig(
model="gpt-4o-mini",
tool_choice="required",
)
)

assert "tool_choice" not in prompt.to_openai_kwargs()


def test_tool_choice_rejects_invalid_string_values():
with pytest.raises(ValidationError):
ModelConfig(model="gpt-4o-mini", tool_choice="always")


def test_fallback_config_uses_model_config_defaults():
prompt = PromptTemplate(fallback_configs=[{"temperature": 0.2}])

Expand Down
Loading