Created
August 2, 2023 06:55
-
-
Save lavishsaluja/4d6082dae43c5c5bc6dbea2d0f462156 to your computer and use it in GitHub Desktop.
code file to generate JSONs with Open AI GPT models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Description: code file to get required JSONs for various applications using Open AI large language models | |
How-to-use: | |
1. replace the following variables for your use case | |
- sample_app_description | |
- sample_user_messages | |
- supported_animal_list | |
- json_examples | |
- json_key_value_descriptions | |
2. change the variables & description in json_schema_generation_prompt_template | |
3. replace the variables in values dict in main() function | |
4. pip install openai, langchain | |
5. python3 main.py | |
""" | |
# Sample data for application | |
sample_app_description = "an app which shares bed time stories along with images of animals if mentioned by the user if supported" | |
sample_user_messages = [ | |
"tell my daughter Lilly a story about frog & snail who were friends in school", | |
"tell my son a story about animals in a jungle", | |
"tell me a story about the king of the jungle", | |
"what happened with rabbit and tortoise in their race? complete the story, my son is listening to you during bedtime" | |
] | |
supported_animal_list = ["lion", "snail", "monkey", "tiger", "elephant", "giraffe", "zebra", "hippopotamus", | |
"kangaroo", "penguin", "bear", "wolf", "fox", "rhinoceros", "jaguar", "deer", "eagle", "hawk", "owl", "buffalo"] | |
json_examples = [ | |
{ | |
"is_this_request_asking_to_recite_a_story": True, | |
"animals_mentioned": ["frog", "snail"] | |
}, | |
{ | |
"is_this_request_asking_to_recite_a_story": False, | |
"animals_mentioned": [] | |
}, | |
{ | |
"is_this_request_asking_to_recite_a_story": True, | |
"animals_mentioned": [] | |
} | |
] | |
json_key_value_descriptions = { | |
"is_this_request_asking_to_recite_a_story": "when the user's message is asking to share a story", | |
"animals_mentioned": "this will be empty list when user has not mentioned any animals which fall under supported_animals_list otherwise this will be a list with animals mentioned by the user, this list should contain that animal's corresponding key in the supported_animal_list" | |
} | |
json_schema_generation_prompt_template = """ | |
given a user's message below, your task is to create a JSON response adhering to the JSON schema shared below. | |
user_message: {user_message} | |
supported_animal_list: {supported_animal_list} | |
JSON Schema: {json_schema_string} | |
Instructions | |
- your response should be formatted as a JSON instance that conforms to the provided JSON schema | |
- analyze the user_message carefully above to extract the necessary information needed to populate the JSON | |
- set the fileds in JSON based on the user_message & following mentioned descroption of keys | |
{json_key_value_descriptions} | |
return the required JSON schema from next line: | |
""" | |
# Importing necessary modules | |
import json | |
import logging | |
from langchain import LLMChain, PromptTemplate | |
from langchain.chat_models import ChatOpenAI | |
TEMPERATURE = 0.0 | |
MODEL_NAME = "gpt-3.5-turbo" | |
MAX_RETRIES_COUNT = 2 | |
OPENAI_API_KEY = "sk-***********************************" | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('main.log', mode='w') | |
] | |
) | |
logger = logging.getLogger() | |
def validate_json_examples(json_examples: list): | |
expected_keys = json_examples[0].keys() | |
expected_types = {key: type(json_examples.get(key)) | |
for key, val in json_examples[0].items()} | |
for json_object in json_examples: | |
keys = set(json_object.keys()) | |
if (keys != expected_keys): | |
logger.error("inconsistent json examples - check the keys") | |
return False | |
for key, val in json_object: | |
if type(val) != expected_types.get(key): | |
logger.error("inconsistent json examples - check values") | |
return False | |
return True | |
def generate_json_schema(json_example): | |
schema = { | |
"type": "object", | |
"properties": {} | |
} | |
for key, value in json_example.items(): | |
if value is None: | |
schema["properties"][key] = { | |
"type": "null" | |
} | |
elif isinstance(value, bool): | |
schema["properties"][key] = { | |
"type": "boolean" | |
} | |
elif isinstance(value, list): | |
if all(isinstance(i, dict) for i in value): | |
schema["properties"][key] = { | |
"type": "array", | |
"items": generate_json_schema(value[0]) if value else {"type": "object"} | |
} | |
elif all(isinstance(i, str) for i in value): | |
schema["properties"][key] = { | |
"type": "array", | |
"items": {"type": "string"} | |
} | |
elif all(isinstance(i, (int, float)) for i in value): | |
schema["properties"][key] = { | |
"type": "array", | |
"items": {"type": "number"} | |
} | |
else: | |
schema["properties"][key] = { | |
"type": "array", | |
"items": {"type": "object"} | |
} | |
elif isinstance(value, str): | |
schema["properties"][key] = { | |
"type": "string" | |
} | |
elif isinstance(value, int): | |
schema["properties"][key] = { | |
"type": "integer" | |
} | |
elif isinstance(value, float): | |
schema["properties"][key] = { | |
"type": "number" | |
} | |
elif isinstance(value, dict): | |
schema["properties"][key] = generate_json_schema(value) | |
else: | |
schema["properties"][key] = { | |
"type": "object" | |
} | |
return schema | |
def get_output_from_OpenAI(prompt_template: str, values: dict): | |
llm = ChatOpenAI(temperature=TEMPERATURE, model=MODEL_NAME, | |
openai_api_key=OPENAI_API_KEY) | |
llm_chain = LLMChain( | |
llm=llm, | |
prompt=PromptTemplate.from_template(prompt_template) | |
) | |
llm_output = llm_chain.apply([values])[0]['text'] | |
return llm_output | |
def lint_json(response, json_examples): | |
sample_json_object = json_examples[0] | |
schema_sample_json_object = generate_json_schema( | |
json_example=sample_json_object) | |
schema_response = generate_json_schema(json_example=response) | |
logger.info("schema_sample_json_object: {}".format( | |
schema_sample_json_object)) | |
logger.info("schema_response: {}".format(schema_response)) | |
return schema_sample_json_object == schema_response | |
def get_error_in_response(response): | |
if (isinstance(response, dict)): | |
is_valid_json = lint_json(response, json_examples=json_examples) | |
if (is_valid_json): | |
logger.info("json response has been created successfully") | |
return None | |
else: | |
logger.debug( | |
"incorrect json-values returned by the LLM. output: {}".format(response)) | |
return "KeyValueError" | |
else: | |
try: | |
response = json.loads(response) | |
is_valid_json = lint_json(response, json_examples=json_examples) | |
if (is_valid_json): | |
logger.info("json response has been created successfully") | |
return None | |
else: | |
logger.debug( | |
"incorrect json-values returned by the LLM. output: {}".format(response)) | |
return "KeyValueError" | |
except Exception as e: | |
logger.debug( | |
"incorrect json-string returned by the LLM. output: {}".format(response)) | |
return "JsonTextError" | |
def extract_json_from_response(response): | |
if (isinstance(response, dict)): | |
return response | |
logger.debug("trying to extract json (if any) by finding first { & last }") | |
first_brace_index = response.find("{") | |
last_brace_index = response.rfind("}") | |
response = response[first_brace_index:last_brace_index+1] | |
return response | |
def generate_required_json(prompt_template: str, values: dict, retry_count: int): | |
response = get_output_from_OpenAI( | |
prompt_template=prompt_template, values=values) | |
error = get_error_in_response(response=response) | |
if not error: | |
return response | |
else: | |
response = extract_json_from_response(response) | |
error = get_error_in_response(response) | |
if not error: | |
return response | |
elif error == "JsonTextError": | |
response = extract_json_from_response(response=response) | |
error = get_error_in_response(response=response) | |
if not error: | |
return response | |
else: | |
pass | |
elif error == "KeyValueError": | |
retry_count += 1 | |
if (retry_count < MAX_RETRIES_COUNT): | |
response = generate_required_json( | |
prompt_template=prompt_template, values=values, retry_count=retry_count) | |
else: | |
logger.error( | |
"was not able to generate JSON output from the LLM for given user message. final response: {}".format(response)) | |
return None | |
return response | |
def main(): | |
user_message = sample_user_messages[0] | |
json_example = json_examples[0] | |
schema = generate_json_schema(json_example=json_example) | |
json_schema_string = json.dumps(schema, indent=2) | |
values = { | |
"user_message": user_message, | |
"supported_animal_list": supported_animal_list, | |
"json_schema_string": json_schema_string, | |
"json_key_value_descriptions": json_key_value_descriptions | |
} | |
response = generate_required_json( | |
prompt_template=json_schema_generation_prompt_template, values=values, retry_count=0) | |
return response | |
print(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment