Created
December 30, 2024 16:51
-
-
Save skrawcz/3d84b70a9df0e71048b69339d130aa62 to your computer and use it in GitHub Desktop.
gist for pytest blog post
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
@pytest.fixture(scope="module") | |
def database_connection(): | |
"""Fixture that creates a DB connection""" | |
db_client = SomeDBClient() | |
yield db_client | |
print("\nStopped client:\n") | |
def test_my_function(database_connection): | |
"""pytest will inject the result of the 'database_connection' function | |
into `database_connection` here in this test function""" | |
... | |
def test_my_other_function(database_connection): | |
"""pytest will inject the result of the 'database_connection' function | |
into `database_connection` here in this test function""" | |
... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from our_agent_application import agent_builder, agent_runner # some functions that build and run our agent | |
from burr.core import state | |
# the following is required to run file based tests | |
from burr.testing import pytest_generate_tests # noqa: F401 | |
@pytest.mark.file_name("e2e.json") # our fixture file with the expected inputs and outputs | |
def test_an_agent_e2e(input_state, expected_state, results_bag): | |
"""Function for testing an agent end-to-end.""" | |
input_state = state.State.deserialize(input_state) | |
expected_state = state.State.deserialize(expected_state) | |
# exercise the agent | |
agent = agent_builder(input_state) # e.g. something like some_actions._build_application(...) | |
output_state = agent_runner(agent) | |
results_bag.input_state = input_state | |
results_bag.expected_state = expected_state | |
results_bag.output_state = output_state | |
results_bag.foo = "bar" | |
# TODO: choose appropriate way to evaluate the output | |
# e.g. exact match, fuzzy match, LLM grade, etc. | |
# this is exact match here on all values in state | |
exact_match = output_state == expected_state | |
# for output that varies, you can do something like this | |
# assert 'some value' in output_state["response"]["content"] | |
# or, have an LLM Grade things -- you need to create the llm_evaluator function: | |
# assert llm_evaluator("are these two equivalent responses. Respond with Y for yes, N for no", | |
# output_state["response"]["content"], expected_state["response"]["content"]) == "Y" | |
# store it in the results bag | |
results_bag.correct = exact_match | |
# place any asserts at the end of the test | |
assert exact_match | |
import pytest | |
from our_agent_application import agent_builder, agent_runner # some functions that build and run our agent | |
from burr.core import state | |
# the following is required to run file based tests | |
from burr.testing import pytest_generate_tests # noqa: F401 | |
from burr.tracking import LocalTrackingClient | |
@pytest.fixture | |
def tracker(): | |
"""Fixture for creating a tracker to track runs to log to the Burr UI.""" | |
tracker = LocalTrackingClient("pytest-runs") | |
# optionally turn on opentelemetry tracing | |
yield tracker | |
@pytest.mark.file_name("e2e.json") # our fixture file with the expected inputs and outputs | |
def test_an_agent_e2e_with_tracker(input_state, expected_state, results_bag, tracker, request): | |
"""Function for testing an agent end-to-end using the tracker. | |
Fixtures used: | |
- results_bag: to log results -- comes from pytest-harvest | |
- tracker: to track runs -- comes from tracker() function above | |
- request: to get the test name -- comes from pytest | |
""" | |
input_state = state.State.deserialize(input_state) | |
expected_state = state.State.deserialize(expected_state) | |
test_name = request.node.name | |
# create the agent -- using the parametrizable builder | |
agent = agent_builder(input_state, partition_key=test_name, tracker=tracker) # e.g. something like some_actions._build_application(...) | |
output_state = agent_runner(agent) | |
results_bag.input_state = input_state | |
results_bag.expected_state = expected_state | |
results_bag.output_state = output_state | |
results_bag.foo = "bar" | |
# TODO: choose appropriate way to evaluate the output | |
# e.g. exact match, fuzzy match, LLM grade, etc. | |
# this is exact match here on all values in state | |
exact_match = output_state == expected_state | |
# for output that varies, you can do something like this | |
# assert 'some value' in output_state["response"]["content"] | |
# or, have an LLM Grade things -- you need to create the llm_evaluator function: | |
# assert llm_evaluator("are these two equivalent responses. Respond with Y for yes, N for no", | |
# output_state["response"]["content"], expected_state["response"]["content"]) == "Y" | |
# store it in the results bag | |
results_bag.correct = exact_match | |
# place any asserts at the end of the test | |
assert exact_match |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from our_agent_application import prompt_for_more | |
from burr.core import state | |
# the following is required to run file based tests | |
from burr.testing import pytest_generate_tests # noqa: F401 | |
@pytest.mark.file_name("prompt_for_more.json") # our fixture file with the expected inputs and outputs | |
def test_an_agent_action(input_state, expected_state, results_bag): | |
"""Function for testing an individual action of our agent.""" | |
input_state = state.State.deserialize(input_state) | |
expected_state = state.State.deserialize(expected_state) | |
_, output_state = prompt_for_more(input_state) # exercising an action of our agent | |
results_bag.input_state = input_state | |
results_bag.expected_state = expected_state | |
results_bag.output_state = output_state | |
results_bag.foo = "bar" | |
# TODO: choose appropriate way to evaluate the output | |
# e.g. exact match, fuzzy match, LLM grade, etc. | |
# this is exact match here on all values in state | |
exact_match = output_state == expected_state | |
# for output that varies, you can do something like this | |
# assert 'some value' in output_state["response"]["content"] | |
# or, have an LLM Grade things -- you need to create the llm_evaluator function: | |
# assert llm_evaluator("are these two equivalent responses. Respond with Y for yes, N for no", | |
# output_state["response"]["content"], expected_state["response"]["content"]) == "Y" | |
# store it in the results bag | |
results_bag.correct = exact_match | |
# place any asserts at the end of the test | |
assert exact_match |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
import subprocess | |
@pytest.fixture | |
def git_info(): | |
"""Fixture that returns the git commit, branch, latest_tag. | |
Note if there are uncommitted changes, the commit will have '-dirty' appended. | |
""" | |
try: | |
commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8') | |
dirty = subprocess.check_output(['git', 'status', '--porcelain']).strip() != b'' | |
commit = f"{commit}{'-dirty' if dirty else ''}" | |
except subprocess.CalledProcessError: | |
commit = None | |
try: | |
latest_tag = subprocess.check_output(['git', 'describe', '--tags', '--abbrev=0']).strip().decode('utf-8') | |
except subprocess.CalledProcessError: | |
latest_tag = None | |
try: | |
branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8') | |
except subprocess.CalledProcessError: | |
branch = None | |
return {'commit': commit, 'latest_tag': latest_tag, "branch": branch} | |
def test_print_results(module_results_df, git_info): | |
"""Function that uses pytest-harvest and our custom git fixture that goes at the end of the module to evaluate & save the results.""" | |
... | |
# add the git information | |
module_results_df["git_commit"] = git_info["commit"] | |
module_results_df["git_latest_tag"] = git_info["latest_tag"] | |
# save results | |
module_results_df.to_csv("results.csv") | |
... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_an_actions_stability(): | |
"""Let's run it a few times to see output variability.""" | |
audio = ... | |
outputs = [run_our_action(State({"audio": audio})) | |
for _ in range(5)] | |
# Check for consistency - for each key create a set of values | |
variances = {} | |
for key in outputs[0].keys(): | |
all_values = set(json.dumps(output[key]) for output in outputs) | |
if len(all_values) > 1: | |
variances[key] = list(all_values) | |
variances_str = json.dumps(variances, indent=2) | |
assert len(variances) == 0, "Outputs vary across iterations:\n" + variances_str |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# test_my_agent.py | |
def test_my_agent(): | |
assert my_agent("input1") == "output1" | |
assert my_agent("input2") == "output2" | |
# can have multiple asserts here - it'll fail | |
# on the first one and not run the rest |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
@pytest.mark.parametrize( | |
"input, expected_output", | |
[ | |
("input1", "output1"), | |
("input2", "output2"), | |
], | |
ids=["test1", "test2"] # these are the test names for the above inputs | |
) | |
def test_my_agent(input, expected_output): | |
actual_output = my_agent(input) # your code to call your agent or part of it here | |
# can include static measures / evaluations here | |
assert actual_output == expected_output | |
# assert some other property of the output... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_my_agent(results_bag): | |
output = my_agent("my_value") | |
results_bag.input = "my_value" | |
results_bag.output = output | |
results_bag.expected_output = "my_expected_output" | |
results_bag.exact_match = "my_expected_output" == output | |
... | |
# place this function at the end of your test module | |
def test_print_results(module_results_df): | |
"""This function evaluates / does operations over all results captured""" | |
# this will include "input", "output", "expected_output" | |
print(module_results_df.columns) | |
# this will show the first few rows of the results | |
print(module_results_df.head()) | |
# Add more evaluation logic here or log the results to a file, etc. | |
accuracy = sum(module_results_df.exact_match) / len(module_results_df) | |
# can save results somewhere | |
module_results_df.to_csv(...) | |
# assert some threshold of success, etc. | |
assert accuracy > 0.9, "Failed overall exact match accuracy threshold" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
@pytest.mark.parametrize( | |
"input, expected_output", | |
[ | |
("input1", "output1"), | |
("input2", "output2"), | |
], | |
ids=["test1", "test2"] # can provide test names | |
) | |
def test_my_agent(input, expected_output, results_bag): | |
results_bag.input = input | |
results_bag.expected_output = expected_output | |
results_bag.output = my_agent(input) # your code to call the agent here | |
# can include static measures / evaluations here | |
results_bag.success = results_bag.output == results_bag.expected_output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment