Skip to content

Instantly share code, notes, and snippets.

@thrau
Last active June 14, 2022 04:02
Show Gist options
  • Save thrau/ace5f9b538cdba4808f6bf652647e8a6 to your computer and use it in GitHub Desktop.
Save thrau/ace5f9b538cdba4808f6bf652647e8a6 to your computer and use it in GitHub Desktop.
script to covert unittest asserts to plain asserts
"""
Script to convert unittest asserts into plain asserts.
Either reads the file from the path passed as first parameter, or reads from stdin if no parameter is given.
"""
import functools
import sys
import libcst as cst
def to_assert(left, op=None, right=None, msg=None) -> cst.Assert:
if right is None:
if op is None:
return cst.Assert(test=left, msg=msg)
return cst.Assert(test=cst.UnaryOperation(operator=op, expression=left), msg=msg)
return cst.Assert(
test=cst.Comparison(
left=left,
comparisons=[
cst.ComparisonTarget(
operator=op,
comparator=right,
),
],
),
msg=msg,
)
class AssertTransformer(cst.CSTTransformer):
unary_comps = {
"assertIsNone": lambda arg, msg: to_assert(arg, cst.Is(), cst.Name("None"), msg=msg),
"assertIsNotNone": lambda arg, msg: to_assert(arg, cst.IsNot(), cst.Name("None"), msg=msg),
"assertTrue": lambda arg, msg: to_assert(arg, msg=msg),
"assertFalse": lambda arg, msg: to_assert(arg, cst.Not(), msg=msg),
}
binary_comps = {
"assertEqual": functools.partial(to_assert, op=cst.Equal()),
"assertNotEqual": functools.partial(to_assert, op=cst.NotEqual()),
"assertIn": functools.partial(to_assert, op=cst.In()),
"assertNotIn": functools.partial(to_assert, op=cst.NotIn()),
"assertIs": functools.partial(to_assert, op=cst.Is()),
"assertIsNot": functools.partial(to_assert, op=cst.IsNot()),
"assertLess": functools.partial(to_assert, op=cst.LessThan()),
"assertLessEqual": functools.partial(to_assert, op=cst.LessThanEqual()),
"assertGreater": functools.partial(to_assert, op=cst.GreaterThan()),
"assertGreaterEqual": functools.partial(to_assert, op=cst.GreaterThanEqual()),
}
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
if isinstance(updated_node.func, cst.Attribute):
attr = updated_node.func
comp = attr.attr.value.strip()
if comp in self.unary_comps:
left = updated_node.args[0].value
try:
msg = updated_node.args[1].value
except IndexError:
msg = None
return self.unary_comps[comp](arg=left, msg=msg)
if comp in self.binary_comps:
left = updated_node.args[0].value
right = updated_node.args[1].value
try:
msg = updated_node.args[2].value
except IndexError:
msg = None
return self.binary_comps[comp](left=left, right=right, msg=msg)
return updated_node
def main():
if len(sys.argv) > 1:
with open(sys.argv[1]) as fd:
code = fd.read()
else:
code = sys.stdin.read()
tree = cst.parse_module(code)
new_code = tree.visit(AssertTransformer()).code
print(new_code)
if __name__ == "__main__":
main()
@whummer
Copy link

whummer commented May 3, 2022

Nice script, thanks for sharing this @thrau ! Just gave this a try - works like a charm.. 😎

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment