|
#!/usr/bin/env python3 |
|
""" |
|
DSPy Grug Translator: end-to-end example from LearnByBuilding.ai tutorial |
|
Complete implementation of translating text to "Grug speak" using DSPy framework |
|
Updated with CLI interface and modular utilities |
|
""" |
|
|
|
import argparse |
|
import logging |
|
import os |
|
import random |
|
import time |
|
from functools import wraps |
|
from random import shuffle |
|
from typing import Optional |
|
|
|
import dspy |
|
import requests |
|
from bs4 import BeautifulSoup |
|
from dotenv import load_dotenv |
|
|
|
# Import utilities from our utils module |
|
from utils_dspy_grug import ( |
|
translate_grug, |
|
overall_metric, |
|
similarity_metric, |
|
ari_metric |
|
) |
|
|
|
# Load environment variables |
|
load_dotenv() |
|
|
|
# Setup logging |
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
# Retry decorator |
|
def retry(max_attempts=3, delay=1): |
|
def decorator(func): |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
for attempt in range(max_attempts): |
|
try: |
|
return func(*args, **kwargs) |
|
except Exception as e: |
|
if attempt == max_attempts - 1: |
|
raise |
|
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay} seconds...") |
|
time.sleep(delay) |
|
return None |
|
return wrapper |
|
return decorator |
|
|
|
# Function to split data for train/test |
|
def split_data(data, train_ratio=0.8, seed=42): |
|
""" |
|
Split data into training and test sets |
|
""" |
|
random.seed(seed) |
|
shuffled_data = data.copy() |
|
shuffle(shuffled_data) |
|
|
|
split_index = int(len(shuffled_data) * train_ratio) |
|
train_data = shuffled_data[:split_index] |
|
test_data = shuffled_data[split_index:] |
|
|
|
return train_data, test_data |
|
|
|
@retry(max_attempts=3) |
|
def fetch_reddit_posts(num_posts=50): |
|
""" |
|
Fetch Reddit posts from r/explainlikeimfive |
|
""" |
|
logger.info(f"Fetching {num_posts} Reddit posts...") |
|
|
|
url = "https://www.reddit.com/r/explainlikeimfive/hot/.json" |
|
headers = {'User-Agent': 'DSPy Grug Translator 1.0'} |
|
|
|
response = requests.get(url, headers=headers) |
|
response.raise_for_status() |
|
|
|
data = response.json() |
|
posts = [] |
|
|
|
for post_data in data['data']['children'][:num_posts]: |
|
post = post_data['data'] |
|
title = post['title'] |
|
|
|
if title and len(title) > 20 and not title.startswith('['): |
|
# Split by sentence-ending punctuation or newlines |
|
sentences = [s.strip() for s in |
|
re.split(r'[.!?\n]', title) if s.strip()] |
|
|
|
for sentence in sentences: |
|
if len(sentence) > 10: |
|
posts.append(sentence) |
|
|
|
if len(posts) >= num_posts: |
|
break |
|
|
|
if len(posts) >= num_posts: |
|
break |
|
|
|
logger.info(f"Successfully fetched {len(posts)} posts") |
|
return posts |
|
|
|
def configure_dspy(provider, model_name=None, eval_provider=None, eval_model_name=None): |
|
""" |
|
Configure DSPy with the specified provider and model |
|
""" |
|
logger.info(f"Configuring DSPy with provider: {provider}") |
|
|
|
if provider == "openai": |
|
lm = dspy.OpenAI(model="gpt-3.5-turbo") |
|
elif provider == "ollama": |
|
if not model_name: |
|
model_name = "llama3.2" |
|
lm = dspy.Ollama(model=model_name) |
|
else: |
|
raise ValueError(f"Unsupported provider: {provider}") |
|
|
|
# Configure evaluation provider if specified |
|
eval_lm = None |
|
if eval_provider: |
|
logger.info(f"Configuring evaluation provider: {eval_provider}") |
|
if eval_provider == "openai": |
|
eval_lm = dspy.OpenAI(model="gpt-3.5-turbo") |
|
elif eval_provider == "ollama": |
|
if not eval_model_name: |
|
eval_model_name = "llama3.2" |
|
eval_lm = dspy.Ollama(model=eval_model_name) |
|
else: |
|
logger.warning(f"Unsupported eval provider: {eval_provider}, using main provider") |
|
|
|
dspy.configure(lm=lm) |
|
return lm, eval_lm |
|
|
|
@retry(max_attempts=3) |
|
def run_optimization(train_data, test_data, num_threads=1): |
|
""" |
|
Run DSPy optimization on the training data |
|
""" |
|
logger.info("Starting optimization process...") |
|
|
|
# Create optimizer |
|
teleprompter = dspy.BootstrapFewShot( |
|
metric=overall_metric, |
|
max_bootstrapped_demos=4, |
|
max_labeled_demos=16, |
|
num_threads=num_threads |
|
) |
|
|
|
# Optimize |
|
optimized_program = teleprompter.compile( |
|
translate_grug, |
|
trainset=train_data |
|
) |
|
|
|
logger.info("Optimization completed") |
|
return optimized_program |
|
|
|
@retry(max_attempts=3) |
|
def evaluate_program(program, test_data, eval_lm=None): |
|
""" |
|
Evaluate the program on test data |
|
""" |
|
logger.info("Starting evaluation...") |
|
|
|
# If we have a separate evaluation LM, configure it temporarily |
|
original_lm = None |
|
if eval_lm: |
|
original_lm = dspy.settings.lm |
|
dspy.configure(lm=eval_lm) |
|
|
|
try: |
|
evaluator = dspy.Evaluate( |
|
devset=test_data, |
|
metric=overall_metric, |
|
num_threads=1, |
|
display_progress=True |
|
) |
|
|
|
result = evaluator(program) |
|
logger.info(f"Evaluation completed with score: {result}") |
|
return result |
|
|
|
finally: |
|
# Restore original LM if we changed it |
|
if original_lm: |
|
dspy.configure(lm=original_lm) |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="DSPy Grug Translator with CLI interface") |
|
|
|
# Provider configuration |
|
parser.add_argument("--provider", choices=["openai", "ollama"], default="openai", |
|
help="LLM provider to use (default: openai)") |
|
parser.add_argument("--ollama-model", default="llama3.2", |
|
help="Ollama model name (default: llama3.2)") |
|
|
|
# Evaluation provider configuration |
|
parser.add_argument("--eval-provider", choices=["openai", "ollama"], |
|
help="Separate LLM provider for evaluation (optional)") |
|
parser.add_argument("--eval-ollama-model", default="llama3.2", |
|
help="Ollama model for evaluation (default: llama3.2)") |
|
|
|
# Data and processing options |
|
parser.add_argument("--num-posts", type=int, default=100, |
|
help="Number of Reddit posts to fetch (default: 100)") |
|
parser.add_argument("--train-ratio", type=float, default=0.8, |
|
help="Ratio of data to use for training (default: 0.8)") |
|
parser.add_argument("--num-threads", type=int, default=1, |
|
help="Number of threads for optimization (default: 1)") |
|
|
|
# Random seed |
|
parser.add_argument("--seed", type=int, default=42, |
|
help="Random seed for reproducibility (default: 42)") |
|
|
|
# Skip optimization flag |
|
parser.add_argument("--skip-optimization", action="store_true", |
|
help="Skip the optimization step") |
|
|
|
# Interactive mode |
|
parser.add_argument("--interactive", action="store_true", |
|
help="Run in interactive mode for manual testing") |
|
|
|
args = parser.parse_args() |
|
|
|
# Set random seed |
|
random.seed(args.seed) |
|
|
|
try: |
|
# Configure DSPy |
|
lm, eval_lm = configure_dspy( |
|
provider=args.provider, |
|
model_name=args.ollama_model if args.provider == "ollama" else None, |
|
eval_provider=args.eval_provider, |
|
eval_model_name=args.eval_ollama_model if args.eval_provider == "ollama" else None |
|
) |
|
|
|
if args.interactive: |
|
logger.info("Starting interactive mode...") |
|
print("\nInteractive Grug Translator Mode") |
|
print("Type 'quit' to exit\n") |
|
|
|
while True: |
|
text = input("Enter text to translate to Grug speak: ").strip() |
|
if text.lower() == 'quit': |
|
break |
|
|
|
if text: |
|
try: |
|
result = translate_grug(text=text) |
|
print(f"Grug speak: {result.translation}\n") |
|
except Exception as e: |
|
logger.error(f"Translation failed: {e}") |
|
return |
|
|
|
# Fetch data |
|
reddit_posts = fetch_reddit_posts(args.num_posts) |
|
|
|
if len(reddit_posts) < 10: |
|
logger.error("Not enough data fetched. Exiting.") |
|
return |
|
|
|
# Create training examples |
|
examples = [dspy.Example(text=post).with_inputs('text') for post in reddit_posts] |
|
|
|
# Split data |
|
train_data, test_data = split_data(examples, args.train_ratio, args.seed) |
|
logger.info(f"Split data: {len(train_data)} training, {len(test_data)} test examples") |
|
|
|
if not args.skip_optimization: |
|
# Run optimization |
|
optimized_program = run_optimization(train_data, test_data, args.num_threads) |
|
else: |
|
logger.info("Skipping optimization, using base program") |
|
optimized_program = translate_grug |
|
|
|
# Evaluate |
|
if test_data: |
|
evaluation_score = evaluate_program(optimized_program, test_data, eval_lm) |
|
logger.info(f"Final evaluation score: {evaluation_score}") |
|
|
|
# Test with a sample |
|
sample_text = "Why do we need to sleep every night?" |
|
logger.info(f"Testing with sample: '{sample_text}'") |
|
|
|
result = optimized_program(text=sample_text) |
|
print(f"\nOriginal: {sample_text}") |
|
print(f"Grug speak: {result.translation}") |
|
|
|
logger.info("Process completed successfully!") |
|
|
|
except KeyboardInterrupt: |
|
logger.info("Process interrupted by user") |
|
except Exception as e: |
|
logger.error(f"An error occurred: {e}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |