Created
January 30, 2024 18:48
-
-
Save seanchatmangpt/3196c51593c524f4b7cb7a6f66606ffa to your computer and use it in GitHub Desktop.
A module with self-correction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging # Import the logging module | |
from dspy import Module, OpenAI, settings, ChainOfThought, Assert | |
logger = logging.getLogger(__name__) # Create a logger instance | |
logger.setLevel(logging.ERROR) # Set the logger's level to ERROR or the appropriate level | |
class GenModule(Module): | |
def __init__(self, output_key, input_keys: list[str] = None, lm=None): | |
if lm is None: | |
lm = OpenAI(max_tokens=500) | |
settings.configure(lm=lm) | |
if input_keys is None: | |
self.input_keys = ["prompt"] | |
super().__init__() | |
self.output_key = output_key | |
# Define the generation and correction queries based on generation_type | |
self.signature = ', '.join(self.input_keys) + f" -> {self.output_key}" | |
self.correction_signature = ', '.join(self.input_keys) + f", error -> {self.output_key}" | |
# DSPy modules for generation and correction | |
self.generate = ChainOfThought(self.signature) | |
self.correct_generate = ChainOfThought(self.correction_signature) | |
def forward(self, **kwargs): | |
# Generate the output using provided inputs | |
gen_result = self.generate(**kwargs) | |
output = gen_result.get(self.output_key) | |
# Try validating the output | |
try: | |
return self.validate_output(output) | |
except (AssertionError, ValueError) as error: | |
logger.error(error) | |
# Correction attempt | |
corrected_result = self.correct_generate(**kwargs, error=str(error)) | |
corrected_output = corrected_result.get(self.output_key) | |
return self.validate_output(corrected_output) | |
def validate_output(self, output): | |
# Implement validation logic or override in subclass | |
raise NotImplementedError("Validation logic should be implemented in subclass") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment