Code for the Medium post
Last active
July 6, 2024 06:05
-
-
Save stephenleo/4cf5f91d22c0110b51fd1d85a44dd123 to your computer and use it in GitHub Desktop.
[Medium] The Definitive Guide to Structured Data Parsing with GPT3.5 in Complex Problems
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create and activate a mamba environment | |
mamba create -n structured_parsing python=3.11 | |
mamba activate structured_parsing | |
# Install the necessary libraries | |
pip install instructor fructose mirascope langchain langchain-openai langchain_experimental |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from tqdm import tqdm | |
from enum import Enum | |
from typing import Callable, Any, Optional, Type | |
import instructor | |
from openai import OpenAI | |
from fructose import Fructose | |
from dataclasses import dataclass, is_dataclass, asdict | |
from pydantic import BaseModel, EmailStr, field_validator | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1, EmailStr as EmailStrV1 | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import FewShotPromptTemplate, PromptTemplate | |
from langchain_experimental.tabular_synthetic_data.openai import ( | |
create_openai_data_generator, | |
) | |
from langchain_experimental.tabular_synthetic_data.prompts import ( | |
SYNTHETIC_FEW_SHOT_PREFIX, | |
SYNTHETIC_FEW_SHOT_SUFFIX, | |
) | |
from langchain.output_parsers import OutputFixingParser, PydanticOutputParser | |
from mirascope.openai import OpenAICallParams, OpenAIExtractor | |
# Create clients | |
os.environ["OPENAI_API_KEY"] = "..." | |
ai = Fructose(model="gpt-3.5-turbo-0125") | |
instructor_client = instructor.patch(OpenAI()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def response_parsing(response: Any) -> Any: | |
if isinstance(response, list): | |
response = {member.value for member in response} | |
elif is_dataclass(response): | |
response = asdict(response) | |
elif isinstance(response, BaseModel) or isinstance(response, BaseModelV1): | |
response = response.dict() | |
return response | |
def experiment( | |
n_runs: int, expected_response: Any | |
) -> Callable[..., tuple[list[Any], int, Optional[float]]]: | |
"""Decorator to run an LLM call function multiple times and return the responses | |
Args: | |
n_runs (int): Number of times to run the function | |
expected_response (set): The expected response set of classes. If provided, the decorator will calculate accurary too. | |
Returns: | |
Callable[..., Tuple[List[Any], int, Optional[float]]]: A function that returns a list of outputs from the function runs, percent of successful runs, accuracy of the identified classes if expected_response is provided else None. | |
""" | |
def experiment_decorator(func): | |
def wrapper(*args, **kwargs): | |
classes = [] | |
accurate = 0 | |
for _ in tqdm(range(n_runs)): | |
try: | |
response = func(*args, **kwargs) | |
if expected_response: | |
response = response_parsing(response) | |
if "classes" in response: | |
response = response_parsing(response["classes"]) | |
if response == expected_response: | |
accurate += 1 | |
classes.append(response) | |
except: | |
pass | |
num_successful = len(classes) | |
percent_successful = num_successful / n_runs | |
if expected_response: | |
accuracy = accurate / num_successful if num_successful else 0 | |
return classes, percent_successful, accuracy if expected_response else None | |
return wrapper | |
return experiment_decorator |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
text_to_classify = "wake me up at nine am, turn on the light in the bathroom and start my playlist" | |
expected_response = {"alarm_set", "iot_hue_lighton", "play_music"} | |
multilabel_classes = [ | |
"audio_volume_other", "play_music", "iot_hue_lighton", "general_greet", "calendar_set", "audio_volume_down", | |
"social_query", "audio_volume_mute", "iot_wemo_on", "iot_hue_lightup", "audio_volume_up", "iot_coffee", | |
"takeaway_query", "qa_maths", "play_game", "cooking_query", "iot_hue_lightdim", "iot_wemo_off", "music_settings", | |
"weather_query", "news_query", "alarm_remove", "social_post", "recommendation_events", "transport_taxi", "takeaway_order", | |
"music_query", "calendar_query", "lists_query", "qa_currency", "recommendation_movies", "general_joke", | |
"recommendation_locations", "email_querycontact", "lists_remove", "play_audiobook", "email_addcontact", "lists_createoradd", | |
"play_radio", "qa_stock", "alarm_query", "email_sendemail", "general_quirky", "music_likeness", "cooking_recipe", | |
"email_query", "datetime_query", "transport_traffic", "play_podcasts", "iot_hue_lightchange", "calendar_remove", | |
"transport_query", "transport_ticket", "qa_factoid", "iot_cleaning", "alarm_set", "datetime_convert", "iot_hue_lightoff", | |
"qa_definition", "music_dislikeness" | |
] | |
print(f"Number of classes: {len(multilabel_classes)}") | |
print(f"Sample classes: {multilabel_classes[:3]}") | |
# Number of classes: 60 | |
# Sample classes: ['audio_volume_other', 'play_music', 'iot_hue_lighton'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fields = {name: (Optional[str], None) for name in multilabel_classes} | |
Intents = Enum("Intents", {name: name for name in multilabel_classes}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@experiment(n_runs=20, expected_response=expected_response) | |
def run_instructor_classification_experiment(text: str): | |
return instructor_client.chat.completions.create( | |
model="gpt-3.5-turbo-0125", | |
response_model=list[Intents], | |
max_retries=2, | |
messages=[{"role": "user", "content": f"Classify the following text: {text}"}], | |
) | |
predictions, percent_successful, accuracy = run_instructor_classification_experiment( | |
text_to_classify | |
) | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Accuracy: {accuracy:.4f}") | |
# 100%|██████████| 20/20 [00:19<00:00, 1.01it/s] | |
# Percent of successful API calls: 1.0000 | |
# Accuracy: 1.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Pydantic response model | |
class TaskDetails(BaseModel): | |
classes: list[Intents] | |
# Mirascope template | |
class TaskExtractor(OpenAIExtractor[TaskDetails]): | |
extract_schema: Type[TaskDetails] = TaskDetails | |
prompt_template = """ | |
Classify the following text: {text_to_classify} | |
""" | |
text_to_classify: str = "" | |
TaskExtractor.call_params = OpenAICallParams(model="gpt-3.5-turbo-0125") | |
extractor = TaskExtractor() | |
@experiment(n_runs=20, expected_response=expected_response) | |
def run_mirascope_classification_experiment(text_to_classify: str): | |
extractor.text_to_classify = text_to_classify | |
task_details = extractor.extract(retries=2) | |
return {cat.value for cat in task_details.classes} | |
predictions, percent_successful, accuracy = run_mirascope_classification_experiment( | |
text_to_classify | |
) | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Accuracy: {accuracy:.4f}") | |
# 100%|██████████| 20/20 [00:38<00:00, 1.93s/it] | |
# Percent of successful API calls: 1.0000 | |
# Accuracy: 1.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@experiment(n_runs=20, expected_response=expected_response) | |
@ai | |
def fructose_classify_text(text: str) -> list[Intents]: | |
"""Classify the given text into all the classes present. | |
The text might have more than one class present. | |
Args: | |
text (str): The input text to classify | |
Returns: | |
list[Intents]: The classes present in the text | |
""" | |
... | |
predictions, percent_successful, accuracy = fructose_classify_text( | |
text=text_to_classify | |
) | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Accuracy: {accuracy:.4f}") | |
# 100%|██████████| 20/20 [00:22<00:00, 1.10s/it] | |
# Percent of successful API calls: 1.0000 | |
# Accuracy: 1.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tagging_prompt = ChatPromptTemplate.from_template( | |
""" | |
Extract the desired information from the following passage. | |
Only extract the properties mentioned in the 'Classification' function. | |
Passage: | |
{text} | |
""" | |
) | |
# Langchain needs Pydantic V1 | |
class PydanticV1Intents(BaseModelV1): | |
classes: list[Intents] | |
# LLM | |
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0125").with_structured_output( | |
PydanticV1Intents | |
) | |
# Chain | |
tagging_chain = tagging_prompt | llm | |
# Experiment | |
@experiment(n_runs=20, expected_response=expected_response) | |
def run_langchain_classification_experiment(text: str): | |
return tagging_chain.invoke({"text": text}) | |
predictions, percent_successful, accuracy = run_langchain_classification_experiment( | |
text_to_classify | |
) | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Accuracy: {accuracy:.4f}") | |
# 100%|██████████| 20/20 [00:18<00:00, 1.08it/s] | |
# Percent of successful API calls: 1.0000 | |
# Accuracy: 1.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
text_to_classify = "Hello I'm John Doe, 28 years old, from New York, USA. My email is [email protected]" | |
expected_response = { | |
"name": "John Doe", | |
"age": 28, | |
"email": "[email protected]", | |
"addresses": [ | |
{ | |
"street": None, | |
"city": "New York", | |
"postal_code": None, | |
"country": {"name": "USA", "code": None}, | |
} | |
], | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PydanticCountry(BaseModel): | |
name: Optional[str] | |
code: Optional[str] | |
class PydanticAddress(BaseModel): | |
street: Optional[str] | |
city: Optional[str] | |
postal_code: Optional[str] | |
country: Optional[PydanticCountry] | |
class PydanticUser(BaseModel): | |
name: Optional[str] | |
age: Optional[int] | |
email: Optional[EmailStr] | |
addresses: Optional[list[PydanticAddress]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@experiment( | |
n_runs=20, | |
expected_response=expected_response, | |
) | |
def run_instructor_experiment(): | |
return instructor_client.chat.completions.create( | |
model="gpt-3.5-turbo-0125", | |
response_model=PydanticUser, | |
max_retries=2, | |
messages=[ | |
{ | |
"role": "user", | |
"content": f"Extract and resolve a list of entities from the following text: {text_to_classify}", | |
} | |
], | |
) | |
predictions, percent_successful, accuracy = run_instructor_experiment() | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Accuracy: {accuracy:.4f}") | |
# 100%|██████████| 20/20 [00:34<00:00, 1.73s/it] | |
# Percent of successful API calls: 1.0000 | |
# Accuracy: 0.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Mirascope template | |
class TaskExtractor(OpenAIExtractor[PydanticUser]): | |
extract_schema: Type[PydanticUser] = PydanticUser | |
prompt_template = """ | |
Extract and resolve a list of entities from the following text: {text_to_classify} | |
""" | |
text_to_classify: str | |
TaskExtractor.call_params = OpenAICallParams(model="gpt-3.5-turbo-0125") | |
extractor = TaskExtractor() | |
@experiment(n_runs=20, expected_response=expected_response) | |
def run_mirascope_experiment(text_to_classify: str): | |
extractor.text_to_classify = text_to_classify | |
return extractor.extract(retries=2) | |
predictions, percent_successful, accuracy = run_mirascope_experiment(text_to_classify) | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Accuracy: {accuracy:.4f}") | |
# 100%|██████████| 20/20 [00:56<00:00, 2.81s/it] | |
# Percent of successful API calls: 1.0000 | |
# Accuracy: 0.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@dataclass | |
class DataclassCountry: | |
name: Optional[str] | |
code: Optional[str] | |
@dataclass | |
class DataclassAddress: | |
street: Optional[str] | |
city: Optional[str] | |
postal_code: Optional[str] | |
country: Optional[DataclassCountry] | |
@dataclass | |
class DataclassUser: | |
name: Optional[str] | |
age: Optional[int] | |
email: Optional[str] | |
addresses: Optional[list[DataclassAddress]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@experiment( | |
n_runs=20, | |
expected_response=expected_response, | |
) | |
@ai | |
def extract_person_data(text: str) -> DataclassUser: | |
""" | |
Given an input text, extract out all the available attributes. | |
""" | |
... | |
predictions, percent_successful, accuracy = extract_person_data(text_to_classify) | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Accuracy: {accuracy:.4f}") | |
# 100%|██████████| 20/20 [00:40<00:00, 2.03s/it] | |
# Percent of successful API calls: 1.0000 | |
# Accuracy: 1.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Pydantic V1 models | |
class LCCountry(BaseModelV1): | |
name: str | |
code: str | |
class LCAddress(BaseModelV1): | |
street: str | |
city: str | |
postal_code: str | |
country: LCCountry | |
class LCUser(BaseModelV1): | |
name: str | |
age: int | |
email: EmailStrV1 | |
addresses: list[LCAddress] | |
# Chain | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"You are an expert extraction algorithm. " | |
"Only extract relevant information from the text. " | |
"If you do not know the value of an attribute asked to extract, " | |
"return null for the attribute's value.", | |
), | |
("human", "{text}"), | |
] | |
) | |
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0125").with_structured_output( | |
schema=LCUser | |
) | |
chain = prompt | llm | |
# Experiment | |
@experiment( | |
n_runs=20, | |
expected_response=expected_response, | |
) | |
def run_lc_experiment(): | |
return chain.invoke({"text": text_to_classify}) | |
predictions, percent_successful, accuracy = run_lc_experiment() | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Accuracy: {accuracy:.4f}") | |
# 100%|██████████| 20/20 [00:48<00:00, 2.41s/it] | |
# Percent of successful API calls: 1.0000 | |
# Accuracy: 0.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PydanticCountry(BaseModel): | |
name: str | |
code: str | |
class PydanticAddress(BaseModel): | |
street: str | |
city: str | |
postal_code: str | |
country: PydanticCountry | |
@field_validator("postal_code") | |
def postal_code_must_be_6_digits(cls, v): | |
if not (v.isdigit() and len(v) == 6): | |
raise ValueError("postal_code must be a 6-digit number") | |
return v | |
class PydanticUser(BaseModel): | |
name: str | |
age: int | |
email: EmailStr | |
addresses: list[PydanticAddress] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@experiment(n_runs=20, expected_response=None) | |
def run_instructor_experiment(): | |
return instructor_client.chat.completions.create( | |
model="gpt-3.5-turbo-0125", | |
response_model=PydanticUser, | |
max_retries=2, | |
messages=[ | |
{ | |
"role": "user", | |
"content": "Generate a random person's information. The name must be chosen at random. Make it something you wouldn't normally choose.", | |
} | |
], | |
) | |
predictions, percent_successful, _ = run_instructor_experiment() | |
variety = len({pred.name for pred in predictions}) / len(predictions) | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Variety of names: {variety:.4f}") | |
# 100%|██████████| 20/20 [00:57<00:00, 2.89s/it] | |
# Percent of successful API calls: 0.9500 | |
# Variety of names: 0.6842 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Mirascope template | |
class TaskExtractor(OpenAIExtractor[PydanticUser]): | |
extract_schema: Type[PydanticUser] = PydanticUser | |
prompt_template = """ | |
Generate a random person's information. The name must be chosen at random. Make it something you wouldn't normally choose. | |
""" | |
TaskExtractor.call_params = OpenAICallParams(model="gpt-3.5-turbo-0125") | |
extractor = TaskExtractor() | |
@experiment(n_runs=20, expected_response=None) | |
def run_mirascope_experiment(): | |
return extractor.extract() | |
predictions, percent_successful, _ = run_mirascope_experiment() | |
if len(predictions): | |
variety = len({pred.name for pred in predictions}) / len(predictions) | |
else: | |
variety = 0 | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Variety of names: {variety:.4f}") | |
# 100%|██████████| 20/20 [00:55<00:00, 2.76s/it] | |
# Percent of successful API calls: 0.0000 | |
# Variety of names: 0.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@dataclass | |
class DataclassCountry: | |
name: str | |
code: str | |
@dataclass | |
class DataclassAddress: | |
street: str | |
city: str | |
postal_code: str | |
country: DataclassCountry | |
def __post_init__(self): | |
if len(self.postal_code) != 6 or not self.postal_code.isdigit(): | |
raise ValueError("Postal code must be a 6-digit number") | |
@dataclass | |
class DataclassUser: | |
name: str | |
age: int | |
email: str | |
addresses: list[DataclassAddress] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@experiment(n_runs=20, expected_response=None) | |
@ai | |
def generate_fake_person_data() -> DataclassUser: | |
""" | |
Generate a random person's information. The name must be chosen at random. Make it something you wouldn't normally choose. | |
""" | |
... | |
predictions, percent_successful, _ = generate_fake_person_data() | |
variety = len({pred.name for pred in predictions}) / len(predictions) | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
print(f"Variety of names: {variety:.4f}") | |
# 100%|██████████| 20/20 [02:27<00:00, 7.39s/it] | |
# Percent of successful API calls: 0.9000 | |
# Variety of names: 1.0000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LCCountry(BaseModelV1): | |
name: str | |
code: str | |
class LCAddress(BaseModelV1): | |
street: str | |
city: str | |
postal_code: str | |
country: LCCountry | |
# Langchain's RetryOutputParser doesn't work with create_openai_data_generator | |
# @validator('postal_code') | |
# def postal_code_must_be_6_digits(cls, v): | |
# if not (v.isdigit() and len(v) == 6): | |
# raise ValueError('postal_code must be a 6-digit number') | |
# return v | |
class LCUser(BaseModelV1): | |
name: str | |
age: int | |
email: EmailStrV1 | |
addresses: list[LCAddress] | |
parser = PydanticOutputParser(pydantic_object=LCUser) | |
fix_parser = OutputFixingParser.from_llm(parser=parser, llm=ChatOpenAI(), max_retries=2) | |
examples = [ | |
{ | |
"example": """{{"name":"Alice", "age":28, "email":"[email protected]", "addresses":{{"street":"123 Main St","city":"Wonderland","postal_code":"456789","country":{{"name":"Fantasyland","code":"FL"}}}}}}""" | |
}, | |
{ | |
"example": """{{"name":"Emily","age":28,"email":"[email protected]","addresses":{{"street":"123 Oak Street","city":"New York","postal_code":"100010","country":{{"name":"United States","code":"US"}}}}}}""" | |
}, | |
] | |
OPENAI_TEMPLATE = PromptTemplate(input_variables=["example"], template="{example}") | |
prompt_template = FewShotPromptTemplate( | |
prefix=SYNTHETIC_FEW_SHOT_PREFIX, | |
examples=examples, | |
suffix=SYNTHETIC_FEW_SHOT_SUFFIX, | |
input_variables=["subject", "extra"], | |
example_prompt=OPENAI_TEMPLATE, | |
) | |
synthetic_data_generator = create_openai_data_generator( | |
output_schema=LCUser, | |
llm=ChatOpenAI(temperature=0.5, model="gpt-3.5-turbo-0125"), | |
prompt=prompt_template, | |
output_parser=fix_parser, | |
) | |
@experiment(n_runs=20, expected_response=None) | |
def run_langchain_experiment(): | |
return synthetic_data_generator.generate( | |
subject="customer information", | |
extra="the name must be chosen at random. Make it something you wouldn't normally choose.", | |
runs=1, | |
) | |
predictions, percent_successful, _ = run_langchain_experiment() | |
print(f"Percent of successful API calls: {percent_successful:.4f}") | |
# 100%|██████████| 20/20 [00:12<00:00, 1.54it/s] | |
# Percent of successful API calls: 0.0500 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment