Created
January 11, 2024 18:56
-
-
Save heathermiller/5b44377d4a0fdbd36ef1e96ba1f683de to your computer and use it in GitHub Desktop.
Compiling in DSPy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dspy.teleprompt import BootstrapFewShot | |
class GenerateAnswer(dspy.Signature): | |
"""Answer questions with short factoid answers.""" | |
context = dspy.InputField(desc="may contain relevant facts") | |
question = dspy.InputField() | |
answer = dspy.OutputField(desc="often between 1 and 5 words") | |
class RAG(dspy.Module): | |
def __init__(self, num_passages=3): | |
super().__init__() | |
self.retrieve = dspy.Retrieve(k=num_passages) | |
self.generate_answer = dspy.ChainOfThought(GenerateAnswer) | |
def forward(self, question): | |
# DSPy Modules are define-by-run and can support arbitary code in its forward function | |
context = self.retrieve(question).passages | |
prediction = self.generate_answer(context=context, question=question) | |
return dspy.Prediction(context=context, answer=prediction.answer) | |
# Validation logic: check that the predicted answer is correct. | |
# Also check that the retrieved context does actually contain that answer. | |
def validate_context_and_answer(example, pred, trace=None): | |
answer_EM = dspy.evaluate.answer_exact_match(example, pred) | |
answer_PM = dspy.evaluate.answer_passage_match(example, pred) | |
return answer_EM and answer_PM | |
# Set up a basic teleprompter, which will compile our RAG program. | |
teleprompter = BootstrapFewShot(metric=validate_context_and_answer) | |
# Compile! | |
compiled_rag = teleprompter.compile(RAG(), trainset=trainset) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment