Created
August 9, 2024 23:08
-
-
Save grahama1970/6b1d6d6fb82c95c91de79ca0c5e81483 to your computer and use it in GitHub Desktop.
create_pydantic_model_from_schema for dynamic openai structured response
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pydantic import BaseModel, create_model, ValidationError | |
from typing import Dict, Type, Any, List, Union | |
import json | |
from beta.llm_client.helpers.json_cleaner import clean_json_string | |
def infer_type(value: Any) -> Type: | |
""" | |
Infers the type of a given value or type string. | |
Args: | |
value (Any): The value or type string whose type needs to be inferred. | |
Returns: | |
Type: The inferred type of the value. | |
""" | |
type_map = { | |
'str': str, | |
'int': int, | |
'float': float, | |
'bool': bool, | |
'dict': dict | |
} | |
if isinstance(value, str): | |
if value in type_map: | |
return type_map[value] | |
elif value.startswith('list[') and value.endswith(']'): | |
return List[infer_type(value[5:-1])] | |
if isinstance(value, bool): | |
return bool | |
try: | |
int(value) | |
return int | |
except (ValueError, TypeError): | |
pass | |
try: | |
float(value) | |
return float | |
except (ValueError, TypeError): | |
pass | |
if isinstance(value, list): | |
return list | |
if isinstance(value, dict): | |
return dict | |
return str | |
def create_pydantic_model_from_schema(data: Union[Dict[str, Any], str]) -> Type[BaseModel]: | |
""" | |
Creates a dynamic Pydantic model based on the inferred types of the input data. | |
Args: | |
data (Union[Dict[str, Any], str]): The input data from which the model schema is created. | |
Returns: | |
Type[BaseModel]: A dynamically created Pydantic model. | |
Raises: | |
ValueError: If the input data is not a dictionary or a valid JSON string. | |
""" | |
if isinstance(data, str): | |
data = clean_json_string(data, return_dict=True) | |
if not isinstance(data, dict): | |
raise ValueError("Input data must be a dictionary or a valid JSON string.") | |
try: | |
# Infer the types for each field in the data | |
fields = {key: (infer_type(value), ...) for key, value in data.items()} | |
# Create a new Pydantic model class dynamically | |
DynamicModel = create_model('DynamicModel', **fields) | |
return DynamicModel | |
except Exception as e: | |
# Catch and raise any unexpected errors during model creation | |
raise RuntimeError(f"Error creating dynamic model: {e}") | |
if __name__ == "__main__": | |
from openai import OpenAI | |
import os | |
from dotenv import load_dotenv | |
from beta.llm_client.helpers.get_project_root import get_project_root | |
project_root = get_project_root() | |
env_path = os.path.join(project_root, '.env') | |
load_dotenv(env_path) | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
# Pass in dynamically instead...as who the hell knows what the user will pass in?! | |
# class ResearchPaperExtraction(BaseModel): | |
# title: str | |
# authors: list[str] | |
# abstract: str | |
# keywords: list[str] | |
# schema_quotes = """{"authors": list[str], "title": str, "abstract": str, "keywords": list[str]}""" | |
schema = "{authors: list[str], title: str, abstract: str, keywords: list[str]}" | |
system_message = ( | |
"You are an expert at structured data extraction. " | |
"You will be given unstructured text from a research paper and should convert it into the given structure. " | |
"The structure is as follows: " | |
# f"{schema}" | |
) | |
DynamicModel = create_pydantic_model_from_schema(schema) | |
completion = client.beta.chat.completions.parse( | |
model="gpt-4o-2024-08-06", | |
messages=[ | |
{"role": "system", | |
"content": system_message}, | |
{"role": "user", "content": "..."} | |
], | |
response_format=DynamicModel, | |
) | |
research_paper = completion.choices[0].message.parsed | |
print(research_paper) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment