diff --git a/taskgen/agent.py b/taskgen/agent.py index e892563..3191936 100644 --- a/taskgen/agent.py +++ b/taskgen/agent.py @@ -9,7 +9,7 @@ import re import sys import base64 -from typing import Any, cast +from typing import Any, Awaitable, Callable, Optional, Union, cast from termcolor import colored import requests @@ -1235,11 +1235,20 @@ async def query( output_format: dict, provide_function_list: bool = False, task: str = "", + custom_validator: Optional[Callable[[Any], Union[Optional[str], Awaitable[Optional[str]]]]] = None, ): """Queries the agent with a query and outputs in output_format. If task is provided, we will filter the functions according to the task If you want to provide the agent with the context of functions available to it, set provide_function_list to True (default: False) - If task is given, then we will use it to do RAG over functions""" + If task is given, then we will use it to do RAG over functions + + Optional validation + retry: + custom_validator: callable invoked with the LLM result. Returns None + if the result is acceptable, or a feedback string describing + what was wrong; the feedback is appended to the next prompt so + the LLM can self-correct within strict_json_async's retry budget. + May be sync or async. + """ # if we have a task to focus on, we can filter the functions (other than use_llm and end_task) by that task filtered_fn_list = [] @@ -1258,6 +1267,7 @@ async def query( output_format=output_format, verbose=self.debug, llm=self.llm, + custom_validator=custom_validator, **self.kwargs, ) @@ -1770,4 +1780,4 @@ async def wrapper(*args, **kwargs): result = await method(*args, **all_kwargs) return result - return wrapper \ No newline at end of file + return wrapper diff --git a/taskgen/base_async.py b/taskgen/base_async.py index 3a056da..a7e3c1a 100644 --- a/taskgen/base_async.py +++ b/taskgen/base_async.py @@ -1,8 +1,9 @@ import asyncio +import inspect import json import re import ast -from typing import Tuple +from typing import Any, Awaitable, Callable, Optional, Tuple, Union from taskgen.base import convert_to_dict, parse_response_llm_check, type_check_and_convert, wrap_with_angle_brackets from taskgen.utils import ensure_awaitable @@ -229,7 +230,19 @@ async def chat_async(system_prompt: str, user_prompt: str|list[str|dict], model: ### Main Functions ### -async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], output_format: dict, return_as_json = False, custom_checks: dict|None = None, check_data = None, delimiter: str = '###', num_tries: int = 3, openai_json_mode: bool = False, **kwargs): +async def strict_json_async( + system_prompt: str, + user_prompt: str|list[str|dict], + output_format: dict, + return_as_json = False, + custom_checks: dict|None = None, + check_data = None, + delimiter: str = '###', + num_tries: int = 3, + openai_json_mode: bool = False, + custom_validator: Optional[Callable[[Any], Union[Optional[str], Awaitable[Optional[str]]]]] = None, + **kwargs +): r""" Ensures that OpenAI will always adhere to the desired output JSON format defined in output_format. Uses rule-based iterative feedback to ask GPT to self-correct. Keeps trying up to num_tries it it does not. Returns empty JSON if unable to after num_tries iterations. @@ -246,6 +259,8 @@ async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], - delimiter: String (Default: '###'). This is the delimiter to surround the keys. With delimiter ###, key becomes ###key### - num_tries: Integer (default: 3). The number of tries to iteratively prompt GPT to generate correct json format - openai_json_mode: Boolean (default: False). Whether or not to use OpenAI JSON Mode + - custom_validator: Optional callable invoked with the parsed output. Returns None if valid, + or a feedback string describing what to fix. May be sync or async. - **kwargs: Dict. Additional arguments for LLM chat Output: @@ -348,6 +363,13 @@ async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], raise Exception(f'Output field of "{key}" does not meet requirement "{requirement}". Action needed: "{action_needed}"') else: print('Requirement met\n\n') + if custom_validator is not None: + feedback = custom_validator(end_dict) + if inspect.isawaitable(feedback): + feedback = await feedback + if feedback: + print(f'Custom validator rejected output. Action needed: "{feedback}"\n\n') + raise Exception(f'Output does not meet custom validator requirement. Action needed: "{feedback}"') if return_as_json: return json.dumps(end_dict, ensure_ascii=False) else: @@ -363,4 +385,4 @@ async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], ### Legacy Support ### # alternative names for strict_json strict_text_async = strict_json_async -strict_output_async = strict_json_async \ No newline at end of file +strict_output_async = strict_json_async