Last active
January 27, 2021 11:13
-
-
Save gtors/27f878d2d42721d0f69e0ec98810ef80 to your computer and use it in GitHub Desktop.
Python script for sorting class definitions in file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script for sorting classes by type in annotation. It may be usefull for code refactoring after code generation. | |
For example: | |
``` | |
class C: | |
a: Optional[A] = None | |
b: Optional[B] = None | |
class A: | |
b: Optional[B] = None | |
class B: | |
c: int | |
``` | |
will be fixed to: | |
``` | |
class B: | |
c: int | |
class A: | |
b: Optional[B] = None | |
class C: | |
a: Optional[A] = None | |
b: Optional[B] = None | |
``` | |
""" | |
import sys | |
import ast | |
import functools | |
from collections import defaultdict | |
from graphlib import TopologicalSorter # py3.9 | |
ast_filter = lambda it, ty: (x for x in getattr(it, 'body', it) if isinstance(x, ty)) | |
def populate_class_graph(tree): | |
graph = defaultdict(list) | |
for cls in ast_filter(tree, ast.ClassDef): | |
graph[cls.name] = [] | |
for prop in ast_filter(cls, ast.AnnAssign): | |
ann = prop.annotation | |
if isinstance(ann, ast.Name): | |
graph[cls.name].append(ann.id) | |
elif isinstance(ann, ast.Subscript): | |
if isinstance(ann.slice, ast.Name): | |
graph[cls.name].append(ann.slice.id) | |
elif isinstance(ann.slice, (ast.Tuple, ast.Subscript)): | |
if isinstance(ann.slice, ast.Tuple): | |
traverse = prop.annotation.slice.elts | |
else: | |
traverse = [ann.slice] | |
while traverse: | |
_traverse = [] | |
for name in ast_filter(traverse, ast.Name): | |
graph[cls.name].append(name.id) | |
for subs in ast_filter(traverse, ast.Subscript): | |
if isinstance(subs.slice, (ast.Name, ast.Subscript)): | |
_traverse.append(subs.slice) | |
elif isinstance(subs.slice, ast.Tuple): | |
_traverse.extend(subs.slice.elts) | |
traverse = _traverse | |
return graph | |
if __name__ == "__main__": | |
file_name = sys.argv(1) | |
with open(file_name) as f: | |
py_code = f.read() | |
ast_tree = ast.parse(py_code) | |
graph = populate_class_graph(ast_tree) | |
topo_sorter = TopologicalSorter(graph) | |
class_order = tuple(topo_sorter.static_order()) | |
# A bit messy, but if the file consists only of imports and class definitions, then ok | |
tree.body.sort(key=( | |
lambda x: ( | |
class_order.index(n) | |
if (n := getattr(x, 'name', None)) in class_order else | |
0 | |
) | |
) | |
with open("fixed_" + file_name) as f: | |
f.write(ast.unparse(tree)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment