Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jakelevi1996/2d249adbbd2e13950852b80cca42ed02 to your computer and use it in GitHub Desktop.
Save jakelevi1996/2d249adbbd2e13950852b80cca42ed02 to your computer and use it in GitHub Desktop.
Pytest guide for unit testing in Python

Pytest guide for unit testing in Python

pytest is a Python module for performing unit tests, which can be installed with the commands pip install --upgrade pip and pip install pytest. This Gist demonstrates various features of pytest and other things which are useful when using pytest, including:

  • How to check that a function returns the values you expect it to return using an assert statement
  • How to check that a function raises the errors you expect it to raise using a with pytest.raises context manager
  • How to automate combinations of different input arguments using the @pytest.mark.parametrize decorator
  • How to define a decorator to automatically repeat a unit test multiple times with different random seeds
  • How to import source code from a parent directory
  • How to check how many times a function is called using unittest.mock.Mock
  • How to print the durations of the slowest unit tests to the command line

Testing errors and assert statements with pytest

Shown at the bottom of this Gist are 2 unit test modules: test_good, which demonstrates tests which are expected to pass, and test_bad, which demonstrates tests which are expected to fail. Included also is the console output from running each unit test. The test_good unit test demonstrates how to check that a function returns the values you expect it to return using an assert statement, and how to check that a function raises the errors you expect it to raise using a with pytest.raises context manager

If these test modules are saved in the same directory and there is a command prompt open in that directory, then these tests can be run using the commands pytest test_good.py and pytest test_bad.py respectively, or all of the tests can be run using the single command pytest; alternatively, the tests can be run from a development environment such as VS Code.

The naming conventions for modules and functions in this Gist follow the pytest conventions for Python test discovery, specifically filenames which match test_*.py or *_test.py, and function-names (outside of classes) which are prefixed with test.

Automate combinations of different input arguments using the @pytest.mark.parametrize decorator

The first argument to pytest.mark.parametrize should be a string containing the parameter names (optionally with white space as well). When there is only a single argument to the function, the second argument should be an iterable of values to use as that value. When there are multiple arguments to the function, the second argument should be an iterable of iterables, and the inner iterable should be the values for the input arguments for each test of the function, in the same order as the argument-name string. @pytest.mark.parametrize decorators can also be stacked. All of these features are demonstrated below (see end of Gist for console output from this test):

import pytest

@pytest.mark.parametrize("single_arg", [2, 4, 6, 7])
def test_single_arg_even(single_arg):
    assert single_arg % 2 == 0

@pytest.mark.parametrize(
    "arg1, arg2, arg3",
    [(1, 2, 3), (4, 5, 9), (10, 11, 12)]
)
def test_multiple_args_sum(arg1, arg2, arg3):
    assert arg1 + arg2 == arg3

@pytest.mark.parametrize("x", [1, 2, 3, 4])
@pytest.mark.parametrize("y", [0, 2, 4, 10])
def test_stacked_parameters(x, y):
    assert x * x + y * y < 10 * 10

The following code can be used to print multiple expressions for a decorator which selects a random seed:

import numpy as np

for _ in range(10):
    print("@pytest.mark.parametrize(\"seed\", [{}, {}, {}])".format(
        *np.random.randint(0, 10000, size=[3])
    ))

Example usage of one such decorator:

@pytest.mark.parametrize("seed", [9989, 6595, 7792])
def test_random(seed):
    np.random.seed(seed)
    assert np.random.randint(2, 10) > 1

Automatically repeat a unit test multiple times with different random seeds

The following Python function can be used to return a decorator which can be used to automatically repeat a unit test multiple times with different random seeds (the Primer on Python Decorators from realpython.com is a useful source of information for decorators which can be configured with arguments):

def iterate_random_seeds(*seeds):
    """
    This function can be used to return a decorator, which will automatically
    repeat a test function multiple times with different random seeds (the seeds
    are provided as arguments to this function). It is assumed that the function
    being decorated accepts no arguments, and returns no values (minor
    modifications would be needed if these assumptions were untrue). The
    decorator can be used as follows:

    ```
    @iterate_random_seeds(5920, 2788, 235)
    def function_name():
        do_function_body()
    ```
    """
    # decorator_func is the decorator which is returned, given the seeds
    def decorator_func(func):
        # func_wrapper is called when the decorated function is called
        def func_wrapper():
            # Call decorated function once with each random seed
            for s in seeds:
                np.random.seed(s)
                func()

        # Calling the decorator returns the decorated function wrapper
        return func_wrapper

    # When this function is called, the decorator is returned
    return decorator_func

Example usage:

@iterate_random_seeds(3, 4, 5)
def print_random():
    print(np.random.normal(size=[2, 2]))

The following function can be used to print multiple expressions for this decorator, with different input random seeds:

def generate_decorator_expression(num_expressions=10):
    """
    This function can be used to print multiple decorator expressions for the
    decorator above, with different input random seeds
    """
    for _ in range(num_expressions):
        print("@iterate_random_seeds({}, {}, {})".format(
            *np.random.randint(0, 10000, size=[3])
        ))

Importing from a parent directory for pytest

When writing pytest unit tests for a small project, it might be desirable to keep the unit tests in a subdirectory of the repository's top-level directory. The file-structure might look something like this:

src/
    main.py
    util.py
    test_dir/
        test_a.py
        test_b.py

This creates a problem, because the test modules test_a and test_b may not easily be able to import necessary source modules (EG main, util) from their parent directory.

A solution is to add the parent directory to the sys.path variable, which will allow modules from the parent directory to be easily imported, without needing to use any relative imports. This can be achieved using the following 4 lines of Python code (as noted in this Stack Overflow answer):

import os, sys
current_dir = os.path.dirname(os.path.abspath(__file__))
source_dir = os.path.abspath(os.path.join(current_dir, ".."))
sys.path.append(source_dir)

These lines of code can be placed at the top of each test file (EG test_a, test_b) before importing the source modules; an alternative solution is to create a file called __init__.py in the test_dir folder, and place the above 4 lines in __init__.py. This __init__.py module will be implicitly imported before any other modules in the test_dir folder are imported; this means that when pytest tries to import the test modules, the parent directory will first be added to the sys.path variable, and the test modules will be able to import the source modules from the parent directory simply using the statement import main, util.

Check if a function is called using unittest.mock.Mock

The unittest.mock.Mock class from the standard-library module unittest can be used to check if a function or method is called during a unit-test.

The side_effect argument to the Mock() constructor specifies "a function to be called whenever the Mock is called".

The Mock.called attribute is "a boolean representing whether or not the mock object has been called".

The following example demonstrates using Mock to test if a function is called:

from unittest.mock import Mock

def f(x): return x + 1

print(f(3))
# >> 4
f_mock = Mock(side_effect=f)
print(f_mock.called)
# >> False
print(f_mock(3))
# >> 4
print(f_mock.called)
# >> True

The following example demonstrates using Mock to test if methods from the C class have been called:

from unittest.mock import Mock

class C:
    def __init__(self, data): self.data = data
    def __repr__(self): return f"C({self.data})"
    def inc(self, amount=1):
        self.data += amount
        if self.data > 10: self.reduce_data()
    def reduce_data(self): self.data = self.data / 2


c = C(3)
print(c)
# >> C(3)
c.inc = Mock(side_effect=c.inc)
c.reduce_data = Mock(side_effect=c.reduce_data)
print(c, c.inc.called, c.reduce_data.called)
# >> C(3) False False
c.inc(4)
print(c, c.inc.called, c.reduce_data.called)
# >> C(7) True False
c.inc(10)
print(c, c.inc.called, c.reduce_data.called)
# >> C(8.5) True True

A note about running pytest in VS Code

When running pytest in VS Code, it is possible that a warning will be raised, starting PytestDeprecationWarning: The 'junit_family' default value will change to 'xunit2' in pytest 6.0.. A solution is to add "-o", "junit_family=xunit1" as terms in the "python.testing.pytestArgs" list in .vscode/settings.json (as described in this GitHub issue):

    "python.testing.pytestArgs": [
        "-o", "junit_family=xunit1"
    ],

Printing durations of slowest tests

Sometimes it is useful to know which tests are taking the longest amount of time, and how much time they are taking. To print the names and durations of the 5 slowest pytest tests, including the following command line argument in the pytest command:

--durations=5

In VSCode, this can be automated by adding --durations=5 as a term in the "python.testing.pytestArgs" list in .vscode/settings.json as follows:

    "python.testing.pytestArgs": [
        "--durations=5"
    ],

Running specific tests

Specific tests can be run from the command line by using the -k flag, and the name of a specific test, or a more complex pattern enclosed in quotations, as described in this Stack Overflow answer:

pytest ./path/to/test_file.py -k test_specific_test_case
pytest ./path/to/test_file.py -k 'test_specific_test_case or test_different_test_case'

If no path is specified, pytest will search recursively through the current directory and all its subdirectories for any valid test cases which match the given pattern; this is useful if it is known that there is only test case with the given name, as it requires a shorter command, for example:

pytest -k test_with_unique_name

Viewing stdout

By default, pytest will capture anything that is printed to stdout, and only display the print statements if the test fails. All print statements can be displayed (regardless of whether the test passes or fails) by including the -s or --capture=no command-line arguments in the call to pytest (or in the .vscode/settings.json "python.testing.pytestArgs" as described above), as described in the official documentation and this Stack Overflow answer.

Summary of my .vscode/settings.json

The -s flag can be un-commented in order to view stdout output:

{
    "python.testing.pytestArgs": [
        "bridgetests",
        "-o", "junit_family=xunit1",
        "--durations=5",
        // "-s"
    ],
    "python.testing.unittestEnabled": false,
    "python.testing.nosetestsEnabled": false,
    "python.testing.pytestEnabled": true
}
"""
This module demonstrates tests which are expected to fail. Run on the command
line and redirect the output to a text file with the following command:
pytest test_bad.py > z_test_bad_output.txt
"""
import pytest
def inc(x):
""" This is the function we will test """
if type(x) not in [int, float]:
raise ValueError("input must be int or float")
return x + 1
def test_wrong_output():
""" This test fails because the assert statement is False """
assert inc(3) == 5
def test_wrong_input():
""" This test fails because an uncaught error is raised """
assert inc("three") == 4
def test_wrong_error():
""" This test fails because the wrong error is caught (effecetively this is
the same as the previous example: the test fails because an uncaught error
is raised) """
with pytest.raises(RuntimeError):
inc("three")
def test_error_wrong_input():
""" This test fails because it tries to catch an error which hasn't been
raised """
with pytest.raises(ValueError):
inc(3)
"""
This module demonstrates tests which are expected to pass. Run on the command
line and redirect the output to a text file with the following command:
pytest test_good.py > z_test_good_output.txt
"""
import pytest
def inc(x):
""" This is the function we will test """
if type(x) not in [int, float]:
raise ValueError("input must be int or float")
return x + 1
def test_inc():
""" Test the function works the way we expect using an assert statement """
assert inc(3) == 4
def test_error():
""" Test the function raises the correct error using a "with pytest.raises"
context manager """
with pytest.raises(ValueError):
inc("three")
import pytest
@pytest.mark.parametrize("single_arg", [2, 4, 6, 7])
def test_single_arg_even(single_arg):
assert single_arg % 2 == 0
@pytest.mark.parametrize(
"arg1, arg2, arg3",
[(1, 2, 3), (4, 5, 9), (10, 11, 12)]
)
def test_multiple_args_sum(arg1, arg2, arg3):
assert arg1 + arg2 == arg3
@pytest.mark.parametrize("x", [1, 2, 3, 4])
@pytest.mark.parametrize("y", [0, 2, 4, 10])
def test_stacked_parameters(x, y):
assert x * x + y * y < 10 * 10
$ pytest
========================================================================== test session starts ===========================================================================
platform win32 -- Python 3.7.6, pytest-5.4.1, py-1.8.1, pluggy-0.13.1
rootdir: C:\Users\Jake\Documents\Programming\Python
collected 23 items
parametrize_test.py ...F..F............FFFF [100%]
================================================================================ FAILURES ================================================================================
________________________________________________________________________ test_single_arg_even[7] _________________________________________________________________________
single_arg = 7
@pytest.mark.parametrize("single_arg", [2, 4, 6, 7])
def test_single_arg_even(single_arg):
> assert single_arg % 2 == 0
E assert (7 % 2) == 0
parametrize_test.py:5: AssertionError
____________________________________________________________________ test_multiple_args_sum[10-11-12] ____________________________________________________________________
arg1 = 10, arg2 = 11, arg3 = 12
@pytest.mark.parametrize(
"arg1, arg2, arg3",
[(1, 2, 3), (4, 5, 9), (10, 11, 12)]
)
def test_multiple_args_sum(arg1, arg2, arg3):
> assert arg1 + arg2 == arg3
E assert (10 + 11) == 12
parametrize_test.py:12: AssertionError
_____________________________________________________________________ test_stacked_parameters[10-1] ______________________________________________________________________
x = 1, y = 10
@pytest.mark.parametrize("x", [1, 2, 3, 4])
@pytest.mark.parametrize("y", [0, 2, 4, 10])
def test_stacked_parameters(x, y):
> assert x * x + y * y < 10 * 10
E assert ((1 * 1) + (10 * 10)) < (10 * 10)
parametrize_test.py:17: AssertionError
_____________________________________________________________________ test_stacked_parameters[10-2] ______________________________________________________________________
x = 2, y = 10
@pytest.mark.parametrize("x", [1, 2, 3, 4])
@pytest.mark.parametrize("y", [0, 2, 4, 10])
def test_stacked_parameters(x, y):
> assert x * x + y * y < 10 * 10
E assert ((2 * 2) + (10 * 10)) < (10 * 10)
parametrize_test.py:17: AssertionError
_____________________________________________________________________ test_stacked_parameters[10-3] ______________________________________________________________________
x = 3, y = 10
@pytest.mark.parametrize("x", [1, 2, 3, 4])
@pytest.mark.parametrize("y", [0, 2, 4, 10])
def test_stacked_parameters(x, y):
> assert x * x + y * y < 10 * 10
E assert ((3 * 3) + (10 * 10)) < (10 * 10)
parametrize_test.py:17: AssertionError
_____________________________________________________________________ test_stacked_parameters[10-4] ______________________________________________________________________
x = 4, y = 10
@pytest.mark.parametrize("x", [1, 2, 3, 4])
@pytest.mark.parametrize("y", [0, 2, 4, 10])
def test_stacked_parameters(x, y):
> assert x * x + y * y < 10 * 10
E assert ((4 * 4) + (10 * 10)) < (10 * 10)
parametrize_test.py:17: AssertionError
======================================================================== short test summary info =========================================================================
FAILED parametrize_test.py::test_single_arg_even[7] - assert (7 % 2) == 0
FAILED parametrize_test.py::test_multiple_args_sum[10-11-12] - assert (10 + 11) == 12
FAILED parametrize_test.py::test_stacked_parameters[10-1] - assert ((1 * 1) + (10 * 10)) < (10 * 10)
FAILED parametrize_test.py::test_stacked_parameters[10-2] - assert ((2 * 2) + (10 * 10)) < (10 * 10)
FAILED parametrize_test.py::test_stacked_parameters[10-3] - assert ((3 * 3) + (10 * 10)) < (10 * 10)
FAILED parametrize_test.py::test_stacked_parameters[10-4] - assert ((4 * 4) + (10 * 10)) < (10 * 10)
====================================================================== 6 failed, 17 passed in 0.24s ======================================================================
============================= test session starts =============================
platform win32 -- Python 3.7.6, pytest-5.4.1, py-1.8.1, pluggy-0.13.1
rootdir: C:\Users\Jake\Documents\Programming\Gists\Pytest guide for unit testing in Python\2d249adbbd2e13950852b80cca42ed02
collected 4 items
test_bad.py FFFF [100%]
================================== FAILURES ===================================
______________________________ test_wrong_output ______________________________
def test_wrong_output():
""" This test fails because the assert statement is False """
> assert inc(3) == 5
E assert 4 == 5
E + where 4 = inc(3)
test_bad.py:18: AssertionError
______________________________ test_wrong_input _______________________________
def test_wrong_input():
""" This test fails because an uncaught error is raised """
> assert inc("three") == 4
test_bad.py:22:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
x = 'three'
def inc(x):
""" This is the function we will test """
if type(x) not in [int, float]:
> raise ValueError("input must be int or float")
E ValueError: input must be int or float
test_bad.py:13: ValueError
______________________________ test_wrong_error _______________________________
def test_wrong_error():
""" This test fails because the wrong error is caught (effecetively this is
the same as the previous example: the test fails because an uncaught error
is raised) """
with pytest.raises(RuntimeError):
> inc("three")
test_bad.py:29:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
x = 'three'
def inc(x):
""" This is the function we will test """
if type(x) not in [int, float]:
> raise ValueError("input must be int or float")
E ValueError: input must be int or float
test_bad.py:13: ValueError
___________________________ test_error_wrong_input ____________________________
def test_error_wrong_input():
""" This test fails because it tries to catch an error which hasn't been
raised """
with pytest.raises(ValueError):
> inc(3)
E Failed: DID NOT RAISE <class 'ValueError'>
test_bad.py:35: Failed
=========================== short test summary info ===========================
FAILED test_bad.py::test_wrong_output - assert 4 == 5
FAILED test_bad.py::test_wrong_input - ValueError: input must be int or float
FAILED test_bad.py::test_wrong_error - ValueError: input must be int or float
FAILED test_bad.py::test_error_wrong_input - Failed: DID NOT RAISE <class 'Va...
============================== 4 failed in 0.04s ==============================
============================= test session starts =============================
platform win32 -- Python 3.7.6, pytest-5.4.1, py-1.8.1, pluggy-0.13.1
rootdir: C:\Users\Jake\Documents\Programming\Gists\Pytest guide for unit testing in Python\2d249adbbd2e13950852b80cca42ed02
collected 2 items
test_good.py .. [100%]
============================== 2 passed in 0.02s ==============================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment