Last active
June 14, 2024 02:16
-
-
Save karljuhlpep/cd7031dfbb43f0c1eb26ed6601ebae03 to your computer and use it in GitHub Desktop.
DSPy Module - CodeGen + Debugging
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import subprocess | |
import dspy | |
### Note this code is not tested, and likely includes errors that need to be refined. | |
class IterativeCodeRefinement(dspy.Module): | |
def __init__(self): | |
super().__init__() | |
self.generate_pseudocode = dspy.ChainOfThought("task -> pseudocode") | |
self.pseudocode_to_code = dspy.ChainOfThought("task, pseudocode -> code") | |
self.generate_code_tests = dspy.ChainOfThought("task, code -> tests") | |
self.check_code_correctness = dspy.ChainOfThought("code, tests, code_execution_output -> yes/no") | |
self.refine_pseudocode = dspy.ChainOfThought("code output, test output, errors -> new pseudocode") | |
self.refine_code_with_previous_context = dspy.ChainOfThought("task, new pseudocode, previous code+errors -> new code") | |
self.refine_tests_with_previous_context = dspy.ChainOfThought("task, new code, previous tests+errors -> new tests") | |
def execute_code(self, code): | |
# Assume code and tests are combined into the code variable | |
result = subprocess.run(["python", "-c", code], capture_output=True, text=True, shell=True) | |
return result.stdout, result.stderr, result.returncode | |
def forward(self, task): | |
pseudocode = self.generate_pseudocode(task=task).pseudocode | |
code = self.pseudocode_to_code(task=task, pseudocode=pseudocode).code | |
tests = self.generate_code_tests(task=task, code=code).tests | |
# Combine code and tests for execution | |
combined_code = code + "\n" + tests | |
stdout, stderr, returncode = self.execute_code(combined_code) | |
# Initial check for correctness | |
is_correct = self.check_code_correctness(code=code, tests=tests, code_execution_output=stdout + stderr).answer | |
# Loop for refinement based on the condition | |
while is_correct.lower() != "yes": | |
new_pseudocode = self.refine_pseudocode(code_output=stdout, test_output=stderr, errors=str(returncode)).pseudocode | |
new_code = self.refine_code_with_previous_context(task=task, new_pseudocode=new_pseudocode, previous_code_errors=stderr).code | |
new_tests = self.refine_tests_with_previous_context(task=task, new_code=new_code, previous_tests_errors=stderr).tests | |
combined_code = new_code + "\n" + new_tests | |
stdout, stderr, returncode = self.execute_code(combined_code) | |
is_correct = self.check_code_correctness(code=new_code, tests=new_tests, code_execution_output=stdout + stderr).answer | |
return {"final_code": new_code, "final_tests": new_tests, "output": stdout, "errors": stderr} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment