Created
November 3, 2022 23:22
-
-
Save harupy/e7656c4c2b8bb5018c7674ddac73a312 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import libcst as cst | |
import pathlib | |
import argparse | |
import difflib | |
class AssertMethodTransformer(cst.CSTTransformer): | |
@staticmethod | |
def is_unittest_assert_method(node: cst.Call, name: str) -> bool: | |
return ( | |
isinstance(node.func, cst.Attribute) | |
and isinstance(node.func.value, cst.Name) | |
and node.func.value.value == "self" | |
and isinstance(node.func.attr, cst.Name) | |
and node.func.attr.value == name | |
) | |
def leave_Assert(self, original_node: cst.Assert, updated_node: cst.Assert) -> cst.Assert: | |
# print(original_node) | |
if isinstance(original_node.test, cst.Comparison): | |
if ( | |
len(original_node.test.comparisons) == 1 | |
and isinstance(original_node.test.comparisons[0].operator, cst.Equal) | |
and ( | |
isinstance(original_node.test.left, cst.SimpleString) | |
or isinstance(original_node.test.left, cst.Integer) | |
or isinstance(original_node.test.left, cst.Float) | |
or isinstance(original_node.test.left, cst.List) | |
or isinstance(original_node.test.left, cst.Set) | |
or isinstance(original_node.test.left, cst.Tuple) | |
or isinstance(original_node.test.left, cst.Dict) | |
or ( | |
isinstance(original_node.test.left, cst.Name) | |
and original_node.test.left.value == "None" | |
) | |
or ( | |
isinstance(original_node.test.left, cst.Name) | |
and original_node.test.left.value == "False" | |
) | |
or ( | |
isinstance(original_node.test.left, cst.Name) | |
and original_node.test.left.value == "True" | |
) | |
) | |
): | |
return cst.Assert( | |
test=cst.Comparison( | |
left=original_node.test.comparisons[0].comparator, | |
comparisons=[ | |
cst.ComparisonTarget( | |
operator=cst.Equal(), | |
comparator=original_node.test.left, | |
) | |
], | |
), | |
msg=original_node.msg, | |
) | |
return original_node | |
def transform_file(path: pathlib.Path) -> None: | |
src = path.read_text() | |
source_tree = cst.parse_module(src) | |
modified_tree = source_tree.visit(AssertMethodTransformer()) | |
if not modified_tree.deep_equals(source_tree): | |
print("".join(difflib.unified_diff(src.splitlines(1), modified_tree.code.splitlines(1)))) | |
path.write_text(modified_tree.code) | |
print(f"Transformed {path}") | |
if __name__ == "__main__": | |
# python a.py $(git ls-files "tests/**/*.py") | |
parser = argparse.ArgumentParser(description="Process some integers.") | |
parser.add_argument("files", metavar="N", nargs="+", help="an integer for the accumulator") | |
args = parser.parse_args() | |
for f in args.files: | |
print("Processing", f) | |
transform_file(pathlib.Path(f)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment