Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "taskgen-ai"
version = "3.4.2"
version = "3.4.3"
authors = [
{ name="John Tan Chong Min", email="tanchongmin@gmail.com" },
]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="taskgen",
version="3.4.2",
version="3.4.3",
packages=find_packages(),
install_requires=[
"openai>=1.59.6",
Expand Down
278 changes: 160 additions & 118 deletions taskgen/agent.py

Large diffs are not rendered by default.

42 changes: 26 additions & 16 deletions taskgen/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
from typing import Tuple

# TODO we never use sync flow anywhere, better to cleanup sync flow altogether

### Helper Functions ###

Expand All @@ -14,11 +15,11 @@ def convert_to_list(field: str, **kwargs) -> list:
res = chat(system_msg, user_msg, **kwargs)

# Extract out list items
field = re.findall(r'\(%item\)\s*(.*?)\n*(?=\(%item\)|$)', res, flags=re.DOTALL)
return field
items = re.findall(r'\(%item\)\s*(.*?)\n*(?=\(%item\)|$)', res, flags=re.DOTALL)
return items


def convert_to_dict(field: str, keys: dict, delimiter: str) -> dict:
def convert_to_dict(field: str, keys:list, delimiter: str) -> dict:
'''Converts the string field into a dictionary with keys by splitting on '{delimiter}{key}{delimiter}' '''
output_d = {}
for key in keys:
Expand Down Expand Up @@ -201,7 +202,7 @@ def type_check_and_convert(field, key, data_type, **kwargs):



def check_datatype(field, key: dict, data_type: str, **kwargs):
def check_datatype(field, key, data_type: str, **kwargs):
''' Ensures that output field of the key of JSON dictionary is of data_type
Currently supports int, float, str, code, enum, lists, nested lists, dict, dict with keys
Takes in **kwargs for the LLM model
Expand Down Expand Up @@ -265,7 +266,7 @@ def check_datatype(field, key: dict, data_type: str, **kwargs):



def check_key(field: str, output_format, new_output_format, delimiter: str, delimiter_num: int, **kwargs):
def check_key(field, output_format, new_output_format, delimiter: str, delimiter_num: int, **kwargs):
''' Check whether each key in dict, or elements in list of new_output_format is present in field, and whether they meet the right data type requirements, then convert field to the right data type
If needed, calls LLM model with parameters **kwargs to correct the output format for improperly formatted list
output_format is user-given output format at each level, new_output_format is with delimiters in keys, and angle brackets surrounding values
Expand All @@ -278,7 +279,7 @@ def check_key(field: str, output_format, new_output_format, delimiter: str, deli
# this is the processed output dictionary for that particular layer in the output structure
output_d = {}
# check key appears for each element in the output
output_d = convert_to_dict(field, output_format.keys(), cur_delimiter)
output_d = convert_to_dict(field, list(output_format.keys()), cur_delimiter)

# after creating dictionary, step into next layer
for key, value in output_d.items():
Expand Down Expand Up @@ -368,22 +369,23 @@ def wrap_with_angle_brackets(d: dict, delimiter: str, delimiter_num: int) -> dic
else:
return d

def chat(system_prompt: str, user_prompt: str, model: str = 'gpt-4o-mini', temperature: float = 0, verbose: bool = False, host: str = 'openai', llm = None, **kwargs):
def chat(system_prompt: str, user_prompt: str|list[str], model: str = 'gpt-4o-mini', temperature: float = 0, verbose: bool = False, host: str = 'openai', llm = None, **kwargs):
r"""Performs a chat with the host's LLM model with system prompt, user prompt, model, verbose and kwargs
Returns the output string res
- system_prompt: String. Write in whatever you want the LLM to become. e.g. "You are a \<purpose in life\>"
- user_prompt: String. The user input. Later, when we use it as a function, this is the function input
- user_prompt: String or list of strings. The user input. Later, when we use it as a function, this is the function input
- model: String. The LLM model to use for json generation
- verbose: Boolean (default: False). Whether or not to print out the system prompt, user prompt, GPT response
- host: String. The provider of the LLM
- llm: User-made llm function.
- Inputs:
- system_prompt: String. Write in whatever you want the LLM to become. e.g. "You are a \<purpose in life\>"
- user_prompt: String. The user input. Later, when we use it as a function, this is the function input
- user_prompt: String or list of strings. The user input. Later, when we use it as a function, this is the function input
- Output:
- res: String. The response of the LLM call
- **kwargs: Dict. Additional arguments for LLM chat
"""
res =""
if llm is not None:
''' If you specified your own LLM, then we just feed in the system and user prompt
LLM function should take in system prompt (str) and user prompt (str), and output a response (str) '''
Expand All @@ -401,6 +403,8 @@ def chat(system_prompt: str, user_prompt: str, model: str = 'gpt-4o-mini', tempe

from openai import OpenAI
client = OpenAI()
if isinstance(user_prompt, list):
user_prompt = "\n".join(user_prompt)
response = client.chat.completions.create(
model=model,
temperature = temperature,
Expand All @@ -421,14 +425,14 @@ def chat(system_prompt: str, user_prompt: str, model: str = 'gpt-4o-mini', tempe


### Main Functions ###
def strict_json(system_prompt: str, user_prompt: str, output_format: dict, return_as_json = False, custom_checks: dict = None, check_data = None, delimiter: str = '###', num_tries: int = 3, openai_json_mode: bool = False, **kwargs):
def strict_json(system_prompt: str, user_prompt: str|list[str], output_format: dict, return_as_json = False, custom_checks: dict|None = None, check_data = None, delimiter: str = '###', num_tries: int = 3, openai_json_mode: bool = False, **kwargs):
r""" Ensures that OpenAI will always adhere to the desired output JSON format defined in output_format.
Uses rule-based iterative feedback to ask GPT to self-correct.
Keeps trying up to num_tries it it does not. Returns empty JSON if unable to after num_tries iterations.

Inputs (compulsory):
- system_prompt: String. Write in whatever you want GPT to become. e.g. "You are a \<purpose in life\>"
- user_prompt: String. The user input. Later, when we use it as a function, this is the function input
- user_prompt: String or list of strings. The user input. Later, when we use it as a function, this is the function input
- output_format: Dict. JSON format with the key as the output key, and the value as the output description

Inputs (optional):
Expand Down Expand Up @@ -459,8 +463,11 @@ def strict_json(system_prompt: str, user_prompt: str, output_format: dict, retur

output_format_prompt = "\nOutput in the following json string format: " + str(output_format) + "\nBe concise."

my_system_prompt = str(system_prompt) + output_format_prompt
my_user_prompt = str(user_prompt)
my_system_prompt = str(system_prompt)
if isinstance(user_prompt, list):
my_user_prompt = user_prompt + [output_format_prompt]
else:
my_user_prompt = [user_prompt] + [output_format_prompt]

res = chat(my_system_prompt, my_user_prompt, response_format = {"type": "json_object"}, **kwargs)

Expand All @@ -487,8 +494,11 @@ def strict_json(system_prompt: str, user_prompt: str, output_format: dict, retur
Ensure the following output keys are present in the json: {' '.join(list(new_output_format.keys()))}'''

for i in range(num_tries):
my_system_prompt = str(system_prompt) + output_format_prompt + error_msg
my_user_prompt = str(user_prompt)
my_system_prompt = str(system_prompt)
if isinstance(user_prompt, list):
my_user_prompt = user_prompt + [output_format_prompt] + [error_msg]
else:
my_user_prompt = [user_prompt] + [output_format_prompt] + [error_msg]

# Use OpenAI to get a response
res = chat(my_system_prompt, my_user_prompt, **kwargs)
Expand All @@ -514,7 +524,7 @@ def strict_json(system_prompt: str, user_prompt: str, output_format: dict, retur

# do checks for keys and output format, remove escape characters so code can be run
end_dict = check_key(res, output_format, new_output_format, delimiter, delimiter_num = 1, **kwargs)

assert isinstance(end_dict, dict)
# run user defined custom checks now
for key in end_dict:
if key in custom_checks:
Expand Down
45 changes: 31 additions & 14 deletions taskgen/base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ async def convert_to_list_async(field: str, **kwargs) -> list:
res = await chat_async(system_msg, user_msg, **kwargs)

# Extract out list items
field = re.findall(r'\(%item\)\s*(.*?)\n*(?=\(%item\)|$)', res, flags=re.DOTALL)
return field
items = re.findall(r'\(%item\)\s*(.*?)\n*(?=\(%item\)|$)', res, flags=re.DOTALL)
return items



Expand All @@ -41,7 +41,7 @@ async def llm_check_async(field, llm_check_msg: str, **kwargs) -> Tuple[bool, st
return requirement_met, action_needed


async def check_datatype_async(field, key: dict, data_type: str, **kwargs):
async def check_datatype_async(field, key, data_type: str, **kwargs):
''' Ensures that output field of the key of JSON dictionary is of data_type
Currently supports int, float, str, code, enum, lists, nested lists, dict, dict with keys
Takes in **kwargs for the LLM model
Expand Down Expand Up @@ -104,7 +104,7 @@ async def check_datatype_async(field, key: dict, data_type: str, **kwargs):



async def check_key_async(field: str, output_format, new_output_format, delimiter: str, delimiter_num: int, **kwargs):
async def check_key_async(field, output_format, new_output_format, delimiter: str, delimiter_num: int, **kwargs):
''' Check whether each key in dict, or elements in list of new_output_format is present in field, and whether they meet the right data type requirements, then convert field to the right data type
If needed, calls LLM model with parameters **kwargs to correct the output format for improperly formatted list
output_format is user-given output format at each level, new_output_format is with delimiters in keys, and angle brackets surrounding values
Expand All @@ -117,7 +117,7 @@ async def check_key_async(field: str, output_format, new_output_format, delimite
# this is the processed output dictionary for that particular layer in the output structure
output_d = {}
# check key appears for each element in the output
output_d = convert_to_dict(field, output_format.keys(), cur_delimiter)
output_d = convert_to_dict(field, list(output_format.keys()), cur_delimiter)

# after creating dictionary, step into next layer
for key, value in output_d.items():
Expand Down Expand Up @@ -169,11 +169,11 @@ async def check_key_async(field: str, output_format, new_output_format, delimite



async def chat_async(system_prompt: str, user_prompt: str, model: str = 'gpt-4o-mini', temperature: float = 0, verbose: bool = False, host: str = 'openai', llm= None, **kwargs):
async def chat_async(system_prompt: str, user_prompt: str|list[str|dict], model: str = 'gpt-4o-mini', temperature: float = 0, verbose: bool = False, host: str = 'openai', llm= None, **kwargs):
r"""Performs a chat with the host's LLM model with system prompt, user prompt, model, verbose and kwargs
Returns the output string res
- system_prompt: String. Write in whatever you want the LLM to become. e.g. "You are a \<purpose in life\>"
- user_prompt: String. The user input. Later, when we use it as a function, this is the function input
- user_prompt: String or list of strings. The user input. Later, when we use it as a function, this is the function input
- model: String. The LLM model to use for json generation
- verbose: Boolean (default: False). Whether or not to print out the system prompt, user prompt, GPT response
- host: String. The provider of the LLM
Expand All @@ -185,6 +185,7 @@ async def chat_async(system_prompt: str, user_prompt: str, model: str = 'gpt-4o-
- res: String. The response of the LLM call
- **kwargs: Dict. Additional arguments for LLM chat
"""
res = ""
if llm is not None:
ensure_awaitable(llm, 'llm')
''' If you specified your own LLM, then we just feed in the system and user prompt
Expand All @@ -203,6 +204,8 @@ async def chat_async(system_prompt: str, user_prompt: str, model: str = 'gpt-4o-

from openai import AsyncOpenAI
client = AsyncOpenAI()
if isinstance(user_prompt, list):
user_prompt = "\n".join([prompt.get("text", "") if isinstance(prompt, dict) and prompt.get("text", "") else str(prompt) for prompt in user_prompt])
response = await client.chat.completions.create(
model=model,
temperature = temperature,
Expand All @@ -226,14 +229,14 @@ async def chat_async(system_prompt: str, user_prompt: str, model: str = 'gpt-4o-
### Main Functions ###


async def strict_json_async(system_prompt: str, user_prompt: str, output_format: dict, return_as_json = False, custom_checks: dict = None, check_data = None, delimiter: str = '###', num_tries: int = 3, openai_json_mode: bool = False, **kwargs):
async def strict_json_async(system_prompt: str, user_prompt: str|list[str|dict], output_format: dict, return_as_json = False, custom_checks: dict|None = None, check_data = None, delimiter: str = '###', num_tries: int = 3, openai_json_mode: bool = False, **kwargs):
r""" Ensures that OpenAI will always adhere to the desired output JSON format defined in output_format.
Uses rule-based iterative feedback to ask GPT to self-correct.
Keeps trying up to num_tries it it does not. Returns empty JSON if unable to after num_tries iterations.

Inputs (compulsory):
- system_prompt: String. Write in whatever you want GPT to become. e.g. "You are a \<purpose in life\>"
- user_prompt: String. The user input. Later, when we use it as a function, this is the function input
- user_prompt: String or list of strings or dictionaries. The user input. Later, when we use it as a function, this is the function input
- output_format: Dict. JSON format with the key as the output key, and the value as the output description

Inputs (optional):
Expand Down Expand Up @@ -264,8 +267,14 @@ async def strict_json_async(system_prompt: str, user_prompt: str, output_format:

output_format_prompt = "\nOutput in the following json string format: " + str(output_format) + "\nBe concise."

my_system_prompt = str(system_prompt) + output_format_prompt
my_user_prompt = str(user_prompt)
my_system_prompt = str(system_prompt)
my_user_prompt: list[str | dict] = []
if isinstance(user_prompt, list):
my_user_prompt.extend(user_prompt)
my_user_prompt.append(output_format_prompt)
else:
my_user_prompt.append(user_prompt)
my_user_prompt.append(output_format_prompt)

res = await chat_async(my_system_prompt, my_user_prompt, response_format = {"type": "json_object"}, **kwargs)

Expand All @@ -292,9 +301,16 @@ async def strict_json_async(system_prompt: str, user_prompt: str, output_format:
Ensure the following output keys are present in the json: {' '.join(list(new_output_format.keys()))}'''

for i in range(num_tries):
my_system_prompt = str(system_prompt) + output_format_prompt + error_msg
my_user_prompt = str(user_prompt)

my_system_prompt = str(system_prompt)
my_user_prompt = []
if isinstance(user_prompt, list):
my_user_prompt.extend(user_prompt)
my_user_prompt.append(output_format_prompt)
my_user_prompt.append(error_msg)
else:
my_user_prompt.append(user_prompt)
my_user_prompt.append(output_format_prompt)
my_user_prompt.append(error_msg)
# Use OpenAI to get a response
res = await chat_async(my_system_prompt, my_user_prompt, **kwargs)

Expand All @@ -320,6 +336,7 @@ async def strict_json_async(system_prompt: str, user_prompt: str, output_format:
# do checks for keys and output format, remove escape characters so code can be run
end_dict = await check_key_async(res, output_format, new_output_format, delimiter, delimiter_num = 1, **kwargs)

assert isinstance(end_dict, dict)
# run user defined custom checks now
for key in end_dict:
if key in custom_checks:
Expand Down
Loading