Skip to content

Instantly share code, notes, and snippets.

@seanchatmangpt
Created February 14, 2024 23:01
Show Gist options
  • Save seanchatmangpt/1e9db2263bbf7e6d5e77a23871a38e70 to your computer and use it in GitHub Desktop.
Save seanchatmangpt/1e9db2263bbf7e6d5e77a23871a38e70 to your computer and use it in GitHub Desktop.
Generate python primatives with test.
import ast
from dspy import Assert
from rdddy.generators.gen_module import GenModule
def is_primitive_type(data_type):
primitive_types = {int, float, str, bool, list, tuple, dict, set}
return data_type in primitive_types
class GenPythonPrimitive(GenModule):
def __init__(self, primitive_type, lm=None):
if not is_primitive_type(primitive_type):
raise ValueError(
f"primitive type {primitive_type.__name__} must be a Python primitive type"
)
super().__init__(f"{primitive_type.__name__}_str_for_ast_literal_eval", lm)
self.primitive_type = primitive_type
def validate_primitive(self, output) -> bool:
try:
return isinstance(ast.literal_eval(output), self.primitive_type)
except SyntaxError as error:
return False
def validate_output(self, output):
Assert(
self.validate_primitive(output),
f"You need to create a valid python {self.primitive_type.__name__} "
f"primitive type for \n{self.output_key}\n"
f"You will be penalized for not returning only a {self.primitive_type.__name__} for "
f"{self.output_key}",
)
data = ast.literal_eval(output)
if self.primitive_type is set:
data = set(data)
return data
def __call__(self, prompt):
return self.forward(prompt=prompt)
class GenDict(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=dict)
class GenList(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=list)
class GenBool(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=bool)
class GenInt(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=int)
class GenFloat(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=float)
class GenTuple(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=tuple)
class GenSet(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=set)
class GenStr(GenPythonPrimitive):
def __init__(self):
super().__init__(primitive_type=str)
def main():
result = GenTuple()(
"Create a list of planets in our solar system sorted by largest to smallest"
)
assert result == (
"Jupiter",
"Saturn",
"Uranus",
"Neptune",
"Earth",
"Venus",
"Mars",
"Mercury",
)
print(f"The planets of the solar system are {result}")
if __name__ == "__main__":
main()
from rdddy.generators.gen_python_primitive import (
GenPythonPrimitive,
) # replace with the actual import
import pytest
from unittest.mock import patch, MagicMock
from dspy import ChainOfThought, settings, OpenAI, DSPyAssertionError
@pytest.fixture
def gen_python_primitive():
with patch.object(settings, "configure"), patch.object(
OpenAI, "__init__", return_value=None
):
yield GenPythonPrimitive(list)
@patch("dspy.predict.Predict.forward")
@patch("rdddy.generators.gen_module.ChainOfThought")
@patch("ast.literal_eval")
def test_forward_success(
mock_literal_eval, mock_chain_of_thought, mock_predict, gen_python_primitive
):
# Setup mock responses
mock_predict.return_value.get.return_value = "['Jupiter', 'Saturn']"
mock_chain_of_thought.return_value.get.return_value = "['Jupiter', 'Saturn']"
mock_literal_eval.return_value = ["Jupiter", "Saturn"]
# Call the method
result = gen_python_primitive.forward(prompt="Create a list of planets")
assert result == ["Jupiter", "Saturn"]
@patch("dspy.predict.Predict.forward")
@patch("rdddy.generators.gen_module.ChainOfThought")
@patch("ast.literal_eval", side_effect=SyntaxError)
def test_forward_syntax_error(
mock_literal_eval, mock_chain_of_thought, mock_predict, gen_python_primitive
):
# Setup mock responses
mock_predict.return_value.get.return_value = "{'Jupiter', 'Saturn'}"
mock_chain_of_thought.side_effect = [
MagicMock(get=MagicMock(return_value="{'Jupiter', 'Saturn'}")), # initial call
MagicMock(
get=MagicMock(return_value="{'Jupiter', 'Saturn'}")
), # correction call
]
# Call the method and expect an error
with pytest.raises(DSPyAssertionError):
gen_python_primitive.forward(prompt="Create a list with syntax error")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment