Created
February 14, 2024 23:01
-
-
Save seanchatmangpt/1e9db2263bbf7e6d5e77a23871a38e70 to your computer and use it in GitHub Desktop.
Generate python primatives with test.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from dspy import Assert | |
from rdddy.generators.gen_module import GenModule | |
def is_primitive_type(data_type): | |
primitive_types = {int, float, str, bool, list, tuple, dict, set} | |
return data_type in primitive_types | |
class GenPythonPrimitive(GenModule): | |
def __init__(self, primitive_type, lm=None): | |
if not is_primitive_type(primitive_type): | |
raise ValueError( | |
f"primitive type {primitive_type.__name__} must be a Python primitive type" | |
) | |
super().__init__(f"{primitive_type.__name__}_str_for_ast_literal_eval", lm) | |
self.primitive_type = primitive_type | |
def validate_primitive(self, output) -> bool: | |
try: | |
return isinstance(ast.literal_eval(output), self.primitive_type) | |
except SyntaxError as error: | |
return False | |
def validate_output(self, output): | |
Assert( | |
self.validate_primitive(output), | |
f"You need to create a valid python {self.primitive_type.__name__} " | |
f"primitive type for \n{self.output_key}\n" | |
f"You will be penalized for not returning only a {self.primitive_type.__name__} for " | |
f"{self.output_key}", | |
) | |
data = ast.literal_eval(output) | |
if self.primitive_type is set: | |
data = set(data) | |
return data | |
def __call__(self, prompt): | |
return self.forward(prompt=prompt) | |
class GenDict(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=dict) | |
class GenList(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=list) | |
class GenBool(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=bool) | |
class GenInt(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=int) | |
class GenFloat(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=float) | |
class GenTuple(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=tuple) | |
class GenSet(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=set) | |
class GenStr(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=str) | |
def main(): | |
result = GenTuple()( | |
"Create a list of planets in our solar system sorted by largest to smallest" | |
) | |
assert result == ( | |
"Jupiter", | |
"Saturn", | |
"Uranus", | |
"Neptune", | |
"Earth", | |
"Venus", | |
"Mars", | |
"Mercury", | |
) | |
print(f"The planets of the solar system are {result}") | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rdddy.generators.gen_python_primitive import ( | |
GenPythonPrimitive, | |
) # replace with the actual import | |
import pytest | |
from unittest.mock import patch, MagicMock | |
from dspy import ChainOfThought, settings, OpenAI, DSPyAssertionError | |
@pytest.fixture | |
def gen_python_primitive(): | |
with patch.object(settings, "configure"), patch.object( | |
OpenAI, "__init__", return_value=None | |
): | |
yield GenPythonPrimitive(list) | |
@patch("dspy.predict.Predict.forward") | |
@patch("rdddy.generators.gen_module.ChainOfThought") | |
@patch("ast.literal_eval") | |
def test_forward_success( | |
mock_literal_eval, mock_chain_of_thought, mock_predict, gen_python_primitive | |
): | |
# Setup mock responses | |
mock_predict.return_value.get.return_value = "['Jupiter', 'Saturn']" | |
mock_chain_of_thought.return_value.get.return_value = "['Jupiter', 'Saturn']" | |
mock_literal_eval.return_value = ["Jupiter", "Saturn"] | |
# Call the method | |
result = gen_python_primitive.forward(prompt="Create a list of planets") | |
assert result == ["Jupiter", "Saturn"] | |
@patch("dspy.predict.Predict.forward") | |
@patch("rdddy.generators.gen_module.ChainOfThought") | |
@patch("ast.literal_eval", side_effect=SyntaxError) | |
def test_forward_syntax_error( | |
mock_literal_eval, mock_chain_of_thought, mock_predict, gen_python_primitive | |
): | |
# Setup mock responses | |
mock_predict.return_value.get.return_value = "{'Jupiter', 'Saturn'}" | |
mock_chain_of_thought.side_effect = [ | |
MagicMock(get=MagicMock(return_value="{'Jupiter', 'Saturn'}")), # initial call | |
MagicMock( | |
get=MagicMock(return_value="{'Jupiter', 'Saturn'}") | |
), # correction call | |
] | |
# Call the method and expect an error | |
with pytest.raises(DSPyAssertionError): | |
gen_python_primitive.forward(prompt="Create a list with syntax error") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment