Skip to content

Instantly share code, notes, and snippets.

@cannin
Created August 16, 2024 15:37
Show Gist options
  • Save cannin/dda2e0df89b96226514b56d3fa8f104e to your computer and use it in GitHub Desktop.
Save cannin/dda2e0df89b96226514b56d3fa8f104e to your computer and use it in GitHub Desktop.
DSPY Example with data from CSV and MIPROv2 optimizer; also, simple viewer code to inspect DSPY prompt history; dspy-ai = "2.4.13"
import json
import dspy
from dspy.teleprompt import BootstrapFewShot, LabeledFewShot
from dspy.evaluate import Evaluate, answer_exact_match
from dspy.datasets import DataLoader
# PURPOSE ----
# Example with dataset read from CSV using LLM judge to evaluate the answer
class BasicQASignature(dspy.Signature):
__doc__ = """Answer questions with short answers."""
question = dspy.InputField(desc="Question to be answered")
#context = dspy.InputField(desc="Context for the prediction")
answer = dspy.OutputField(desc="Answer with 1 to 5 words")
class BasicQA(dspy.Module):
def __init__(self):
super().__init__()
self.generate_answer = dspy.Predict(BasicQASignature)
def forward(self, question):
return self.generate_answer(question=question)
class FactLangJudge(dspy.Signature):
"""Judge if the answer in Spanish and correct."""
context = dspy.InputField(desc="Context for the prediction")
question = dspy.InputField(desc="Question to be answered")
answer = dspy.InputField(desc="Answer for the question")
#judge_answer = dspy.OutputField(desc="Is the answer correct and in Spanish?", prefix="Factual[Yes/No]:")
judge_answer = dspy.OutputField(desc="Is the answer correct and in Spanish? Answer with only Yes or No")
def llm_metric(example, pred, trace=True):
judge = dspy.ChainOfThought(FactLangJudge)
output = judge(context=example.context, question=example.question, answer=pred.answer)
# judge_answer comes from the FactLangJudge object
output_str = output.judge_answer.lower()
print(output_str)
#print("yes" in output_str)
return "yes" in output_str
#return int(output.judge_answer.lower() == "yes")
#return int(output == "Yes")
#return True
# LOAD DATA ----
print("LOAD DATA ----")
#filename = "dolly_subset_100_rows_question.csv",
filename = "us_state_capitals_3.csv"
dl = DataLoader()
dataset = dl.from_csv(
filename,
fields=["question", "context", "answer"],
input_keys=["question"]
)
splits = dl.train_test_split(dataset=dataset, train_size=0.8)
train_dataset = splits['train']
test_dataset = splits['test']
# Set up the LM
print("SETUP LLM ----")
model = 'gemma2:2b'
#model = 'qwen2:7b'
lm = dspy.OllamaLocal(model=model)
dspy.settings.configure(lm=lm)
print("SETUP OPTIMIZATION ----")
# Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 4-shot examples of our CoT program.
#config = dict(max_bootstrapped_demos=4, max_labeled_demos=4)
print("SETUP METRIC ----")
# Use the `answer_exact_match` here. In general, the metric is going to tell the optimizer how well it's doing.
#metric = answer_exact_match
metric = llm_metric
# Optimize!
#teleprompter = BootstrapFewShot(metric=answer_exact_match, max_bootstrapped_demos=4, max_labeled_demos=4)
teleprompter = BootstrapFewShot(metric=metric, max_bootstrapped_demos=4, max_labeled_demos=4)
# Skip LabeledFewShot: https://dspy-docs.vercel.app/docs/building-blocks/optimizers
#teleprompter = LabeledFewShot(k=3)
lm.inspect_history(n=1)
# Ask questions
print("ASK QUESTIONS ----")
print("dspy.Predict")
predictor_basic = dspy.Predict("question -> answer")
output = predictor_basic(question="What colors are on the American flag?")
print(output)
print("dspy.Predict(BasicQASignature)")
predictor_sig = dspy.Predict(BasicQASignature)
predictor_sig(question="What colors are on the American flag?")
print("BasicQA()")
predictor_module = BasicQA()
predictor_module(question="What colors are on the American flag?")
print("COMPILE PROGRAM ----")
optimized_program = teleprompter.compile(predictor_module, trainset=train_dataset)
print("EVALUATE ----")
evaluater = Evaluate(
devset=train_dataset,
metric=metric,
num_threads=1,
display_progress=True,
display_table=5,
return_outputs=True)
score, outputs = evaluater(optimized_program)
print("INSPECT ----")
lm.inspect_history(n=1)
print("USE ----")
optimized_program(question='What is the capital of France?')
print("SAVE ----")
optimized_program.save("v1.json")
# outputs[0][0].toDict()
# Iterate through a list of tuples with one Example element converted to dict using toDict()
seralizable_outputs = []
for obj in outputs:
seralizable_outputs.append(obj[0].toDict())
# Write seralizable_outputs to a file
with open("outputs.json", "w") as f:
f.write(json.dumps(seralizable_outputs, indent=2))
# LOAD PROGRAM ----
#loaded_program = YOUR_PROGRAM_CLASS()
#loaded_program.load(path=YOUR_SAVE_PATH)
import dspy
import dsp
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric
from dspy.teleprompt import MIPRO, MIPROv2
import pickle
# PURPOSE ----
# Example with built-in dataset using advanced MIPROv2 optimizer
class CoT(dspy.Module):
def __init__(self):
super().__init__()
self.prog = dspy.ChainOfThought("question -> answer")
def forward(self, question):
return self.prog(question=question)
print("SETUP LLM ----")
model = 'gemma2:2b'
#model = 'llama3.1:70b-instruct-q4_0'
lm = dspy.OllamaLocal(model=model)
dspy.settings.configure(lm=lm)
gsm8k = GSM8K()
gsm8k_trainset, gsm8k_devset = gsm8k.train[:20], gsm8k.dev[:20]
# MIPRO
# teleprompter_optimizer = MIPRO(metric=gsm8k_metric, num_candidates=4)
# kwargs = dict(num_threads=1, display_progress=True, display_table=0)
# optimized_cot = teleprompter_optimizer.compile(
# CoT(),
# trainset=gsm8k_trainset,
# num_trials=3,
# max_bootstrapped_demos=3,
# max_labeled_demos=5,
# eval_kwargs=kwargs,
# requires_permission_to_run=False)
# MIPROv2
teleprompter_optimizer = MIPROv2(metric=gsm8k_metric, num_candidates=4, prompt_model=lm, task_model=lm)
kwargs = dict(lm=lm, num_threads=1, display_progress=True, display_table=0)
optimized_cot = teleprompter_optimizer.compile(
CoT(),
trainset=gsm8k_trainset,
max_bootstrapped_demos=3,
max_labeled_demos=5,
eval_kwargs=kwargs,
requires_permission_to_run=False)
print(optimized_cot)
# SAVE TO PICKLE ----
f = open('history.pkl', 'wb')
with open('history.pkl', 'wb') as f:
pickle.dump(lm.history, f)
# with open('history.pkl', 'rb') as f:
# lm_history = pickle.load(f)
import os
import json
import pickle
from git.repo import Repo
# PURPOSE ----
# Inspect dspy prompt history using git commits and https://github.com/pomber/git-history/tree/master/cli
# Function to split long lines into smaller ones
def split_long_lines(text, max_length=100):
lines = text.splitlines() # Split the string into lines
result = []
for line in lines:
while len(line) > max_length:
# Find the last space before the max_length
break_point = line.rfind(' ', 0, max_length)
if break_point == -1: # No spaces found, force break
break_point = max_length
result.append(line[:break_point])
line = line[break_point:].lstrip()
result.append(line)
return result
repo_path = 'git_tmp'
repo = Repo(repo_path)
with open('history.pkl', 'rb') as f:
lm_history = pickle.load(f)
for i in range(len(lm_history)):
print("GET PROMPT ----")
tmp = lm_history[i]['prompt']
tmp = split_long_lines(tmp, max_length=100)
final_str = '\n'.join(tmp)
file = 'prompt.json'
file_path = os.path.join(repo_path, file)
#with open(file_path, 'w') as f:
# json.dump(tmp, f, indent=2)
with open(file_path, 'w') as f:
f.write(final_str)
print("COMMIT PROMPT ----")
repo.index.add([file])
commit_msg = f"Prompt: {i}"
repo.index.commit(commit_msg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment