Last active
September 21, 2024 10:34
-
-
Save seanchatmangpt/7e25b66ebffdedba7310d9c90f377463 to your computer and use it in GitHub Desktop.
Create DSPy Signatures from Pydantic Models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import logging | |
import inspect | |
from typing import Type, TypeVar | |
from dspy import Assert, Module, ChainOfThought, Signature, InputField, OutputField | |
from pydantic import BaseModel, ValidationError | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.ERROR) | |
def eval_dict_str(dict_str: str) -> dict: | |
"""Safely convert str to dict""" | |
return ast.literal_eval(dict_str) | |
class PromptToPydanticInstanceSignature(Signature): | |
""" | |
Synthesize the prompt into the kwargs fit the model. | |
Do not duplicate the field descriptions | |
""" | |
root_pydantic_model_class_name = InputField( | |
desc="The class name of the pydantic model to receive the kwargs" | |
) | |
pydantic_model_definitions = InputField( | |
desc="Pydantic model class definitions as a string" | |
) | |
prompt = InputField( | |
desc="The prompt to be synthesized into data. Do not duplicate descriptions" | |
) | |
root_model_kwargs_dict = OutputField( | |
prefix="kwargs_dict: dict = ", | |
desc="Generate a Python dictionary as a string with minimized whitespace that only contains json valid values.", | |
) | |
class PromptToPydanticInstanceErrorSignature(Signature): | |
"""Synthesize the prompt into the kwargs fit the model""" | |
error = InputField(desc="Error message to fix the kwargs") | |
root_pydantic_model_class_name = InputField( | |
desc="The class name of the pydantic model to receive the kwargs" | |
) | |
pydantic_model_definitions = InputField( | |
desc="Pydantic model class definitions as a string" | |
) | |
prompt = InputField(desc="The prompt to be synthesized into data") | |
root_model_kwargs_dict = OutputField( | |
prefix="kwargs_dict = ", | |
desc="Generate a Python dictionary as a string with minimized whitespace that only contains json valid values.", | |
) | |
T = TypeVar("T", bound=BaseModel) | |
class GenPydanticInstance(Module): | |
""" | |
A module for generating and validating Pydantic model instances based on prompts. | |
Usage: | |
To use this module, instantiate the GenPydanticInstance class with the desired | |
root Pydantic model and optional child models. Then, call the `forward` method | |
with a prompt to generate Pydantic model instances based on the provided prompt. | |
""" | |
def __init__( | |
self, | |
root_model: Type[T], | |
child_models: list[Type[BaseModel]] = None, | |
generate_sig=PromptToPydanticInstanceSignature, | |
correct_generate_sig=PromptToPydanticInstanceErrorSignature, | |
): | |
super().__init__() | |
if not issubclass(root_model, BaseModel): | |
raise TypeError("root_model must inherit from pydantic.BaseModel") | |
self.models = [root_model] # Always include root_model in models list | |
if child_models: | |
# Validate that each child_model inherits from BaseModel | |
for model in child_models: | |
if not issubclass(model, BaseModel): | |
raise TypeError( | |
"All child_models must inherit from pydantic.BaseModel" | |
) | |
self.models.extend(child_models) | |
self.output_key = "root_model_kwargs_dict" | |
self.root_model = root_model | |
# Concatenate source code of models for use in generation/correction logic | |
self.model_sources = "\n".join( | |
[inspect.getsource(model) for model in self.models] | |
) | |
# Initialize DSPy ChainOfThought modules for generation and correction | |
self.generate = ChainOfThought(generate_sig) | |
self.correct_generate = ChainOfThought(correct_generate_sig) | |
self.validation_error = None | |
def validate_root_model(self, output: str) -> bool: | |
"""Validates whether the generated output conforms to the root Pydantic model.""" | |
try: | |
model_inst = self.root_model.model_validate(eval_dict_str(output)) | |
return isinstance(model_inst, self.root_model) | |
except (ValidationError, ValueError, TypeError, SyntaxError) as error: | |
self.validation_error = error | |
logger.debug(f"Validation error: {error}") | |
return False | |
def validate_output(self, output) -> T: | |
"""Validates the generated output and returns an instance of the root Pydantic model if successful.""" | |
Assert( | |
self.validate_root_model(output), | |
f"""You need to create a kwargs dict for {self.root_model.__name__}\n | |
Validation error:\n{self.validation_error}""", | |
) | |
return self.root_model.model_validate(eval_dict_str(output)) | |
def forward(self, prompt) -> T: | |
""" | |
Takes a prompt as input and generates a Python dictionary that represents an instance of the | |
root Pydantic model. It also handles error correction and validation. | |
""" | |
output = self.generate( | |
prompt=prompt, | |
root_pydantic_model_class_name=self.root_model.__name__, | |
pydantic_model_definitions=self.model_sources, | |
) | |
output = output[self.output_key] | |
try: | |
return self.validate_output(output) | |
except (AssertionError, ValueError, TypeError) as error: | |
logger.error(f"Error {str(error)}\nOutput:\n{output}") | |
# Correction attempt | |
corrected_output = self.correct_generate( | |
prompt=prompt, | |
root_pydantic_model_class_name=self.root_model.__name__, | |
pydantic_model_definitions=self.model_sources, | |
error=str(error), | |
)[self.output_key] | |
return self.validate_output(corrected_output) | |
def __call__(self, *args, **kwargs): | |
return self.forward(kwargs.get("prompt")) | |
def main(): | |
import dspy | |
from rdddy.messages import EventStormModel, Event, Command, Query | |
lm = dspy.OpenAI(max_tokens=3000, model="gpt-4") | |
dspy.settings.configure(lm=lm) | |
prompt = """Automated Hygen template full stack system for NextJS. | |
Express | |
Express.js is arguably the most popular web framework for Node.js | |
A typical app structure for express celebrates the notion of routes and handlers, while views and data are left for interpretation (probably because the rise of microservices and client-side apps). | |
So an app structure may look like this: | |
app/ | |
routes.js | |
handlers/ | |
health.js | |
shazam.js | |
While routes.js glues everything together: | |
// ... some code ... | |
const health = require('./handlers/health') | |
const shazam = require('./handlers/shazam') | |
app.get('/health', health) | |
app.post('/shazam', shazam) | |
module.exports = app | |
Unlike React Native, you could dynamically load modules here. However, there's still a need for judgement when constructing the routes (app.get/post part). | |
Using hygen let's see how we could build something like this: | |
$ hygen route new --method post --name auth | |
Since we've been through a few templates as with previous use cases, let's jump straight to the interesting part, the inject part. | |
So let's say our generator is structured like this: | |
_templates/ | |
route/ | |
new/ | |
handler.ejs.t | |
inject_handler.ejs.t | |
Then inject_handler looks like this: | |
--- | |
inject: true | |
to: app/routes.js | |
skip_if: <%= name %> | |
before: "module.exports = app" | |
--- | |
app.<%= method %>('/<%= name %>', <%= name %>) | |
Note how we're anchoring this inject to before: "module.exports = app". If in previous occasions we appended content to a given line, we're now prepending it. | |
""" | |
model_module = GenPydanticInstance(root_model=EventStormModel, child_models=[Event, Command, Query]) | |
model_inst = model_module(prompt=prompt) | |
print(model_inst) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dspy | |
from dspy import Signature | |
from dspy.signatures.field import InputField, OutputField | |
from pydantic import BaseModel, Field | |
from rdddy.generators.gen_pydantic_instance import GenPydanticInstance | |
class InputFieldTemplateSpecModel(BaseModel): | |
"""Defines an input field for a DSPy Signature.""" | |
name: str = Field( | |
..., | |
description="The key used to access and pass the input within the Signature.", | |
) | |
prefix: str | None = Field( | |
None, | |
description="Optional additional context or labeling for the input field.", | |
) | |
desc: str = Field( | |
..., | |
description="Description of the input field's purpose or the nature of content it should contain.", | |
) | |
class OutputFieldTemplateSpecModel(BaseModel): | |
"""Defines an input field for a DSPy Signature.""" | |
name: str = Field( | |
..., | |
description="The key used to access and pass the input within the Signature.", | |
) | |
prefix: str | None = Field( | |
None, | |
description="Optional additional context or labeling for the output field.", | |
) | |
desc: str = Field( | |
..., | |
description="Description of the output field's purpose or the nature of content it should contain.", | |
) | |
class SignatureTemplateSpecModel(BaseModel): | |
""" | |
SignatureTemplateSpecModel encapsulates the specifications for input/output behavior of a task in the DSPy framework. | |
It provides a structured approach to define how data should be inputted into and outputted from a language model (LM), | |
facilitating the creation and integration of complex LM pipelines. | |
signature_class = type(model.name, (Signature,), class_dict) | |
""" | |
name: str = Field( | |
..., | |
description="Signature class name. Use this to specify additional context or labeling.", | |
) | |
instructions: str = Field( | |
..., description="Documentation of the task's expected LM function and output." | |
) | |
input_fields: list[InputFieldTemplateSpecModel] | |
output_fields: list[OutputFieldTemplateSpecModel] | |
def create_signature_class_from_model(model: SignatureTemplateSpecModel) -> type: | |
""" | |
Create a DSPy Signature class from a Pydantic model. | |
:param model: The Pydantic model to convert. | |
:return: A DSPy Signature class. | |
""" | |
class_dict = {"__doc__": model.instructions, "__annotations__": {}} | |
# Process input fields | |
for field in model.input_fields: | |
input_field = InputField(prefix=field.prefix, desc=field.desc) | |
class_dict[field.name] = input_field | |
class_dict["__annotations__"][field.name] = InputField | |
# Process output fields | |
for field in model.output_fields: | |
output_field = OutputField(prefix=field.prefix, desc=field.desc) | |
class_dict[field.name] = output_field | |
class_dict["__annotations__"][field.name] = OutputField | |
# Dynamically create the Signature class | |
signature_class = type(model.name, (Signature,), class_dict) | |
return signature_class | |
def main(): | |
lm = dspy.OpenAI(max_tokens=500) | |
dspy.settings.configure(lm=lm) | |
sig_prompt = "I need a signature called QuestionAnswering that allows input of 'context', 'question', and output 'answer'" | |
sig_module = GenPydanticInstance( | |
root_model=SignatureTemplateSpecModel, | |
child_models=[InputFieldTemplateSpecModel, OutputFieldTemplateSpecModel], | |
) | |
question_answering_signature = sig_module.forward(sig_prompt) | |
# Convert the SignatureModel to a DSPy Signature class | |
QuestionAnswering = create_signature_class_from_model(question_answering_signature) | |
context = """Chaining language model (LM) calls as composable modules is fueling a new powerful | |
way of programming. However, ensuring that LMs adhere to important constraints remains a key | |
challenge, one often addressed with heuristic “prompt engineering”. We introduce LM Assertions, | |
a new programming construct for expressing computational constraints that LMs should satisfy. | |
We integrate our constructs into the recent DSPy programming model for LMs, and present new | |
strategies that allow DSPy to compile programs with arbitrary LM Assertions into systems | |
that are more reliable and more accurate. In DSPy, LM Assertions can be integrated at compile | |
time, via automatic prompt optimization, and/or at inference time, via automatic self- refinement | |
and backtracking. We report on two early case studies for complex question answer- ing (QA), | |
in which the LM program must iteratively retrieve information in multiple hops and synthesize a | |
long-form answer with citations. We find that LM Assertions improve not only compliance with | |
imposed rules and guidelines but also enhance downstream task performance, delivering intrinsic | |
and extrinsic gains up to 35.7% and 13.3%, respectively. Our reference implementation of LM Assertions | |
is integrated into DSPy at dspy.ai.""" | |
question = "What strategies can DSPy use?" | |
answer = ( | |
dspy.ChainOfThought(QuestionAnswering) | |
.forward(context=context, question=question) | |
.answer | |
) | |
print(answer) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment