Created
May 27, 2024 15:49
-
-
Save Sachaa-Thanasius/49bb1a3ed312746d30e23ad5cb80ce6f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import os | |
import sys | |
from io import BytesIO | |
from typing import TYPE_CHECKING, Tuple, Union | |
if sys.version_info >= (3, 10) or TYPE_CHECKING: | |
StrPath = Union[str, os.PathLike[str]] | |
else: | |
StrPath = Union[str, os.PathLike] | |
class AnnotationVisitor(ast.NodeVisitor): | |
"""Node visitor that finds the locations for all annotations that contain `float` or `complex`. | |
It also tracks the name for imports of `typing` or `typing.Union`, as well as a position to place an import for | |
`typing.Union` if a) it doesn't already exist and b) is necessary to accomodate pre-3.10 code. | |
""" | |
def __init__(self): | |
self.float_ann_locations: list[tuple[int, int]] = [] | |
self.complex_ann_locations: list[tuple[int, int]] = [] | |
self.import_position: tuple[int, int] = (-1, -1) | |
self.union_import_needed = True | |
self.union_name = b"Union" | |
def _visit_annotation(self, node: Union[ast.arg, ast.AnnAssign]) -> None: | |
if node.annotation is not None: | |
for ann_component in ast.walk(node.annotation): | |
if isinstance(ann_component, ast.Name): | |
if ann_component.id == "float": | |
self.float_ann_locations.append((ann_component.lineno, ann_component.col_offset)) | |
elif ann_component.id == "complex": | |
self.complex_ann_locations.append((ann_component.lineno, ann_component.col_offset)) | |
def visit_arg(self, node: ast.arg): | |
"""Find all the parts of annotations in parameters that use either `float` or `complex`.""" | |
self._visit_annotation(node) | |
self.generic_visit(node) | |
def visit_Import(self, node: ast.Import) -> None: | |
"""Save what name `typing` was imported under, if it's already imported.""" | |
typing_alias = next(alias for alias in node.names if alias.name == "typing") | |
if typing_alias: | |
self.union_import_needed = False | |
self.union_name = (f"{typing_alias.asname or typing_alias.name}.Union").encode() | |
self.generic_visit(node) | |
def visit_ImportFrom(self, node: ast.ImportFrom) -> None: | |
"""Save what name `typing.Union` was imported under, if it's already imported.""" | |
if node.module == "typing": | |
union_alias = next((alias for alias in node.names if alias.name == "Union"), None) | |
if union_alias: | |
self.union_import_needed = False | |
self.union_name = (union_alias.asname or union_alias.name).encode() | |
self.generic_visit(node) | |
def visit_Module(self, node: ast.Module) -> None: | |
"""Find the first available plce in the module to insert an import if need be. | |
This may be be used to place an import for typing.Union in pre-3.10 Python code. | |
""" | |
self.generic_visit(node) | |
expect_docstring = True | |
for sub_node in node.body: | |
if ( | |
isinstance(sub_node, ast.Expr) | |
and isinstance(sub_node.value, ast.Constant) | |
and isinstance(sub_node.value.value, str) | |
and expect_docstring | |
): | |
expect_docstring = False | |
elif isinstance(sub_node, ast.ImportFrom) and sub_node.module == "__future__" and sub_node.level == 0: | |
pass | |
else: | |
self.import_position = (sub_node.lineno, sub_node.col_offset) | |
break | |
def fix_file(filename: StrPath, py_version: Tuple[int, int]) -> None: | |
"""Update a file to adjust annotations for `float` and `complex` to `float | int` and `complex | float | int` respectively. | |
This accounts for python versions without access to the `|` syntax, though it doesn't account for string | |
annotations that use that syntax in pre-3.10 versions. | |
Parameters | |
---------- | |
filename: StrPath | |
A path-like object that can be used to open a file. The resulting file will be modified. | |
py_version: tuple[int, int] | |
The Python version this script should assume the given file's code is written in. If below (3, 10), `Union` will be | |
used instead of `|` when substituting the annotations, resulting in an added import from `typing`. | |
Examples | |
-------- | |
For <3.10 code: | |
.. code-block:: python | |
# Initial state of test.py in 3.8: | |
from typing import Optional | |
def example(a: float, *, b: complex, **kwargs: Optional[float]) -> float: | |
result: float = a + b + sum(kwargs.values()) | |
return result | |
# After calling fix_file("test.py", (3, 8)): | |
from typing import Union | |
from typing import Optional | |
def example(a: Union[float, int], *, b: Union[complex, float, int], **kwargs: Optional[Union[float, int]]) -> float: | |
result: float = a + b + sum(kwargs.values()) | |
return result | |
For >=3.10 code: | |
.. code-block:: python | |
# Initial state of test.py in 3.10: | |
def example(a: float, *, b: complex, **kwargs: float | None) -> float: | |
result: float = a + b + sum(kwargs.values()) | |
return result | |
# After calling fix_file("test.py", (3, 10)): | |
def example(a: float | int, *, b: complex | float | int, **kwargs: float | int | None) -> float: | |
result: float = a + b + sum(kwargs.values()) | |
return result | |
""" | |
with open(filename, "rb") as fp: | |
source = fp.read() | |
tree = ast.parse(source) | |
visitor = AnnotationVisitor() | |
visitor.visit(tree) | |
if py_version >= (3, 10): | |
float_replacement = b"float | int" | |
complex_replacement = b"complex | float | int" | |
else: | |
float_replacement = visitor.union_name + b"[float, int]" | |
complex_replacement = visitor.union_name + b"[complex, float, int]" | |
def modify_source(): | |
for lineno, line in enumerate(iter(BytesIO(source).readline, b""), start=1): | |
if lineno == visitor.import_position[0] and py_version < (3, 10) and visitor.union_import_needed: | |
yield b"from typing import Union\n" | |
relevant_float_locations = [ | |
(float_lineno, col_offset, b"float") | |
for float_lineno, col_offset in visitor.float_ann_locations | |
if lineno == float_lineno | |
] | |
relevant_complex_locations = [ | |
(complex_lineno, col_offset, b"complex") | |
for complex_lineno, col_offset in visitor.complex_ann_locations | |
if lineno == complex_lineno | |
] | |
# Needs to be in reverse order to avoid column offsets on the same line from becoming incorrect as the line | |
# is mutated. | |
relevant_locations = sorted((*relevant_float_locations, *relevant_complex_locations), reverse=True) | |
for _, rel_col_offset, type_ in relevant_locations: | |
if type_ == b"float": | |
curr_replacement = float_replacement | |
else: | |
curr_replacement = complex_replacement | |
line = line[:rel_col_offset] + curr_replacement + line[rel_col_offset + len(type_) :] | |
yield line | |
with open(filename, "wb") as fp: | |
fp.writelines(modify_source()) | |
def main(): | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("filename", help="The file that will be modified.") | |
parser.add_argument( | |
"python_version", | |
help=( | |
"The python version to take into account when modifying the file. Determines whether 'Union' or '|' is used. " | |
"Must be in the format '3.x', with x being a number." | |
), | |
) | |
args = parser.parse_args() | |
filename = args.filename | |
py_major, _, py_minor = args.python_version.partition(".") | |
fix_file(filename, (int(py_major), int(py_minor))) | |
if __name__ == "__main__": | |
raise SystemExit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment