Last active
August 22, 2023 16:19
-
-
Save sgbaird/46f29843ac1634653a2f9a530eb78c74 to your computer and use it in GitHub Desktop.
gentle-intro-to-unit-tests.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyMx299HZMDXEMr2fOPzAm1q", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/sgbaird/46f29843ac1634653a2f9a530eb78c74/gentle-intro-to-unit-tests.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# A Gentle Introduction to Unit Tests\n", | |
"\n", | |
"> By all means, write buggy code! But try not to write buggy code without quickly realizing that it's buggy 😊\n", | |
"\n", | |
"This is where unit tests come in. Unit tests are small, self-contained tests that verify the correctness of individual components or units of code. Let's start with a silly example, but please bear with me until the more interesting bits.\n", | |
"\n", | |
"## Adding Numbers\n", | |
"\n", | |
"The function will be called `add_numbers` which will add two numbers together. We'll put this in a file `addition.py` using special Google Colab syntax.\n", | |
"\n", | |
"> NOTE: The file will show up in the file browser on the left panel" | |
], | |
"metadata": { | |
"id": "Y8t2EK8EO5PG" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "s8HYQ91A9KDN", | |
"outputId": "6ab462de-21df-4b10-d526-2ce9e04b2e8f" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Writing addition.py\n" | |
] | |
} | |
], | |
"source": [ | |
"%%writefile addition.py\n", | |
"def add_numbers(a, b):\n", | |
" return a + b" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"You want to verify that `add_numbers` works as expected, so you write the following unit test using the [pytest](https://docs.pytest.org/en/7.3.x/) format. First, we import our `add_numbers` function from the `addition.py` module. Then, we verify the known case of `2+3=5`. It's important that you prefix both the *filename* the *function name* with `test_` or postfix it with `_test` so that `pytest` can recognize it automatically." | |
], | |
"metadata": { | |
"id": "FALbsyBUPVo3" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%%writefile test_addition.py\n", | |
"from addition import add_numbers\n", | |
"\n", | |
"def test_add_numbers():\n", | |
" result = add_numbers(2, 3)\n", | |
" assert result == 5" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "mQwvfa0C9sqj", | |
"outputId": "ec231227-a14b-4c6f-a9cd-4c0a8c7e2d55" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Writing test_addition.py\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Finally, we run pytest! Note that this uses a \"shebang\" (`!`) to run it in the shell rather than as Python code. Incidentally, pytest is already installed on Google Colab, but otherwise you could install it with `pip install pytest`." | |
], | |
"metadata": { | |
"id": "tc_5TgiOP5z3" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pytest" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "1e6acpbB9w6m", | |
"outputId": "d5a3f9b1-dfb9-4833-991c-25ca71b88122" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1m============================= test session starts ==============================\u001b[0m\n", | |
"platform linux -- Python 3.10.12, pytest-7.2.2, pluggy-1.0.0\n", | |
"rootdir: /content\n", | |
"plugins: anyio-3.6.2\n", | |
"\u001b[1mcollecting ... \u001b[0m\u001b[1m\rcollected 1 item \u001b[0m\n", | |
"\n", | |
"test_addition.py \u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n", | |
"\n", | |
"\u001b[32m============================== \u001b[32m\u001b[1m1 passed\u001b[0m\u001b[32m in 0.02s\u001b[0m\u001b[32m ===============================\u001b[0m\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Your test passed! Typically, you won't be able to address all possible uses of the function, but you can verify enough to feel confident the code is working as intended. More on that later. This example is simple but not particularly interesting. Let's move to a more interesting case where we verify that a machine learning model is performing better than a dummy baseline." | |
], | |
"metadata": { | |
"id": "z2CGNBkwQJIA" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Sanity Check for Property Prediction Model" | |
], | |
"metadata": { | |
"id": "h4UqYzeNOv6X" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Let's say you come up with a new optimization scheme or design a new property prediction model, generative model, etc. A good starting point is to write \"smoke tests\" to verify that code runs without error, but how do you know that the code is doing what it should? Well, you break the large task into little tasks (\"units\") and test those individual pieces. That's going to be highly specific to your problem. Maybe GitHub Copilot can suggest some reasonable tests, but chances are you'll need to think strategically about it.\n", | |
"\n", | |
"On a higher level, a property prediction/generative/optimization model should perform better than some dummy baseline. For property prediction, it should generally perform better than a dummy model that always predicts the mean of the training data. What is the bulk modulus of $\\mathrm{Al}_2\\mathrm{O}_3$? What is the bulk modulus of $\\mathrm{Mo}$? It doesn't matter what the material is, the answer is always just the mean of the training data. For crystal structure generative modeling, it should perform better than choosing random atom locations and assigning random elements to those locations. For optimization models, it should perform better than random search.\n", | |
"\n", | |
"Here's a sanity check that a ML model works better than a dummy baseline. `sklearn` has a [DummyRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.dummy.DummyRegressor.html) class, so we'll use that for the baseline. We'll compare against [RandomForestRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html), but you can pretend this is a fancy materials informatics model like one from [Matbench](https://matbench.materialsproject.org/). You can also pretend `X` represents molecules, chemical formulas, crystal structures, etc. We'll use a small, contrived dataset with an easy relationship between the inputs and outputs, but for you, this could be a portion of [Materials Project](https://next-gen.materialsproject.org/) data, your own experiments, etc. For the sake of completeness, we'll keep the package code (`regressor.py`) and test code (`test_regressor.py`) as separate files." | |
], | |
"metadata": { | |
"id": "hOAdVUSJQMGH" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%%writefile regressor.py\n", | |
"from sklearn.ensemble import RandomForestRegressor as MyModel" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "PRzjJN2reY57", | |
"outputId": "75d033ba-8cd0-45dc-e056-98b120b51a19" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Writing regressor.py\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%%writefile test_regressor.py\n", | |
"import numpy as np\n", | |
"from sklearn.dummy import DummyRegressor\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from sklearn.metrics import mean_absolute_error\n", | |
"from regressor import MyModel\n", | |
"\n", | |
"def test_rf_better_than_dummy():\n", | |
" X = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(-1, 1)\n", | |
" y = np.array([2.0, 3.0, 5.0, 10.0, 15.0, 20.0])\n", | |
" X_train, X_test, y_train, y_test = train_test_split(X, y)\n", | |
" dr = DummyRegressor()\n", | |
" rfr = MyModel()\n", | |
" dr.fit(X_train, y_train)\n", | |
" rfr.fit(X_train, y_train)\n", | |
" dr_mae = mean_absolute_error(y_test, dr.predict(X_test))\n", | |
" rfr_mae = mean_absolute_error(y_test, rfr.predict(X_test))\n", | |
" assert rfr_mae < dr_mae" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "j05v3LziDGdJ", | |
"outputId": "0c7e5977-4583-43fd-8160-18104ba9953d" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Overwriting test_regressor.py\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pytest" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Cyy9NVbiR0nE", | |
"outputId": "ed30ba63-89a6-4f1b-b986-5fb04fa0caa4" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1m============================= test session starts ==============================\u001b[0m\n", | |
"platform linux -- Python 3.10.12, pytest-7.2.2, pluggy-1.0.0\n", | |
"rootdir: /content\n", | |
"plugins: anyio-3.6.2\n", | |
"collected 2 items \u001b[0m\n", | |
"\n", | |
"test_addition.py \u001b[32m.\u001b[0m\u001b[32m [ 50%]\u001b[0m\n", | |
"test_regressor.py \u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n", | |
"\n", | |
"\u001b[32m============================== \u001b[32m\u001b[1m2 passed\u001b[0m\u001b[32m in 0.85s\u001b[0m\u001b[32m ===============================\u001b[0m\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"It worked! Not only that, but it checked both the `test_addition.py` and `test_regressor.py` files. To illustrate what happens when your code doesn't pass a test, let's add a bug to the `add_numbers` function and run pytest again." | |
], | |
"metadata": { | |
"id": "vAeZdjFFUCpQ" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%%writefile addition.py\n", | |
"def add_numbers(a, b):\n", | |
" return a * b" | |
], | |
"metadata": { | |
"id": "rW09H8L1e__F", | |
"outputId": "63c5953a-f66b-400d-d029-64e18b33c7d8", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Overwriting addition.py\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pytest" | |
], | |
"metadata": { | |
"id": "m07hhApZfTs9", | |
"outputId": "5a009c75-8e29-4659-e20f-6ca8a5193c8b", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1m============================= test session starts ==============================\u001b[0m\n", | |
"platform linux -- Python 3.10.12, pytest-7.2.2, pluggy-1.0.0\n", | |
"rootdir: /content\n", | |
"plugins: anyio-3.6.2\n", | |
"collected 2 items \u001b[0m\n", | |
"\n", | |
"test_addition.py \u001b[31mF\u001b[0m\u001b[31m [ 50%]\u001b[0m\n", | |
"test_regressor.py \u001b[32m.\u001b[0m\u001b[31m [100%]\u001b[0m\n", | |
"\n", | |
"=================================== FAILURES ===================================\n", | |
"\u001b[31m\u001b[1m_______________________________ test_add_numbers _______________________________\u001b[0m\n", | |
"\n", | |
" \u001b[94mdef\u001b[39;49;00m \u001b[92mtest_add_numbers\u001b[39;49;00m():\u001b[90m\u001b[39;49;00m\n", | |
" result = add_numbers(\u001b[94m2\u001b[39;49;00m, \u001b[94m3\u001b[39;49;00m)\u001b[90m\u001b[39;49;00m\n", | |
"> \u001b[94massert\u001b[39;49;00m result == \u001b[94m5\u001b[39;49;00m\u001b[90m\u001b[39;49;00m\n", | |
"\u001b[1m\u001b[31mE assert 6 == 5\u001b[0m\n", | |
"\n", | |
"\u001b[1m\u001b[31mtest_addition.py\u001b[0m:5: AssertionError\n", | |
"\u001b[36m\u001b[1m=========================== short test summary info ============================\u001b[0m\n", | |
"\u001b[31mFAILED\u001b[0m test_addition.py::\u001b[1mtest_add_numbers\u001b[0m - assert 6 == 5\n", | |
"\u001b[31m========================= \u001b[31m\u001b[1m1 failed\u001b[0m, \u001b[32m1 passed\u001b[0m\u001b[31m in 1.37s\u001b[0m\u001b[31m ==========================\u001b[0m\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Oops! A test failed. Upon inspection, the code is multiplying the two numbers instead of adding them. It told you exactly which function went wrong and how it went wrong. If multiple tests fail, it informs you of each case. However, unit tests are only as robust as you make them; if we had chosen `2+2=4` as the test case, we would have had a \"silent bug\" (one that doesn't throw an error) since `2*2=4`. Meanwhile, it's usually unreasonable to address every case possible ([property-based testing](https://hypothesis.readthedocs.io/en/latest/) can help, but that's a tutorial for another day). So, implement your unit tests judiciously." | |
], | |
"metadata": { | |
"id": "UaWg85xpfXRk" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Other advantages\n", | |
"\n", | |
"In addition to helping with **early bug detection**, I'll highlight two more advantages to using unit tests: maintenance and code quality.\n", | |
"\n", | |
"Unit tests act as safety nets when making changes to your code. When you refactor or optimize code, you can rerun the unit tests to ensure that the behavior remains consistent. This makes it easier to **maintain** your code.\n", | |
"\n", | |
"Writing unit tests encourages high-level thinking about what the code is for and what the expected behavior should be (see e.g., [test-driven development](https://en.wikipedia.org/wiki/Test-driven_development)). While writing unit tests, you're also more likely to become aware of design issues with your code such as poor modularity, unecessary duplication, and overly complex code. In other words, unit tests help **improve code quality**.\n", | |
"\n", | |
"Unit tests save you headache in the long run by:\n", | |
"1. helping to catch bugs early\n", | |
"2. making it easier to maintain code\n", | |
"3. helping improve code quality\n", | |
"\n", | |
"## Next Steps\n", | |
"\n", | |
"The point of this tutorial is give a brief introduction to unit tests, provide a teaser that shows how they're relevant to chemistry and materials informatics, and give you an initial sense of reward. It's not to teach you all there is to them. Here are three action items for you to take:\n", | |
"1. Write at least one unit test for a project you're working on\n", | |
"2. Start going through [the pytest documentation](https://docs.pytest.org/en/7.3.x/)\n", | |
"3. Move the code for one of your projects into a [PyScaffold](https://pyscaffold.org/en/stable/) template and add the unit test you wrote from (1)\n", | |
"\n", | |
"PyScaffold is a high-quality Python package template generator, and it uses `pytest` to run all automated tests. Start with [the homepage](https://pyscaffold.org/en/stable/) and work your way towards (or skip to) [\"automation, tests, and coverage\"](https://pyscaffold.org/en/stable/features.html#automation-tests-coverage) and [\"migrating to PyScaffold\"](https://pyscaffold.org/en/stable/migration.html).\n", | |
"\n", | |
"If you want to get in contact:\n", | |
"- [GitHub](https://github.com/sgbaird)\n", | |
"- [LinkedIn](https://www.linkedin.com/in/sterling-baird)\n", | |
"- [Twitter](https://twitter.com/SterlingBaird1)\n", | |
"\n", | |
"For email, see the mail icon at bottom-left of my GitHub profile." | |
], | |
"metadata": { | |
"id": "Ui74bXufU1Vd" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"<!-- Unit tests are used in virtually every common Python package such as [numpy](https://github.com/numpy/numpy/tree/main/numpy/core/tests), [scikit-learn](https://github.com/scikit-learn/scikit-learn/tree/main/sklearn/metrics/tests), and [scipy](https://github.com/scipy/scipy/tree/main/scipy/interpolate/tests). They're also implemented in many chemistry and materials informatics repositories such as [rdkit](https://github.com/rdkit/rdkit/blob/master/rdkit/ML/UnitTestScreenComposite.py) and [pymatgen](https://github.com/materialsproject/pymatgen/tree/master/pymatgen/core/tests). However, implementing unit tests for your own projects doesn't need to be as daunting as these previous examples might imply -->" | |
], | |
"metadata": { | |
"id": "R_q699oXk18X" | |
} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment