Last active
February 14, 2024 22:31
-
-
Save seanchatmangpt/5f4d6288ed23f73eeaf4320ca6ea0574 to your computer and use it in GitHub Desktop.
Pydantic instance generation with unit tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import logging | |
import inspect | |
from typing import Type, TypeVar | |
from dspy import Assert, Module, ChainOfThought, Signature, InputField, OutputField | |
from pydantic import BaseModel, ValidationError | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.ERROR) | |
def eval_dict_str(dict_str: str) -> dict: | |
"""Safely convert str to dict""" | |
return ast.literal_eval(dict_str) | |
class PromptToPydanticInstanceSignature(Signature): | |
""" | |
Synthesize the prompt into the kwargs fit the model. | |
Do not duplicate the field descriptions | |
""" | |
root_pydantic_model_class_name = InputField( | |
desc="The class name of the pydantic model to receive the kwargs" | |
) | |
pydantic_model_definitions = InputField( | |
desc="Pydantic model class definitions as a string" | |
) | |
prompt = InputField( | |
desc="The prompt to be synthesized into data. Do not duplicate descriptions" | |
) | |
root_model_kwargs_dict = OutputField( | |
prefix="kwargs_dict: dict = ", | |
desc="Generate a Python dictionary as a string with minimized whitespace that only contains json valid values.", | |
) | |
class PromptToPydanticInstanceErrorSignature(Signature): | |
"""Synthesize the prompt into the kwargs fit the model""" | |
error = InputField(desc="Error message to fix the kwargs") | |
root_pydantic_model_class_name = InputField( | |
desc="The class name of the pydantic model to receive the kwargs" | |
) | |
pydantic_model_definitions = InputField( | |
desc="Pydantic model class definitions as a string" | |
) | |
prompt = InputField(desc="The prompt to be synthesized into data") | |
root_model_kwargs_dict = OutputField( | |
prefix="kwargs_dict = ", | |
desc="Generate a Python dictionary as a string with minimized whitespace that only contains json valid values.", | |
) | |
T = TypeVar("T", bound=BaseModel) | |
class GenPydanticInstance(Module): | |
""" | |
A module for generating and validating Pydantic model instances based on prompts. | |
Usage: | |
To use this module, instantiate the GenPydanticInstance class with the desired | |
root Pydantic model and optional child models. Then, call the `forward` method | |
with a prompt to generate Pydantic model instances based on the provided prompt. | |
""" | |
def __init__( | |
self, | |
root_model: Type[T], | |
child_models: list[Type[BaseModel]] = None, | |
generate_sig=PromptToPydanticInstanceSignature, | |
correct_generate_sig=PromptToPydanticInstanceErrorSignature, | |
): | |
super().__init__() | |
self.models = [root_model] # Always include root_model in models list | |
self.models.extend(child_models) | |
self.output_key = "root_model_kwargs_dict" | |
self.root_model = root_model | |
# Concatenate source code of models for use in generation/correction logic | |
self.model_sources = "\n".join( | |
[inspect.getsource(model) for model in self.models] | |
) | |
# Initialize DSPy ChainOfThought modules for generation and correction | |
self.generate = ChainOfThought(generate_sig) | |
self.correct_generate = ChainOfThought(correct_generate_sig) | |
self.validation_error = None | |
def validate_root_model(self, output: str) -> bool: | |
"""Validates whether the generated output conforms to the root Pydantic model.""" | |
try: | |
model_inst = self.root_model.model_validate(eval_dict_str(output)) | |
return isinstance(model_inst, self.root_model) | |
except (ValidationError, ValueError, TypeError, SyntaxError) as error: | |
self.validation_error = error | |
logger.debug(f"Validation error: {error}") | |
return False | |
def validate_output(self, output) -> T: | |
"""Validates the generated output and returns an instance of the root Pydantic model if successful.""" | |
Assert( | |
self.validate_root_model(output), | |
f"""You need to create a kwargs dict for {self.root_model.__name__}\n | |
Validation error:\n{self.validation_error}""", | |
) | |
return self.root_model.model_validate(eval_dict_str(output)) | |
def forward(self, prompt) -> T: | |
""" | |
Takes a prompt as input and generates a Python dictionary that represents an instance of the | |
root Pydantic model. It also handles error correction and validation. | |
""" | |
output = self.generate( | |
prompt=prompt, | |
root_pydantic_model_class_name=self.root_model.__name__, | |
pydantic_model_definitions=self.model_sources, | |
) | |
output = output[self.output_key] | |
try: | |
return self.validate_output(output) | |
except (AssertionError, ValueError, TypeError) as error: | |
logger.error(f"Error {str(error)}\nOutput:\n{output}") | |
# Correction attempt | |
corrected_output = self.correct_generate( | |
prompt=prompt, | |
root_pydantic_model_class_name=self.root_model.__name__, | |
pydantic_model_definitions=self.model_sources, | |
error=f"str(error){self.validation_error}", | |
)[self.output_key] | |
return self.validate_output(corrected_output) | |
def __call__(self, *args, **kwargs): | |
return self.forward(kwargs.get("prompt")) | |
def main(): | |
import dspy | |
from rdddy.messages import EventStormModel, Event, Command, Query | |
lm = dspy.OpenAI(max_tokens=3000, model="gpt-4") | |
dspy.settings.configure(lm=lm) | |
prompt = """ | |
```prompt | |
Automated Hygen template full stack system for NextJS. | |
Express | |
Express.js is arguably the most popular web framework for Node.js | |
A typical app structure for express celebrates the notion of routes and handlers, while views and data are left for interpretation (probably because the rise of microservices and client-side apps). | |
So an app structure may look like this: | |
app/ | |
routes.js | |
handlers/ | |
health.js | |
shazam.js | |
While routes.js glues everything together: | |
// ... some code ... | |
const health = require('./handlers/health') | |
const shazam = require('./handlers/shazam') | |
app.get('/health', health) | |
app.post('/shazam', shazam) | |
module.exports = app | |
Unlike React Native, you could dynamically load modules here. However, there's still a need for judgement when constructing the routes (app.get/post part). | |
Using hygen let's see how we could build something like this: | |
$ hygen route new --method post --name auth | |
Since we've been through a few templates as with previous use cases, let's jump straight to the interesting part, the inject part. | |
So let's say our generator is structured like this: | |
_templates/ | |
route/ | |
new/ | |
handler.ejs.t | |
inject_handler.ejs.t | |
Then inject_handler looks like this: | |
--- | |
inject: true | |
to: app/routes.js | |
skip_if: <%= name %> | |
before: "module.exports = app" | |
--- | |
app.<%= method %>('/<%= name %>', <%= name %>) | |
Note how we're anchoring this inject to before: "module.exports = app". If in previous occasions we appended content to a given line, we're now prepending it. | |
``` | |
You are a Event Storm assistant that comes up with Events, Commands, and Queries for Reactive Domain Driven Design based on the ```prompt``` | |
""" | |
model_module = GenPydanticInstance(root_model=EventStormModel, child_models=[Event, Command, Query]) | |
model_inst = model_module(prompt=prompt) | |
print(model_inst) | |
value = """""" | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rdddy.generators.gen_pydantic_instance import ( | |
GenPydanticInstance, | |
) | |
import pytest | |
from unittest.mock import patch, MagicMock | |
from dspy import settings, OpenAI, DSPyAssertionError | |
from typing import Dict, Any, Optional | |
from pydantic import BaseModel, Field, ValidationError | |
class APIEndpoint(BaseModel): | |
method: str = Field(..., description="HTTP method of the API endpoint") | |
url: str = Field(..., description="URL of the API endpoint") | |
description: str = Field( | |
..., description="Description of what the API endpoint does" | |
) | |
response: str = Field(..., description="Response from the API endpoint") | |
query_params: Optional[Dict[str, Any]] = Field(None, description="Query parameters") | |
VALID_PYDANTIC_MODEL_STRING = """{ | |
"method": "GET", | |
"url": "/forecast/today", | |
"description": "API endpoint for retrieving meteorological conditions", | |
"response": "Structured summary of weather conditions", | |
"query_params": {"geographical_area": "string"} | |
}""" | |
VALID_PROMPT = """ | |
Imagine a digital portal where users can inquire about meteorological conditions. | |
This portal is accessible through a web interface that interacts with a backend service. | |
The service is invoked by sending a request to a specific endpoint. | |
This request is crafted using a standard protocol for web communication. | |
The endpoint's location is a mystery, hidden within the path '/forecast/today'. | |
Users pose their inquiries by specifying a geographical area of interest, | |
though the exact format of this specification is left to the user's imagination. | |
Upon successful request processing, the service responds with a structured | |
summary of the weather, encapsulating details such as temperature, humidity, | |
and wind speed. However, the structure of this response and the means of | |
accessing the weather summary are not explicitly defined. | |
""" | |
VALID_PYDANTIC_MODEL_DICT = { | |
"method": "GET", | |
"url": "/forecast/today", | |
"description": "API endpoint for retrieving meteorological conditions", | |
"response": "Structured summary of weather conditions", | |
"query_params": {"geographical_area": "string"}, | |
} | |
INVALID_STR = "{ 'name': 'Alice', 'age': 30, 'city': 'Wonderland' }" | |
@pytest.fixture | |
def gen_pydantic_model(): | |
with patch.object(settings, "configure"), patch.object( | |
OpenAI, "__init__", return_value=None | |
): | |
yield GenPydanticInstance( | |
APIEndpoint | |
) # Replace APIEndpoint with your Pydantic model | |
@patch("dspy.predict.Predict.forward") | |
@patch("rdddy.generators.gen_module.ChainOfThought") | |
@patch("ast.literal_eval") | |
def test_forward_success( | |
mock_literal_eval, mock_chain_of_thought, mock_predict, gen_pydantic_model | |
): | |
# Mock responses for a successful forward pass | |
mock_predict.return_value.get.return_value = ( | |
VALID_PYDANTIC_MODEL_STRING # Replace with a valid string for your model | |
) | |
mock_chain_of_thought.return_value.get.return_value = VALID_PYDANTIC_MODEL_STRING | |
mock_literal_eval.return_value = ( | |
VALID_PYDANTIC_MODEL_DICT # Replace with a valid dict for your model | |
) | |
# Call the method | |
result = gen_pydantic_model.forward( | |
prompt=VALID_PROMPT | |
) # Replace with a valid prompt | |
assert isinstance( | |
result, APIEndpoint | |
) # Replace APIEndpoint with your Pydantic model class | |
@patch("dspy.predict.Predict.forward") | |
@patch("rdddy.generators.gen_module.ChainOfThought") | |
@patch("ast.literal_eval", side_effect=SyntaxError) | |
def test_forward_syntax_error( | |
mock_literal_eval, mock_chain_of_thought, mock_predict, gen_pydantic_model | |
): | |
# Setup mock responses for a syntax error case | |
mock_predict.return_value.get.return_value = INVALID_STR | |
mock_chain_of_thought.side_effect = [ | |
MagicMock(get=MagicMock(return_value=INVALID_STR)), # initial call | |
MagicMock(get=MagicMock(return_value=INVALID_STR)), # correction call | |
] | |
# Call the method and expect an error | |
with pytest.raises(DSPyAssertionError): | |
gen_pydantic_model.forward(prompt="///") # Replace with an invalid prompt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment