Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions taskgen/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
import sys
import base64
from typing import Any, cast
from typing import Any, Awaitable, Callable, Optional, Union, cast

from termcolor import colored
import requests
Expand Down Expand Up @@ -1235,11 +1235,20 @@ async def query(
output_format: dict,
provide_function_list: bool = False,
task: str = "",
custom_validator: Optional[Callable[[Any], Union[Optional[str], Awaitable[Optional[str]]]]] = None,
):
"""Queries the agent with a query and outputs in output_format.
If task is provided, we will filter the functions according to the task
If you want to provide the agent with the context of functions available to it, set provide_function_list to True (default: False)
If task is given, then we will use it to do RAG over functions"""
If task is given, then we will use it to do RAG over functions

Optional validation + retry:
custom_validator: callable invoked with the LLM result. Returns None
if the result is acceptable, or a feedback string describing
what was wrong; the feedback is appended to the next prompt so
the LLM can self-correct within strict_json_async's retry budget.
May be sync or async.
"""

# if we have a task to focus on, we can filter the functions (other than use_llm and end_task) by that task
filtered_fn_list = []
Expand All @@ -1258,6 +1267,7 @@ async def query(
output_format=output_format,
verbose=self.debug,
llm=self.llm,
custom_validator=custom_validator,
**self.kwargs,
)

Expand Down Expand Up @@ -1770,4 +1780,4 @@ async def wrapper(*args, **kwargs):
result = await method(*args, **all_kwargs)
return result

return wrapper
return wrapper
28 changes: 25 additions & 3 deletions taskgen/base_async.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import inspect
import json
import re
import ast
from typing import Tuple
from typing import Any, Awaitable, Callable, Optional, Tuple, Union
from taskgen.base import convert_to_dict, parse_response_llm_check, type_check_and_convert, wrap_with_angle_brackets

from taskgen.utils import ensure_awaitable
Expand Down Expand Up @@ -229,7 +230,19 @@ async def chat_async(system_prompt: str, user_prompt: str|list[str|dict], model:
### Main Functions ###


async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], output_format: dict, return_as_json = False, custom_checks: dict|None = None, check_data = None, delimiter: str = '###', num_tries: int = 3, openai_json_mode: bool = False, **kwargs):
async def strict_json_async(
system_prompt: str,
user_prompt: str|list[str|dict],
output_format: dict,
return_as_json = False,
custom_checks: dict|None = None,
check_data = None,
delimiter: str = '###',
num_tries: int = 3,
openai_json_mode: bool = False,
custom_validator: Optional[Callable[[Any], Union[Optional[str], Awaitable[Optional[str]]]]] = None,
**kwargs
):
r""" Ensures that OpenAI will always adhere to the desired output JSON format defined in output_format.
Uses rule-based iterative feedback to ask GPT to self-correct.
Keeps trying up to num_tries it it does not. Returns empty JSON if unable to after num_tries iterations.
Expand All @@ -246,6 +259,8 @@ async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict],
- delimiter: String (Default: '###'). This is the delimiter to surround the keys. With delimiter ###, key becomes ###key###
- num_tries: Integer (default: 3). The number of tries to iteratively prompt GPT to generate correct json format
- openai_json_mode: Boolean (default: False). Whether or not to use OpenAI JSON Mode
- custom_validator: Optional callable invoked with the parsed output. Returns None if valid,
or a feedback string describing what to fix. May be sync or async.
- **kwargs: Dict. Additional arguments for LLM chat

Output:
Expand Down Expand Up @@ -348,6 +363,13 @@ async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict],
raise Exception(f'Output field of "{key}" does not meet requirement "{requirement}". Action needed: "{action_needed}"')
else:
print('Requirement met\n\n')
if custom_validator is not None:
feedback = custom_validator(end_dict)
if inspect.isawaitable(feedback):
feedback = await feedback
if feedback:
print(f'Custom validator rejected output. Action needed: "{feedback}"\n\n')
raise Exception(f'Output does not meet custom validator requirement. Action needed: "{feedback}"')
if return_as_json:
return json.dumps(end_dict, ensure_ascii=False)
else:
Expand All @@ -363,4 +385,4 @@ async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict],
### Legacy Support ###
# alternative names for strict_json
strict_text_async = strict_json_async
strict_output_async = strict_json_async
strict_output_async = strict_json_async
Loading