Last active
May 10, 2023 17:45
-
-
Save samuelcolvin/625a1655c2a73e02469fc3c27285ca42 to your computer and use it in GitHub Desktop.
auto-generate assert statements in pytest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
License: MIT | |
Copyright (c) 2022 Samuel Colvin. | |
See https://twitter.com/adriangb01/status/1573708407479189505 | |
## Usage | |
Once installed just add | |
```py | |
insert_assert(the_value) | |
# or | |
insert_assert(function_call()) | |
``` | |
to a test and run pytest this code will collect the argument string, and the value, format it with black | |
and substitute `assert <argument-code> == value` into code when pytest finishes. | |
## Installation | |
To use this (until it's a proper package and pytest plugin): | |
Add this file to `tests`, exclude it from git, add the following to your conftest.py: | |
```py | |
try: | |
from .insert_assert import * | |
except ImportError: | |
pass | |
``` | |
## Example usage | |
```py | |
def test_string(): | |
thing = 'foobar' | |
insert_assert(thing) | |
def test_list_callable(): | |
def foobar(): | |
return ['foo', 1, b'bytes'] | |
insert_assert(foobar()) | |
def test_comprehension(): | |
insert_assert([f'x{i}' for i in range(10)]) | |
``` | |
""" | |
import ast | |
import os | |
import sys | |
import textwrap | |
from dataclasses import dataclass | |
from enum import Enum | |
from itertools import groupby | |
from pathlib import Path | |
from types import FrameType | |
from typing import Any | |
import pytest | |
from black import InvalidInput, Mode, TargetVersion, format_file_contents | |
# requires pip install executing black | |
from executing import Source | |
__all__ = 'add_insert_assert_to_builtins', 'pytest_terminal_summary', 'insert_assert' | |
@dataclass | |
class ToReplace: | |
file: Path | |
start_line: int | |
end_line: int | |
code: str | |
to_replace: list[ToReplace] = [] | |
@pytest.fixture(scope='session', autouse=True) | |
def add_insert_assert_to_builtins(): | |
__builtins__['insert_assert'] = insert_assert | |
def pytest_terminal_summary(): | |
if to_replace: | |
# TODO replace with a pytest argument | |
try_run = bool(os.getenv('TRY_RUN')) | |
file_count = 0 | |
for file, group in groupby(to_replace, key=lambda tr: tr.file): | |
# we have to substitute lines in reverse order to avoid messing up line numbers | |
lines = file.read_text().splitlines() | |
for tr in sorted(group, key=lambda x: x.start_line, reverse=True): | |
if try_run: | |
hr = '-' * 80 | |
print(f'{file} - {tr.start_line}:{tr.end_line}:\n{hr}\n{tr.code}{hr}\n') | |
else: | |
lines[tr.start_line - 1 : tr.end_line] = tr.code.splitlines() | |
if not try_run: | |
file.write_text('\n'.join(lines)) | |
file_count += 1 | |
print(f'replaced {len(to_replace)} insert_assert() calls in {file_count} files') | |
def insert_assert(value): | |
call_frame: FrameType = sys._getframe(1) | |
source = Source.for_frame(call_frame) | |
ex = source.executing(call_frame) | |
ast_arg = ex.node.args[0] | |
if isinstance(ast_arg, ast.Name): | |
arg = ast_arg.id | |
else: | |
arg = ' '.join(map(str.strip, ex.source.asttokens().get_text(ast_arg).splitlines())) | |
python_code = f'# insert_assert({arg})\nassert {arg} == {custom_repr(value)}' | |
mode = Mode( | |
line_length=120, | |
string_normalization=False, | |
magic_trailing_comma=False, | |
target_versions={TargetVersion.PY37, TargetVersion.PY38, TargetVersion.PY39, TargetVersion.PY310}, | |
) | |
try: | |
python_code = format_file_contents(python_code, fast=False, mode=mode) | |
except InvalidInput: | |
# we just ignore this and allow the user to fix the code and run black | |
print('black error') | |
pass | |
python_code = textwrap.indent(python_code, ex.node.col_offset * ' ') | |
to_replace.append(ToReplace(Path(call_frame.f_code.co_filename), ex.node.lineno, ex.node.end_lineno, python_code)) | |
def custom_repr(value): | |
if isinstance(value, (list, tuple, set, frozenset)): | |
return value.__class__(map(custom_repr, value)) | |
elif isinstance(value, dict): | |
return value.__class__((custom_repr(k), custom_repr(v)) for k, v in value.items()) | |
if isinstance(value, Enum): | |
return PlainRepr(f'{value.__class__.__name__}.{value.name}') | |
else: | |
return PlainRepr(repr(value)) | |
class PlainRepr: | |
__slots__ = ('s',) | |
def __init__(self, s: str): | |
self.s = s | |
def __repr__(self): | |
return self.s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment