Last active
November 13, 2024 17:16
-
-
Save ricklamers/54ca231853c07f60005fafbd84b102a0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import time | |
from typing import List | |
from groq import Groq | |
from pydantic import BaseModel, Field | |
class Ingredient(BaseModel): | |
name: str = Field( | |
description="Name of the ingredient", min_length=2, max_length=100 | |
) | |
amount: float = Field(description="Quantity of the ingredient", gt=0) | |
unit: str = Field(description="Unit of measurement", pattern="^[a-zA-Z ]{1,20}$") | |
notes: str | None = Field( | |
description="Optional notes about the ingredient", default=None | |
) | |
substitutes: List[str] = Field( | |
description="Possible ingredient substitutions", default_factory=list | |
) | |
class Recipe(BaseModel): | |
title: str = Field(description="Title of the recipe", min_length=3, max_length=200) | |
description: str = Field( | |
description="Brief description of the recipe", min_length=10 | |
) | |
ingredients: List[Ingredient] = Field( | |
description="List of ingredients needed", min_items=1 | |
) | |
instructions: List[str] = Field( | |
description="Step by step cooking instructions", min_items=2 | |
) | |
prep_time_mins: int = Field(description="Preparation time in minutes", gt=0, le=240) | |
cook_time_mins: int = Field(description="Cooking time in minutes", gt=0, le=480) | |
servings: int = Field(description="Number of servings", gt=0, le=20) | |
difficulty: str = Field( | |
description="Recipe difficulty level", pattern="^(easy|medium|hard)$" | |
) | |
cuisine_type: str = Field(description="Type of cuisine", min_length=3) | |
tags: List[str] = Field(description="Recipe categories", default_factory=list) | |
client = Groq(api_key=os.environ["GROQ_API_KEY"]) | |
def get_recipe_json(error_msg="", raise_on_error=False): | |
if error_msg: | |
print(f"Retrying due to error: {error_msg}") | |
if raise_on_error: | |
raise Exception(error_msg) | |
messages = [ | |
{ | |
"role": "system", | |
"content": f"Provide a cooking recipe in JSON format following this schema:\n{Recipe.model_json_schema()}", | |
}, | |
{ | |
"role": "user", | |
"content": "Provide a sample recipe." | |
+ (f"\nErrors: {error_msg}" if error_msg else ""), | |
}, | |
] | |
try: | |
completion = client.chat.completions.create( | |
model="llama-3.1-8b-instant", | |
messages=messages, | |
temperature=1, | |
max_tokens=1024, | |
stream=False, | |
response_format={"type": "json_object"}, | |
) | |
except Exception: | |
# Retry due to API error with a short delay | |
# Could be 429 rate limit or 400 JSON format error | |
time.sleep(1) | |
return get_recipe_json("", raise_on_error) | |
response_text = completion.choices[0].message.content | |
try: | |
# First try to parse as JSON | |
json_data = json.loads(response_text) | |
# Then validate with Pydantic | |
recipe = Recipe.model_validate(json_data) | |
return recipe | |
except json.JSONDecodeError as e: | |
return get_recipe_json(f"Invalid JSON format: {str(e)}", raise_on_error) | |
except ValueError as e: | |
return get_recipe_json(f"Schema validation error: {str(e)}", raise_on_error) | |
def run_recipe_generator(show_error=False): | |
while True: | |
try: | |
recipe = get_recipe_json(raise_on_error=show_error) | |
if not show_error: | |
print(recipe.model_dump_json(indent=2)) | |
break | |
except Exception as e: | |
print(f"Error occurred: {str(e)}") | |
if show_error: | |
break | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--show-error", action="store_true", help="Keep running until an error occurs" | |
) | |
args = parser.parse_args() | |
run_recipe_generator(show_error=args.show_error) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment