Created
January 31, 2024 07:08
-
-
Save seanchatmangpt/6ad51efc96a125fcd5ca77e539d920fa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import dspy | |
from typing import Dict, Any, Optional | |
lm = dspy.OpenAI(max_tokens=500) | |
dspy.settings.configure(lm=lm) | |
import inspect | |
from dspy import Assert | |
from pydantic import BaseModel, Field, ValidationError | |
from rdddy.generators.gen_module import GenModule | |
def strip_text_before_first_open_brace(input_text): | |
if "{" in input_text: | |
return input_text[input_text.index("{"):] | |
else: | |
return input_text | |
from dspy import Signature, InputField, OutputField | |
class GenPydanticModel(GenModule): | |
def __init__(self, root_model, models: list = None): | |
if models is None: | |
models = [root_model] | |
elif root_model not in models: | |
models.append(root_model) | |
super().__init__(f"{root_model.__name__.lower()}_model_validate_json_dict", | |
input_keys=["inspect_getsource", "prompt"]) | |
self.root_model = root_model | |
self.models = models | |
self.model_sources = '\n'.join([inspect.getsource(model) for model in self.models]) | |
def validate_root_model(self, output) -> bool: | |
try: | |
value = ast.literal_eval(output) | |
model = self.root_model.parse_obj(value) | |
return isinstance(model, self.root_model) | |
except (ValidationError, TypeError) as error: | |
return False | |
def validate_output(self, output): | |
output = strip_text_before_first_open_brace(str(output)) | |
Assert( | |
self.validate_root_model(output), | |
f"""You need to create a kwargs dict for {self.root_model.__name__}""", | |
) | |
value = ast.literal_eval(output) | |
model = self.root_model.parse_obj(value) | |
return model | |
def forward(self, **kwargs): | |
return super().forward(inspect_getsource=self.model_sources, prompt=kwargs["prompt"]) | |
api_description = """ | |
Imagine a digital portal where users can inquire about meteorological conditions. | |
This portal is accessible through a web interface that interacts with a backend service. | |
The service is invoked by sending a request to a specific endpoint. | |
This request is crafted using a standard protocol for web communication. | |
The endpoint's location is a mystery, hidden within the path '/forecast/today'. | |
Users pose their inquiries by specifying a geographical area of interest, | |
though the exact format of this specification is left to the user's imagination. | |
Upon successful request processing, the service responds with a structured | |
summary of the weather, encapsulating details such as temperature, humidity, | |
and wind speed. However, the structure of this response and the means of | |
accessing the weather summary are not explicitly defined. | |
""" | |
class APIEndpoint(BaseModel): | |
method: str = Field(..., description="HTTP method of the API endpoint") | |
url: str = Field(..., description="URL of the API endpoint") | |
description: str = Field(..., description="Description of what the API endpoint does") | |
response: str = Field(..., description="Response from the API endpoint") | |
query_params: Optional[Dict[str, Any]] = Field(None, description="Query parameters") | |
class RefineTextForModelSignature(Signature): | |
"""Change the description to fit the model""" | |
model_definitions = InputField(desc="Pydantic model class definitions as a string") | |
raw_description = InputField(desc="The raw text description to be refined") | |
refined_description = OutputField( | |
desc="Description as kwargs dict for the Pydantic model that only contains json valid values") | |
def main(): | |
desc = dspy.ChainOfThought(RefineTextForModelSignature).forward(model_definitions=inspect.getsource(APIEndpoint), | |
raw_description=api_description).refined_description | |
dot = GenPydanticModel(root_model=APIEndpoint) | |
result = dot.forward(prompt=desc) | |
print(result) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
method='GET' url='/forecast/today' description='Provides a summary of the weather for a specified geographical area' response='Structured summary of the weather, including temperature, humidity, and wind speed' query_params={'geographical_area': 'str'}