Created
June 28, 2023 18:03
-
-
Save sobolevn/967ac61b8b2575e26a33f5a6bf5df201 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import argparse | |
import os | |
import sys | |
import tokenize | |
from dataclasses import dataclass | |
from typing import Final, TypeAlias | |
_AnyFunction: TypeAlias = ast.FunctionDef | ast.AsyncFunctionDef | |
_EXCLUDE_FILES: Final = frozenset({ | |
'Lib/test/badsyntax_3131.py', | |
'Lib/test/badsyntax_pep3120.py', | |
'Lib/test/bad_coding2.py', | |
}) | |
@dataclass(frozen=True, slots=True) | |
class MethodDuplicate: | |
node: _AnyFunction | |
class_name: str | |
def report_error(self) -> str: | |
return f'Found duplicate method {self.node.name} in {self.class_name} class' | |
@dataclass(frozen=True, slots=True) | |
class ClassDuplicate: | |
node: ast.ClassDef | |
def report_error(self) -> str: | |
return f'Found duplicate {self.node.name} class' | |
class ClassSpec: | |
def __init__(self, class_name: str) -> None: | |
self.name = class_name | |
self.methods: set[set] = set() | |
def add_method(self, method: _AnyFunction) -> MethodDuplicate | None: | |
if method.name in self.methods: | |
return MethodDuplicate(method, self.name) | |
self.methods.add(method.name) | |
return None | |
class TestMethodDuplicatesVisitor(ast.NodeVisitor): | |
_METHOD_PREFIX: Final = 'test_' | |
def __init__(self) -> None: | |
self.current_class: ClassSpec | None = None | |
self.classes: set[str] = set() | |
self.duplicates: list[ClassDuplicate | MethodDuplicate] = [] | |
def visit_ClassDef(self, node: ast.ClassDef) -> None: | |
if isinstance(node._parent, ast.Module): | |
if node.name in self.classes: | |
self.duplicates.append(ClassDuplicate(node)) | |
self.classes.add(node.name) | |
self.current_class = ClassSpec(node.name) | |
self.generic_visit(node) | |
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: | |
if ( | |
self.current_class | |
and isinstance(node._parent, ast.ClassDef) | |
and node.name.startswith(self._METHOD_PREFIX) | |
): | |
result = self.current_class.add_method(node) | |
if result is not None: | |
self.duplicates.append(result) | |
self.generic_visit(node) | |
visit_AsyncFunctionDef = visit_FunctionDef | |
def _read_file(filename: str) -> str: | |
# Taken from https://github.com/PyCQA/flake8/ | |
try: | |
with tokenize.open(filename) as fd: | |
return fd.read() | |
except (SyntaxError, UnicodeError): | |
# If we can't detect the codec with tokenize.detect_encoding, or | |
# the detected encoding is incorrect, just fallback to latin-1. | |
with open(filename, encoding='latin-1') as fd: | |
return fd.read() | |
def _set_parents(tree: ast.AST) -> ast.AST: | |
for statement in ast.walk(tree): | |
for child in ast.iter_child_nodes(statement): | |
setattr(child, '_parent', statement) | |
return tree | |
def main() -> None: | |
parser = argparse.ArgumentParser('Find duplicate test methods') | |
parser.add_argument('dir', type=str) | |
args = parser.parse_args() | |
found_any_duplicates = False | |
for dirpath, _, files in os.walk(args.dir): | |
for file in files: | |
if not file.endswith('.py'): | |
continue | |
full_path = os.path.join(dirpath, file) | |
if full_path in _EXCLUDE_FILES: | |
continue | |
source = ast.parse(_read_file(full_path), filename=full_path) | |
visitor = TestMethodDuplicatesVisitor() | |
visitor.visit(_set_parents(source)) | |
for dup in visitor.duplicates: | |
found_any_duplicates = True | |
print(f'{full_path}:{dup.node.lineno} {dup.report_error()}') | |
sys.exit(found_any_duplicates) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment