Created
August 16, 2024 15:37
-
-
Save cannin/dda2e0df89b96226514b56d3fa8f104e to your computer and use it in GitHub Desktop.
DSPY Example with data from CSV and MIPROv2 optimizer; also, simple viewer code to inspect DSPY prompt history; dspy-ai = "2.4.13"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import dspy | |
from dspy.teleprompt import BootstrapFewShot, LabeledFewShot | |
from dspy.evaluate import Evaluate, answer_exact_match | |
from dspy.datasets import DataLoader | |
# PURPOSE ---- | |
# Example with dataset read from CSV using LLM judge to evaluate the answer | |
class BasicQASignature(dspy.Signature): | |
__doc__ = """Answer questions with short answers.""" | |
question = dspy.InputField(desc="Question to be answered") | |
#context = dspy.InputField(desc="Context for the prediction") | |
answer = dspy.OutputField(desc="Answer with 1 to 5 words") | |
class BasicQA(dspy.Module): | |
def __init__(self): | |
super().__init__() | |
self.generate_answer = dspy.Predict(BasicQASignature) | |
def forward(self, question): | |
return self.generate_answer(question=question) | |
class FactLangJudge(dspy.Signature): | |
"""Judge if the answer in Spanish and correct.""" | |
context = dspy.InputField(desc="Context for the prediction") | |
question = dspy.InputField(desc="Question to be answered") | |
answer = dspy.InputField(desc="Answer for the question") | |
#judge_answer = dspy.OutputField(desc="Is the answer correct and in Spanish?", prefix="Factual[Yes/No]:") | |
judge_answer = dspy.OutputField(desc="Is the answer correct and in Spanish? Answer with only Yes or No") | |
def llm_metric(example, pred, trace=True): | |
judge = dspy.ChainOfThought(FactLangJudge) | |
output = judge(context=example.context, question=example.question, answer=pred.answer) | |
# judge_answer comes from the FactLangJudge object | |
output_str = output.judge_answer.lower() | |
print(output_str) | |
#print("yes" in output_str) | |
return "yes" in output_str | |
#return int(output.judge_answer.lower() == "yes") | |
#return int(output == "Yes") | |
#return True | |
# LOAD DATA ---- | |
print("LOAD DATA ----") | |
#filename = "dolly_subset_100_rows_question.csv", | |
filename = "us_state_capitals_3.csv" | |
dl = DataLoader() | |
dataset = dl.from_csv( | |
filename, | |
fields=["question", "context", "answer"], | |
input_keys=["question"] | |
) | |
splits = dl.train_test_split(dataset=dataset, train_size=0.8) | |
train_dataset = splits['train'] | |
test_dataset = splits['test'] | |
# Set up the LM | |
print("SETUP LLM ----") | |
model = 'gemma2:2b' | |
#model = 'qwen2:7b' | |
lm = dspy.OllamaLocal(model=model) | |
dspy.settings.configure(lm=lm) | |
print("SETUP OPTIMIZATION ----") | |
# Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 4-shot examples of our CoT program. | |
#config = dict(max_bootstrapped_demos=4, max_labeled_demos=4) | |
print("SETUP METRIC ----") | |
# Use the `answer_exact_match` here. In general, the metric is going to tell the optimizer how well it's doing. | |
#metric = answer_exact_match | |
metric = llm_metric | |
# Optimize! | |
#teleprompter = BootstrapFewShot(metric=answer_exact_match, max_bootstrapped_demos=4, max_labeled_demos=4) | |
teleprompter = BootstrapFewShot(metric=metric, max_bootstrapped_demos=4, max_labeled_demos=4) | |
# Skip LabeledFewShot: https://dspy-docs.vercel.app/docs/building-blocks/optimizers | |
#teleprompter = LabeledFewShot(k=3) | |
lm.inspect_history(n=1) | |
# Ask questions | |
print("ASK QUESTIONS ----") | |
print("dspy.Predict") | |
predictor_basic = dspy.Predict("question -> answer") | |
output = predictor_basic(question="What colors are on the American flag?") | |
print(output) | |
print("dspy.Predict(BasicQASignature)") | |
predictor_sig = dspy.Predict(BasicQASignature) | |
predictor_sig(question="What colors are on the American flag?") | |
print("BasicQA()") | |
predictor_module = BasicQA() | |
predictor_module(question="What colors are on the American flag?") | |
print("COMPILE PROGRAM ----") | |
optimized_program = teleprompter.compile(predictor_module, trainset=train_dataset) | |
print("EVALUATE ----") | |
evaluater = Evaluate( | |
devset=train_dataset, | |
metric=metric, | |
num_threads=1, | |
display_progress=True, | |
display_table=5, | |
return_outputs=True) | |
score, outputs = evaluater(optimized_program) | |
print("INSPECT ----") | |
lm.inspect_history(n=1) | |
print("USE ----") | |
optimized_program(question='What is the capital of France?') | |
print("SAVE ----") | |
optimized_program.save("v1.json") | |
# outputs[0][0].toDict() | |
# Iterate through a list of tuples with one Example element converted to dict using toDict() | |
seralizable_outputs = [] | |
for obj in outputs: | |
seralizable_outputs.append(obj[0].toDict()) | |
# Write seralizable_outputs to a file | |
with open("outputs.json", "w") as f: | |
f.write(json.dumps(seralizable_outputs, indent=2)) | |
# LOAD PROGRAM ---- | |
#loaded_program = YOUR_PROGRAM_CLASS() | |
#loaded_program.load(path=YOUR_SAVE_PATH) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dspy | |
import dsp | |
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric | |
from dspy.teleprompt import MIPRO, MIPROv2 | |
import pickle | |
# PURPOSE ---- | |
# Example with built-in dataset using advanced MIPROv2 optimizer | |
class CoT(dspy.Module): | |
def __init__(self): | |
super().__init__() | |
self.prog = dspy.ChainOfThought("question -> answer") | |
def forward(self, question): | |
return self.prog(question=question) | |
print("SETUP LLM ----") | |
model = 'gemma2:2b' | |
#model = 'llama3.1:70b-instruct-q4_0' | |
lm = dspy.OllamaLocal(model=model) | |
dspy.settings.configure(lm=lm) | |
gsm8k = GSM8K() | |
gsm8k_trainset, gsm8k_devset = gsm8k.train[:20], gsm8k.dev[:20] | |
# MIPRO | |
# teleprompter_optimizer = MIPRO(metric=gsm8k_metric, num_candidates=4) | |
# kwargs = dict(num_threads=1, display_progress=True, display_table=0) | |
# optimized_cot = teleprompter_optimizer.compile( | |
# CoT(), | |
# trainset=gsm8k_trainset, | |
# num_trials=3, | |
# max_bootstrapped_demos=3, | |
# max_labeled_demos=5, | |
# eval_kwargs=kwargs, | |
# requires_permission_to_run=False) | |
# MIPROv2 | |
teleprompter_optimizer = MIPROv2(metric=gsm8k_metric, num_candidates=4, prompt_model=lm, task_model=lm) | |
kwargs = dict(lm=lm, num_threads=1, display_progress=True, display_table=0) | |
optimized_cot = teleprompter_optimizer.compile( | |
CoT(), | |
trainset=gsm8k_trainset, | |
max_bootstrapped_demos=3, | |
max_labeled_demos=5, | |
eval_kwargs=kwargs, | |
requires_permission_to_run=False) | |
print(optimized_cot) | |
# SAVE TO PICKLE ---- | |
f = open('history.pkl', 'wb') | |
with open('history.pkl', 'wb') as f: | |
pickle.dump(lm.history, f) | |
# with open('history.pkl', 'rb') as f: | |
# lm_history = pickle.load(f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import pickle | |
from git.repo import Repo | |
# PURPOSE ---- | |
# Inspect dspy prompt history using git commits and https://github.com/pomber/git-history/tree/master/cli | |
# Function to split long lines into smaller ones | |
def split_long_lines(text, max_length=100): | |
lines = text.splitlines() # Split the string into lines | |
result = [] | |
for line in lines: | |
while len(line) > max_length: | |
# Find the last space before the max_length | |
break_point = line.rfind(' ', 0, max_length) | |
if break_point == -1: # No spaces found, force break | |
break_point = max_length | |
result.append(line[:break_point]) | |
line = line[break_point:].lstrip() | |
result.append(line) | |
return result | |
repo_path = 'git_tmp' | |
repo = Repo(repo_path) | |
with open('history.pkl', 'rb') as f: | |
lm_history = pickle.load(f) | |
for i in range(len(lm_history)): | |
print("GET PROMPT ----") | |
tmp = lm_history[i]['prompt'] | |
tmp = split_long_lines(tmp, max_length=100) | |
final_str = '\n'.join(tmp) | |
file = 'prompt.json' | |
file_path = os.path.join(repo_path, file) | |
#with open(file_path, 'w') as f: | |
# json.dump(tmp, f, indent=2) | |
with open(file_path, 'w') as f: | |
f.write(final_str) | |
print("COMMIT PROMPT ----") | |
repo.index.add([file]) | |
commit_msg = f"Prompt: {i}" | |
repo.index.commit(commit_msg) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment