From 3ae90d7950a8d825307f9b4f15c70db864da276f Mon Sep 17 00:00:00 2001 From: Anand Choubey Date: Fri, 22 May 2026 17:44:22 +0530 Subject: [PATCH 1/2] feat(agent): add custom_validator hook with retry to AsyncAgent.query MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an optional `custom_validator` callable (and `max_validation_retries`) parameter to `AsyncAgent.query`. When supplied, the validator is invoked with each LLM result; if it returns a non-empty feedback string, the prompt is re-issued with the feedback appended so the LLM can self-correct. Bounded by `max_validation_retries` (default 2 extra calls). This is useful for callers that need a structural / semantic check that strict_json's schema enforcement can't express — for example, validating that every UUID the LLM emits belongs to a known input set. Defaults preserve existing behavior (no validator, single call). Bumps version 3.4.3 -> 3.4.4. --- pyproject.toml | 2 +- setup.py | 2 +- taskgen/agent.py | 55 +++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3dbe238..8ea9be5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "taskgen-ai" -version = "3.4.3" +version = "3.4.4" authors = [ { name="John Tan Chong Min", email="tanchongmin@gmail.com" }, ] diff --git a/setup.py b/setup.py index b8283bc..f94ce75 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="taskgen", - version="3.4.3", + version="3.4.4", packages=find_packages(), install_requires=[ "openai>=1.59.6", diff --git a/taskgen/agent.py b/taskgen/agent.py index e892563..912cf53 100644 --- a/taskgen/agent.py +++ b/taskgen/agent.py @@ -9,7 +9,7 @@ import re import sys import base64 -from typing import Any, cast +from typing import Any, Awaitable, Callable, Optional, Union, cast from termcolor import colored import requests @@ -1235,11 +1235,22 @@ async def query( output_format: dict, provide_function_list: bool = False, task: str = "", + custom_validator: Optional[Callable[[Any], Union[Optional[str], Awaitable[Optional[str]]]]] = None, + max_validation_retries: int = 2, ): """Queries the agent with a query and outputs in output_format. If task is provided, we will filter the functions according to the task If you want to provide the agent with the context of functions available to it, set provide_function_list to True (default: False) - If task is given, then we will use it to do RAG over functions""" + If task is given, then we will use it to do RAG over functions + + Optional validation + retry: + custom_validator: callable invoked with the LLM result. Returns None + if the result is acceptable, or a feedback string describing + what was wrong; the feedback is appended to the next prompt so + the LLM can self-correct. May be sync or async. + max_validation_retries: max number of additional LLM calls after + the first when custom_validator rejects the result (default 2). + """ # if we have a task to focus on, we can filter the functions (other than use_llm and end_task) by that task filtered_fn_list = [] @@ -1252,14 +1263,38 @@ async def query( global_context_output = await self.get_global_context(self) if self.get_global_context is not None else "" input_user_query_object: list[str|dict] = self.prepare_input_user_prompt_for_query(query, provide_function_list, task, filtered_fn_list, global_context_output) - res = await strict_json_async( - system_prompt="", - user_prompt=input_user_query_object, - output_format=output_format, - verbose=self.debug, - llm=self.llm, - **self.kwargs, - ) + attempts = (max_validation_retries + 1) if custom_validator is not None else 1 + res: Any = None + prompt_for_attempt: list[str|dict] = input_user_query_object + for attempt in range(attempts): + res = await strict_json_async( + system_prompt="", + user_prompt=prompt_for_attempt, + output_format=output_format, + verbose=self.debug, + llm=self.llm, + **self.kwargs, + ) + + if custom_validator is None: + break + + feedback = custom_validator(res) + if inspect.isawaitable(feedback): + feedback = await feedback + if not feedback: + break + + if attempt == attempts - 1: + # Final attempt also failed validation; return what we have. + # Callers are responsible for handling an invalid final result. + break + + prompt_for_attempt = list(input_user_query_object) + [ + "\n\nThe previous response was rejected by the validator. " + f"Reason: {feedback}\n" + "Produce a new response that fixes the issue. Do not repeat the same mistake.", + ] return res From 4ad3a056e401b0598d48d4fa126405e51168b267 Mon Sep 17 00:00:00 2001 From: Anand Choubey Date: Mon, 8 Jun 2026 18:54:04 +0530 Subject: [PATCH 2/2] fix(agent): keep validator retries within strict json budget --- pyproject.toml | 2 +- setup.py | 2 +- taskgen/agent.py | 49 +++++++++++-------------------------------- taskgen/base_async.py | 28 ++++++++++++++++++++++--- 4 files changed, 39 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8ea9be5..3dbe238 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "taskgen-ai" -version = "3.4.4" +version = "3.4.3" authors = [ { name="John Tan Chong Min", email="tanchongmin@gmail.com" }, ] diff --git a/setup.py b/setup.py index f94ce75..b8283bc 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="taskgen", - version="3.4.4", + version="3.4.3", packages=find_packages(), install_requires=[ "openai>=1.59.6", diff --git a/taskgen/agent.py b/taskgen/agent.py index 912cf53..3191936 100644 --- a/taskgen/agent.py +++ b/taskgen/agent.py @@ -1236,7 +1236,6 @@ async def query( provide_function_list: bool = False, task: str = "", custom_validator: Optional[Callable[[Any], Union[Optional[str], Awaitable[Optional[str]]]]] = None, - max_validation_retries: int = 2, ): """Queries the agent with a query and outputs in output_format. If task is provided, we will filter the functions according to the task @@ -1247,9 +1246,8 @@ async def query( custom_validator: callable invoked with the LLM result. Returns None if the result is acceptable, or a feedback string describing what was wrong; the feedback is appended to the next prompt so - the LLM can self-correct. May be sync or async. - max_validation_retries: max number of additional LLM calls after - the first when custom_validator rejects the result (default 2). + the LLM can self-correct within strict_json_async's retry budget. + May be sync or async. """ # if we have a task to focus on, we can filter the functions (other than use_llm and end_task) by that task @@ -1263,38 +1261,15 @@ async def query( global_context_output = await self.get_global_context(self) if self.get_global_context is not None else "" input_user_query_object: list[str|dict] = self.prepare_input_user_prompt_for_query(query, provide_function_list, task, filtered_fn_list, global_context_output) - attempts = (max_validation_retries + 1) if custom_validator is not None else 1 - res: Any = None - prompt_for_attempt: list[str|dict] = input_user_query_object - for attempt in range(attempts): - res = await strict_json_async( - system_prompt="", - user_prompt=prompt_for_attempt, - output_format=output_format, - verbose=self.debug, - llm=self.llm, - **self.kwargs, - ) - - if custom_validator is None: - break - - feedback = custom_validator(res) - if inspect.isawaitable(feedback): - feedback = await feedback - if not feedback: - break - - if attempt == attempts - 1: - # Final attempt also failed validation; return what we have. - # Callers are responsible for handling an invalid final result. - break - - prompt_for_attempt = list(input_user_query_object) + [ - "\n\nThe previous response was rejected by the validator. " - f"Reason: {feedback}\n" - "Produce a new response that fixes the issue. Do not repeat the same mistake.", - ] + res = await strict_json_async( + system_prompt="", + user_prompt=input_user_query_object, + output_format=output_format, + verbose=self.debug, + llm=self.llm, + custom_validator=custom_validator, + **self.kwargs, + ) return res @@ -1805,4 +1780,4 @@ async def wrapper(*args, **kwargs): result = await method(*args, **all_kwargs) return result - return wrapper \ No newline at end of file + return wrapper diff --git a/taskgen/base_async.py b/taskgen/base_async.py index 3a056da..a7e3c1a 100644 --- a/taskgen/base_async.py +++ b/taskgen/base_async.py @@ -1,8 +1,9 @@ import asyncio +import inspect import json import re import ast -from typing import Tuple +from typing import Any, Awaitable, Callable, Optional, Tuple, Union from taskgen.base import convert_to_dict, parse_response_llm_check, type_check_and_convert, wrap_with_angle_brackets from taskgen.utils import ensure_awaitable @@ -229,7 +230,19 @@ async def chat_async(system_prompt: str, user_prompt: str|list[str|dict], model: ### Main Functions ### -async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], output_format: dict, return_as_json = False, custom_checks: dict|None = None, check_data = None, delimiter: str = '###', num_tries: int = 3, openai_json_mode: bool = False, **kwargs): +async def strict_json_async( + system_prompt: str, + user_prompt: str|list[str|dict], + output_format: dict, + return_as_json = False, + custom_checks: dict|None = None, + check_data = None, + delimiter: str = '###', + num_tries: int = 3, + openai_json_mode: bool = False, + custom_validator: Optional[Callable[[Any], Union[Optional[str], Awaitable[Optional[str]]]]] = None, + **kwargs +): r""" Ensures that OpenAI will always adhere to the desired output JSON format defined in output_format. Uses rule-based iterative feedback to ask GPT to self-correct. Keeps trying up to num_tries it it does not. Returns empty JSON if unable to after num_tries iterations. @@ -246,6 +259,8 @@ async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], - delimiter: String (Default: '###'). This is the delimiter to surround the keys. With delimiter ###, key becomes ###key### - num_tries: Integer (default: 3). The number of tries to iteratively prompt GPT to generate correct json format - openai_json_mode: Boolean (default: False). Whether or not to use OpenAI JSON Mode + - custom_validator: Optional callable invoked with the parsed output. Returns None if valid, + or a feedback string describing what to fix. May be sync or async. - **kwargs: Dict. Additional arguments for LLM chat Output: @@ -348,6 +363,13 @@ async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], raise Exception(f'Output field of "{key}" does not meet requirement "{requirement}". Action needed: "{action_needed}"') else: print('Requirement met\n\n') + if custom_validator is not None: + feedback = custom_validator(end_dict) + if inspect.isawaitable(feedback): + feedback = await feedback + if feedback: + print(f'Custom validator rejected output. Action needed: "{feedback}"\n\n') + raise Exception(f'Output does not meet custom validator requirement. Action needed: "{feedback}"') if return_as_json: return json.dumps(end_dict, ensure_ascii=False) else: @@ -363,4 +385,4 @@ async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], ### Legacy Support ### # alternative names for strict_json strict_text_async = strict_json_async -strict_output_async = strict_json_async \ No newline at end of file +strict_output_async = strict_json_async