Skip to content

Instantly share code, notes, and snippets.

@cnndabbler
Created December 21, 2024 00:55
Show Gist options
  • Save cnndabbler/0bbca28dcd7de2e97060c7ac90a67126 to your computer and use it in GitHub Desktop.
Save cnndabbler/0bbca28dcd7de2e97060c7ac90a67126 to your computer and use it in GitHub Desktop.
This code implements an intelligent restaurant recommendation system that combines Pydantic type safety with sophisticated context management. The system uses Pydantic Agents to interface with an LLM (Large Language Model) while maintaining type safety and data validation throughout the application.
"""
Restaurant recommendation system demonstrating context features with multiple customers.
Shows how context helps personalize and improve recommendations over time.
"""
from typing import List, Optional, Dict, Literal
from pydantic import BaseModel, Field
import asyncio
from datetime import datetime
from pydantic_ai import Agent, RunContext
import json
import logfire
logfire.configure()
logfire.info('Starting in pydantic_restaurant_context.py')
from pathlib import Path
import json
from datetime import datetime
import os
# Constants for persistence
CONTEXT_DIR = Path("customer_contexts")
CONTEXT_DIR.mkdir(exist_ok=True)
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder to handle datetime objects."""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
class ContextPersistence:
"""Handles saving and loading of customer contexts."""
@staticmethod
def get_context_file(customer_name: str) -> Path:
"""Get the path to a customer's context file."""
return CONTEXT_DIR / f"{customer_name.lower()}_context.json"
@staticmethod
def save_context(customer_ctx: 'CustomerContext') -> None:
"""Save customer context to file."""
context_file = ContextPersistence.get_context_file(customer_ctx.name)
# Convert context to serializable format
context_data = {
'name': customer_ctx.name,
'preferences': customer_ctx.preferences.model_dump(),
'visit_history': [visit.model_dump() for visit in customer_ctx.visit_history],
'recommendation_history': [rest.model_dump() for rest in customer_ctx.recommendation_history],
'context_updates': customer_ctx.context_updates,
'last_updated': datetime.now().isoformat()
}
# Save to file with custom encoder
with open(context_file, 'w') as f:
json.dump(context_data, f, indent=2, cls=DateTimeEncoder)
print(f"\n💾 Saved context for {customer_ctx.name} to {context_file}")
@staticmethod
def load_context(customer_name: str) -> Optional['CustomerContext']:
"""Load customer context from file if it exists."""
context_file = ContextPersistence.get_context_file(customer_name)
if not context_file.exists():
print(f"\n📝 No saved context found for {customer_name}")
return None
try:
with open(context_file, 'r') as f:
data = json.load(f)
# Convert ISO format strings back to datetime
if 'visit_history' in data:
for visit in data['visit_history']:
if 'date' in visit:
visit['date'] = datetime.fromisoformat(visit['date'])
# Create new context with loaded data
preferences = CustomerPreferences(**data['preferences'])
context = CustomerContext(data['name'], preferences)
# Restore visit history
context.visit_history = [
DiningExperience(**visit) for visit in data['visit_history']
]
# Restore recommendation history
context.recommendation_history = [
Restaurant(**rest) for rest in data['recommendation_history']
]
# Restore context updates
context.context_updates = data['context_updates']
print(f"\n📂 Loaded saved context for {customer_name}")
print(f"Last updated: {data['last_updated']}")
return context
except Exception as e:
print(f"\n❌ Error loading context for {customer_name}: {str(e)}")
return None
# Models for restaurant recommendations
class Restaurant(BaseModel):
"""Restaurant details."""
name: str
cuisine: str
price_range: Literal["$", "$$", "$$$", "$$$$"]
vegetarian_friendly: bool = False
gluten_free_options: bool = False
average_rating: float
specialties: List[str]
location: str
typical_wait_time: str
class DiningExperience(BaseModel):
"""Record of a customer's dining experience."""
restaurant: str
date: datetime
rating: int
liked_dishes: List[str]
disliked_dishes: List[str]
comments: str
party_size: int
occasion: Optional[str] = None
class CustomerPreferences(BaseModel):
"""Customer's dining preferences."""
favorite_cuisines: List[str] = Field(default_factory=list)
dietary_restrictions: List[str] = Field(default_factory=list)
preferred_price_range: List[str] = Field(default_factory=list)
typical_party_size: int = 2
favorite_restaurants: List[str] = Field(default_factory=list)
disliked_restaurants: List[str] = Field(default_factory=list)
preferred_locations: List[str] = Field(default_factory=list)
usual_occasions: List[str] = Field(default_factory=list)
class RecommendationContext(BaseModel):
"""Context for restaurant recommendations."""
customer_name: str
preferences: CustomerPreferences
dining_history: List[DiningExperience] = Field(default_factory=list)
current_request: str
previous_recommendations: List[Restaurant] = Field(default_factory=list)
rejected_recommendations: List[Restaurant] = Field(default_factory=list)
current_occasion: Optional[str] = None
party_size: Optional[int] = None
time_of_day: Optional[str] = None
class CustomerContext:
"""Tracks a customer's context across multiple restaurant visits."""
def __init__(self, name: str, initial_preferences: CustomerPreferences):
self.name = name
self.preferences = initial_preferences
self.visit_history: List[DiningExperience] = []
self.recommendation_history: List[Restaurant] = []
self.context_updates: List[dict] = [] # Track all context changes
def add_visit(self, experience: DiningExperience, restaurant: Restaurant):
"""Add a new restaurant visit and update context."""
self.visit_history.append(experience)
self.recommendation_history.append(restaurant)
# Store context state before update
before_state = self.preferences.model_dump()
# Update preferences based on experience
if experience.rating >= 4:
if restaurant.name not in self.preferences.favorite_restaurants:
self.preferences.favorite_restaurants.append(restaurant.name)
if restaurant.cuisine not in self.preferences.favorite_cuisines:
self.preferences.favorite_cuisines.append(restaurant.cuisine)
elif experience.rating <= 2:
if restaurant.name not in self.preferences.disliked_restaurants:
self.preferences.disliked_restaurants.append(restaurant.name)
# Update party size preference
self.preferences.typical_party_size = (
(self.preferences.typical_party_size * 2 + experience.party_size) // 3
)
# Update occasion preferences
if experience.occasion and experience.occasion not in self.preferences.usual_occasions:
self.preferences.usual_occasions.append(experience.occasion)
# Store context update
after_state = self.preferences.model_dump()
self.context_updates.append({
'visit_number': len(self.visit_history),
'restaurant': restaurant.name,
'before': before_state,
'after': after_state,
'changes': self._get_context_changes(before_state, after_state),
'timestamp': datetime.now().isoformat()
})
# Save context after update
ContextPersistence.save_context(self)
def _get_context_changes(self, before: dict, after: dict) -> dict:
"""Get changes between two context states."""
changes = {}
for key in set(before.keys()) | set(after.keys()):
if key not in before:
changes[key] = {'type': 'added', 'value': after[key]}
elif key not in after:
changes[key] = {'type': 'removed', 'value': before[key]}
elif before[key] != after[key]:
if isinstance(before[key], list):
added = set(after[key]) - set(before[key])
removed = set(before[key]) - set(after[key])
if added or removed:
changes[key] = {
'type': 'list_changed',
'added': list(added),
'removed': list(removed)
}
else:
changes[key] = {
'type': 'value_changed',
'from': before[key],
'to': after[key]
}
return changes
def print_context_journey(self):
"""Print the customer's context evolution through their restaurant journey."""
print(f"\n🗺️ Context Journey for {self.name}")
print("=" * 50)
for i, update in enumerate(self.context_updates, 1):
print(f"\n📍 Visit #{i}: {update['restaurant']}")
changes = update['changes']
for key, change in changes.items():
if change['type'] == 'added':
print(f"+ Added {key}: {change['value']}")
elif change['type'] == 'removed':
print(f"- Removed {key}: {change['value']}")
elif change['type'] == 'list_changed':
if change['added']:
print(f"+ Added to {key}: {change['added']}")
if change['removed']:
print(f"- Removed from {key}: {change['removed']}")
elif change['type'] == 'value_changed':
print(f"~ Changed {key}: {change['from']} → {change['to']}")
print(f"\nCurrent preferences after visit:")
visit_prefs = update['after']
print(f"- Favorite cuisines: {visit_prefs['favorite_cuisines']}")
print(f"- Favorite restaurants: {visit_prefs['favorite_restaurants']}")
if visit_prefs['disliked_restaurants']:
print(f"- Disliked restaurants: {visit_prefs['disliked_restaurants']}")
print(f"- Typical party size: {visit_prefs['typical_party_size']}")
print("-" * 50)
# Initialize recommendation agent
recommendation_agent = Agent(
model="ollama:qwen2.5:32b",
result_type=Restaurant,
deps_type=RecommendationContext,
system_prompt="""You are a knowledgeable restaurant recommendation agent.
Analyze customer preferences, dining history, and current context to suggest restaurants they'll love.
Consider dietary restrictions, price preferences, and past experiences."""
)
@recommendation_agent.system_prompt
async def add_recommendation_context(ctx: RunContext[RecommendationContext]) -> str:
"""Generate a smart recommendation prompt based on customer context."""
prompt = f"""Recommend a restaurant for {ctx.deps.customer_name}.
Current request: {ctx.deps.current_request}
Customer preferences:"""
prefs = ctx.deps.preferences
if prefs.favorite_cuisines:
prompt += f"\n- Favorite cuisines: {', '.join(prefs.favorite_cuisines)}"
if prefs.dietary_restrictions:
prompt += f"\n- Dietary restrictions: {', '.join(prefs.dietary_restrictions)}"
if prefs.preferred_price_range:
prompt += f"\n- Price range: {', '.join(prefs.preferred_price_range)}"
if prefs.typical_party_size:
prompt += f"\n- Usually dines with: {prefs.typical_party_size} people"
if prefs.favorite_restaurants:
prompt += f"\n- Favorite restaurants: {', '.join(prefs.favorite_restaurants)}"
if prefs.preferred_locations:
prompt += f"\n- Preferred locations: {', '.join(prefs.preferred_locations)}"
if ctx.deps.dining_history:
prompt += "\n\nRecent dining experiences:"
recent_experiences = sorted(ctx.deps.dining_history, key=lambda x: x.date, reverse=True)[:3]
for exp in recent_experiences:
prompt += f"\n- {exp.restaurant} ({exp.date.strftime('%Y-%m-%d')})"
prompt += f"\n Rating: {exp.rating}/5"
prompt += f"\n Liked: {', '.join(exp.liked_dishes)}"
if exp.disliked_dishes:
prompt += f"\n Disliked: {', '.join(exp.disliked_dishes)}"
if ctx.deps.rejected_recommendations:
prompt += "\n\nPreviously rejected suggestions:"
for rest in ctx.deps.rejected_recommendations:
prompt += f"\n- {rest.name} ({rest.cuisine}, {rest.price_range})"
prompt += "\nPlease suggest different restaurants from these."
current_context = []
if ctx.deps.current_occasion:
current_context.append(f"Occasion: {ctx.deps.current_occasion}")
if ctx.deps.party_size:
current_context.append(f"Party size: {ctx.deps.party_size}")
if ctx.deps.time_of_day:
current_context.append(f"Time: {ctx.deps.time_of_day}")
if current_context:
prompt += "\n\nCurrent context:"
prompt += "\n- " + "\n- ".join(current_context)
return prompt
def print_context_update(title: str, before: dict, after: dict) -> None:
"""Print changes in context in a readable format."""
print(f"\n📊 {title}")
print("Changes detected:")
for key in set(before.keys()) | set(after.keys()):
if key not in before:
print(f"+ Added {key}: {after[key]}")
elif key not in after:
print(f"- Removed {key}: {before[key]}")
elif before[key] != after[key]:
if isinstance(before[key], list):
# Handle list changes
added = set(after[key]) - set(before[key])
removed = set(before[key]) - set(after[key])
if added:
print(f"+ Added to {key}: {added}")
if removed:
print(f"- Removed from {key}: {removed}")
else:
print(f"~ Changed {key}: {before[key]} → {after[key]}")
# Example restaurant database
RESTAURANTS = {
"Bella Italia": Restaurant(
name="Bella Italia",
cuisine="Italian",
price_range="$$",
vegetarian_friendly=True,
gluten_free_options=True,
average_rating=4.5,
specialties=["Margherita Pizza", "Eggplant Parmesan", "Homemade Pasta"],
location="Downtown",
typical_wait_time="20-30 minutes"
),
"Sushi Master": Restaurant(
name="Sushi Master",
cuisine="Japanese",
price_range="$$$",
vegetarian_friendly=True,
gluten_free_options=True,
average_rating=4.8,
specialties=["Omakase", "Dragon Roll", "Fresh Sashimi"],
location="Marina District",
typical_wait_time="30-45 minutes"
),
"The Prime Cut": Restaurant(
name="The Prime Cut",
cuisine="Steakhouse",
price_range="$$$$",
vegetarian_friendly=False,
gluten_free_options=True,
average_rating=4.7,
specialties=["Ribeye Steak", "Wagyu Beef", "Lobster Tail"],
location="Financial District",
typical_wait_time="45-60 minutes"
),
"Thai Spice": Restaurant(
name="Thai Spice",
cuisine="Thai",
price_range="$$",
vegetarian_friendly=True,
gluten_free_options=True,
average_rating=4.4,
specialties=["Pad Thai", "Green Curry", "Mango Sticky Rice"],
location="Hayes Valley",
typical_wait_time="15-20 minutes"
)
}
# Customer database simulation
CUSTOMERS = {
"Alice": CustomerPreferences(
favorite_cuisines=["Italian", "Thai"],
dietary_restrictions=["Vegetarian"],
preferred_price_range=["$", "$$"],
typical_party_size=4,
preferred_locations=["Downtown", "Hayes Valley"],
usual_occasions=["Family dinner", "Casual dining"]
),
"Bob": CustomerPreferences(
favorite_cuisines=["Japanese", "Steakhouse"],
dietary_restrictions=["Gluten-free"],
preferred_price_range=["$$$", "$$$$"],
typical_party_size=2,
preferred_locations=["Financial District", "Marina District"],
usual_occasions=["Business dinner", "Special occasion"]
)
}
# Initialize or load customer contexts
CUSTOMER_CONTEXTS = {}
for name, prefs in CUSTOMERS.items():
# Try to load existing context
loaded_context = ContextPersistence.load_context(name)
if loaded_context:
CUSTOMER_CONTEXTS[name] = loaded_context
else:
# Create new context if none exists
CUSTOMER_CONTEXTS[name] = CustomerContext(name, prefs)
async def update_customer_preferences(
preferences: CustomerPreferences,
restaurant: Restaurant,
experience: DiningExperience
) -> CustomerPreferences:
"""Update customer preferences based on their dining experience."""
# Store original state for comparison
original_state = preferences.model_dump()
# Update favorite restaurants if highly rated
if experience.rating >= 4:
if restaurant.name not in preferences.favorite_restaurants:
preferences.favorite_restaurants.append(restaurant.name)
if restaurant.cuisine not in preferences.favorite_cuisines:
preferences.favorite_cuisines.append(restaurant.cuisine)
# Update disliked restaurants if poorly rated
elif experience.rating <= 2:
if restaurant.name not in preferences.disliked_restaurants:
preferences.disliked_restaurants.append(restaurant.name)
# Update typical party size (weighted average)
preferences.typical_party_size = (
(preferences.typical_party_size * 2 + experience.party_size) // 3
)
# Update occasion preferences
if experience.occasion and experience.occasion not in preferences.usual_occasions:
preferences.usual_occasions.append(experience.occasion)
# Show context updates
print_context_update(
f"Context Updates for {experience.restaurant}",
original_state,
preferences.model_dump()
)
return preferences
async def get_restaurant_recommendation(
customer_name: str,
request: str,
occasion: Optional[str] = None,
party_size: Optional[int] = None,
time_of_day: Optional[str] = None
) -> Optional[Restaurant]:
"""Get a personalized restaurant recommendation for a customer."""
try:
# Get customer preferences
preferences = CUSTOMERS.get(customer_name)
if not preferences:
raise ValueError(f"Customer {customer_name} not found")
# Show initial context
print(f"\n📋 Initial Context for {customer_name}:")
print(f"Favorite cuisines: {preferences.favorite_cuisines}")
print(f"Dietary restrictions: {preferences.dietary_restrictions}")
print(f"Price range: {preferences.preferred_price_range}")
print(f"Typical party size: {preferences.typical_party_size}")
if preferences.favorite_restaurants:
print(f"Favorite restaurants: {preferences.favorite_restaurants}")
if preferences.disliked_restaurants:
print(f"Disliked restaurants: {preferences.disliked_restaurants}")
# Create recommendation context
context = RecommendationContext(
customer_name=customer_name,
preferences=preferences,
current_request=request,
current_occasion=occasion,
party_size=party_size,
time_of_day=time_of_day
)
# Show current request context
print(f"\n🔍 Request Context:")
print(f"Request: {request}")
if occasion:
print(f"Occasion: {occasion}")
if party_size:
print(f"Party size: {party_size}")
if time_of_day:
print(f"Time: {time_of_day}")
# Get recommendation
print(f"\n🤔 [Recommendation Agent] Finding the perfect restaurant for {customer_name}...")
result = await recommendation_agent.run(
user_prompt=request,
deps=context
)
# For demo purposes, map the recommendation to our restaurant database
recommended = RESTAURANTS.get(result.data.name)
if not recommended:
# Find the most similar restaurant in our database
for restaurant in RESTAURANTS.values():
if (restaurant.cuisine == result.data.cuisine and
restaurant.price_range == result.data.price_range):
recommended = restaurant
break
if recommended:
print(f"\n✨ Found recommendation: {recommended.name}")
print(f"Cuisine: {recommended.cuisine}")
print(f"Price Range: {recommended.price_range}")
print(f"Location: {recommended.location}")
print(f"Specialties: {', '.join(recommended.specialties)}")
print(f"Average Rating: {recommended.average_rating}/5")
print(f"Wait Time: {recommended.typical_wait_time}")
dietary_info = []
if recommended.vegetarian_friendly:
dietary_info.append("Vegetarian-friendly")
if recommended.gluten_free_options:
dietary_info.append("Gluten-free options")
if dietary_info:
print(f"Dietary Options: {', '.join(dietary_info)}")
return recommended
except Exception as e:
print(f"\n❌ Error getting recommendation: {str(e)}")
return None
async def simulate_dining_experience(
customer_name: str,
restaurant: Restaurant,
rating: int,
liked_dishes: List[str],
disliked_dishes: List[str],
comments: str,
party_size: int,
occasion: Optional[str] = None
) -> None:
"""Simulate a dining experience and update customer preferences."""
try:
# Create dining experience
experience = DiningExperience(
restaurant=restaurant.name,
date=datetime.now(),
rating=rating,
liked_dishes=liked_dishes,
disliked_dishes=disliked_dishes,
comments=comments,
party_size=party_size,
occasion=occasion
)
# Update customer context
customer_ctx = CUSTOMER_CONTEXTS[customer_name]
customer_ctx.add_visit(experience, restaurant)
print(f"\n✍️ Updated {customer_name}'s preferences based on experience at {restaurant.name}")
if rating >= 4:
print(f"Added to favorite restaurants!")
elif rating <= 2:
print(f"Added to disliked restaurants.")
except Exception as e:
print(f"\n❌ Error updating preferences: {str(e)}")
async def main():
"""Demonstrate context-aware restaurant recommendations."""
print("\n=== Restaurant Recommendation System with Context ===\n")
# Don't clean up context files this time - we want to load existing context
# Scenario 1: Alice looking for lunch
print("\n👩 Alice's New Request:")
print("Looking for lunch after previous visits to Bella Italia")
restaurant1 = await get_restaurant_recommendation(
customer_name="Alice",
request="Want a light lunch, something different from Italian",
occasion="Lunch",
party_size=2,
time_of_day="Afternoon"
)
if restaurant1:
# Simulate Alice's dining experience
await simulate_dining_experience(
customer_name="Alice",
restaurant=restaurant1,
rating=4,
liked_dishes=["Green Curry", "Spring Rolls"],
disliked_dishes=[],
comments="Nice change from Italian, loved the vegetarian options!",
party_size=2,
occasion="Lunch"
)
print("\n" + "="*50)
# Scenario 2: Bob trying Japanese
print("\n👨 Bob's New Request:")
print("Looking for Japanese after previous steakhouse experience")
restaurant2 = await get_restaurant_recommendation(
customer_name="Bob",
request="Want to try some high-end Japanese food",
occasion="Date night",
party_size=2,
time_of_day="Evening"
)
if restaurant2:
# Simulate Bob's dining experience
await simulate_dining_experience(
customer_name="Bob",
restaurant=restaurant2,
rating=5,
liked_dishes=["Omakase", "Fresh Sashimi"],
disliked_dishes=[],
comments="Amazing Japanese food, perfect for date night",
party_size=2,
occasion="Date night"
)
# Print context journeys showing all visits including previous ones
print("\n=== Complete Context Journey ===")
CUSTOMER_CONTEXTS["Alice"].print_context_journey()
print("\n")
CUSTOMER_CONTEXTS["Bob"].print_context_journey()
print("\n" + "="*50 + "\n")
print("Demo completed! Showed how context persists and improves over time:")
print("1. Loaded previous dining experiences")
print("2. Made recommendations considering past visits")
print("3. Updated preferences with new experiences")
print("4. Maintained complete dining history")
print("5. Evolved cuisine preferences")
# Show where context files are saved
print("\n📁 Context files are saved in:", CONTEXT_DIR.absolute())
print("You can find the following files:")
for file in CONTEXT_DIR.glob("*_context.json"):
print(f"- {file.name}")
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment