Created
January 27, 2025 21:31
-
-
Save dbreunig/74a586f359d8e5111c559d6b7003903a to your computer and use it in GitHub Desktop.
Using DSPy to prompt-optimize this post: https://thescoop.org/archives/2025/01/27/llm-extraction-challenge-fundraising-emails/index.html
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dspy | |
from dspy.evaluate import Evaluate | |
from dspy.teleprompt import * | |
import csv | |
# Load the emails | |
with open('training.csv', 'r') as file: | |
reader = csv.DictReader(file) | |
emails = [row for row in reader] | |
# Define our DSPy Signature | |
class WhoPaid(dspy.Signature): | |
"""Return the committee which paid for the email.""" | |
email: str = dspy.InputField(description="The campaign email body to analyze.") | |
committee: str = dspy.OutputField(description="The committee which paid for the email.") | |
who_paid = dspy.Predict(WhoPaid) | |
# Define our validation function | |
def validate_paid(example, prediction, trace=None): | |
return example['committee'] == prediction['committee'] | |
# Import our trainset | |
trainset = [] | |
for email in emails: | |
example = dspy.Example(email=email['body'], committee=email['committee']).with_inputs("email") | |
trainset.append(example) | |
trainset = trainset[:100] | |
# Evaluate our out-of-the-box function | |
with dspy.context(lm=dspy.LM(f"openai/llama3.2:latest", api_base='http://localhost:11434/v1', api_key='ollama')): | |
evaluator = Evaluate(devset=trainset, num_threads=1, display_progress=True, display_table=5) | |
evaluator(who_paid, metric=validate_paid) | |
# Load our task and prompt gen models | |
lm = dspy.LM(f"openai/llama3.2:latest", api_base='http://localhost:11434/v1', api_key='ollama') | |
prompt_gen_lm = dspy.LM(f"openai/llama3.3:latest", api_base='http://localhost:11434/v1', api_key='ollama') | |
dspy.configure(lm=lm) | |
# Optimize | |
tp = dspy.MIPROv2(metric=validate_paid, auto="light", prompt_model=prompt_gen_lm, task_model=lm) | |
optimized_who_paid = tp.compile(who_paid, trainset=trainset, max_labeled_demos=0, max_bootstrapped_demos=0) | |
# Save our optimization with: | |
optimized_who_paid.save("optimized_paid_for.json") | |
# Reload the trainset with all 1k records | |
trainset = [] | |
for email in emails: | |
example = dspy.Example(email=email['body'], committee=email['committee']).with_inputs("email") | |
trainset.append(example) | |
# Evaluate our existing function | |
with dspy.context(lm=dspy.LM(f"openai/llama3.2:latest", api_base='http://localhost:11434/v1', api_key='ollama')): | |
evaluator = Evaluate(devset=trainset, num_threads=1, display_progress=True, display_table=5) | |
evaluator(optimized_who_paid, metric=validate_paid) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment