Last active
January 25, 2021 18:48
-
-
Save agmond/9240c74f5ff88497eec7f6d2949a38fe to your computer and use it in GitHub Desktop.
Fix `super()` calls after migration to Python3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
The following script fixes `super()` calls after migration from Python 2 to Python 3. | |
The script edits the code automatically. Upon completion, the following steps are recommended: | |
- Search for the regex `super\([^\)]` and fix manually those places (if needed) | |
- Search for the regex `super\(\s[^\)]` and fix manually those places (if needed) | |
- Run Flake8 and manually fix styling problems | |
""" | |
import ast | |
import linecache | |
from pathlib import Path | |
root_path = Path('/path/to/main/directory/') | |
files = root_path.glob('**/*.py') | |
class SuperVisitor(ast.NodeVisitor): | |
def __init__(self, filename, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.filename = filename | |
self.current_class_name = None | |
self.lines_to_edit = [] | |
def visit_ClassDef(self, node): | |
self.current_class_name = node.name | |
for child_node in node.body: | |
self.generic_visit(child_node) | |
self.current_class_name = None | |
return node | |
def visit_Call(self, node): | |
if ( | |
self.current_class_name | |
and hasattr(node, 'func') | |
and getattr(node.func, 'id', None) == 'super' | |
and hasattr(node, 'args') | |
and len(node.args) == 2 | |
): | |
class_name_arg, self_arg = node.args | |
if getattr(class_name_arg, 'id', None) == self.current_class_name and getattr(self_arg, 'id', None) == 'self': | |
assert class_name_arg.lineno == self_arg.end_lineno, f'Not in the same line: {self.filename}:{node.lineno}' | |
self.lines_to_edit.append((class_name_arg.lineno, class_name_arg.col_offset, self_arg.end_col_offset)) | |
for child_node in ast.walk(node): | |
if child_node != node: | |
self.generic_visit(child_node) | |
return node | |
for filename in map(str, files): | |
with open(filename, 'r+') as f: | |
source = f.read() | |
parsed = ast.parse(source, filename) | |
node_visitor = SuperVisitor(filename) | |
node_visitor.visit(parsed) | |
if not node_visitor.lines_to_edit: | |
continue | |
linecache.clearcache() | |
for lineno, start, end in node_visitor.lines_to_edit: | |
current_line_text = linecache.getline(filename, lineno) | |
new_line_text = current_line_text[:start] + current_line_text[end:] | |
source = source.replace(current_line_text, new_line_text, 1) | |
f.seek(0) | |
f.truncate(0) | |
f.write(source) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment