Skip to content

Instantly share code, notes, and snippets.

@Sachaa-Thanasius
Created May 27, 2024 15:49
Show Gist options
  • Save Sachaa-Thanasius/49bb1a3ed312746d30e23ad5cb80ce6f to your computer and use it in GitHub Desktop.
Save Sachaa-Thanasius/49bb1a3ed312746d30e23ad5cb80ce6f to your computer and use it in GitHub Desktop.
import ast
import os
import sys
from io import BytesIO
from typing import TYPE_CHECKING, Tuple, Union
if sys.version_info >= (3, 10) or TYPE_CHECKING:
StrPath = Union[str, os.PathLike[str]]
else:
StrPath = Union[str, os.PathLike]
class AnnotationVisitor(ast.NodeVisitor):
"""Node visitor that finds the locations for all annotations that contain `float` or `complex`.
It also tracks the name for imports of `typing` or `typing.Union`, as well as a position to place an import for
`typing.Union` if a) it doesn't already exist and b) is necessary to accomodate pre-3.10 code.
"""
def __init__(self):
self.float_ann_locations: list[tuple[int, int]] = []
self.complex_ann_locations: list[tuple[int, int]] = []
self.import_position: tuple[int, int] = (-1, -1)
self.union_import_needed = True
self.union_name = b"Union"
def _visit_annotation(self, node: Union[ast.arg, ast.AnnAssign]) -> None:
if node.annotation is not None:
for ann_component in ast.walk(node.annotation):
if isinstance(ann_component, ast.Name):
if ann_component.id == "float":
self.float_ann_locations.append((ann_component.lineno, ann_component.col_offset))
elif ann_component.id == "complex":
self.complex_ann_locations.append((ann_component.lineno, ann_component.col_offset))
def visit_arg(self, node: ast.arg):
"""Find all the parts of annotations in parameters that use either `float` or `complex`."""
self._visit_annotation(node)
self.generic_visit(node)
def visit_Import(self, node: ast.Import) -> None:
"""Save what name `typing` was imported under, if it's already imported."""
typing_alias = next(alias for alias in node.names if alias.name == "typing")
if typing_alias:
self.union_import_needed = False
self.union_name = (f"{typing_alias.asname or typing_alias.name}.Union").encode()
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Save what name `typing.Union` was imported under, if it's already imported."""
if node.module == "typing":
union_alias = next((alias for alias in node.names if alias.name == "Union"), None)
if union_alias:
self.union_import_needed = False
self.union_name = (union_alias.asname or union_alias.name).encode()
self.generic_visit(node)
def visit_Module(self, node: ast.Module) -> None:
"""Find the first available plce in the module to insert an import if need be.
This may be be used to place an import for typing.Union in pre-3.10 Python code.
"""
self.generic_visit(node)
expect_docstring = True
for sub_node in node.body:
if (
isinstance(sub_node, ast.Expr)
and isinstance(sub_node.value, ast.Constant)
and isinstance(sub_node.value.value, str)
and expect_docstring
):
expect_docstring = False
elif isinstance(sub_node, ast.ImportFrom) and sub_node.module == "__future__" and sub_node.level == 0:
pass
else:
self.import_position = (sub_node.lineno, sub_node.col_offset)
break
def fix_file(filename: StrPath, py_version: Tuple[int, int]) -> None:
"""Update a file to adjust annotations for `float` and `complex` to `float | int` and `complex | float | int` respectively.
This accounts for python versions without access to the `|` syntax, though it doesn't account for string
annotations that use that syntax in pre-3.10 versions.
Parameters
----------
filename: StrPath
A path-like object that can be used to open a file. The resulting file will be modified.
py_version: tuple[int, int]
The Python version this script should assume the given file's code is written in. If below (3, 10), `Union` will be
used instead of `|` when substituting the annotations, resulting in an added import from `typing`.
Examples
--------
For <3.10 code:
.. code-block:: python
# Initial state of test.py in 3.8:
from typing import Optional
def example(a: float, *, b: complex, **kwargs: Optional[float]) -> float:
result: float = a + b + sum(kwargs.values())
return result
# After calling fix_file("test.py", (3, 8)):
from typing import Union
from typing import Optional
def example(a: Union[float, int], *, b: Union[complex, float, int], **kwargs: Optional[Union[float, int]]) -> float:
result: float = a + b + sum(kwargs.values())
return result
For >=3.10 code:
.. code-block:: python
# Initial state of test.py in 3.10:
def example(a: float, *, b: complex, **kwargs: float | None) -> float:
result: float = a + b + sum(kwargs.values())
return result
# After calling fix_file("test.py", (3, 10)):
def example(a: float | int, *, b: complex | float | int, **kwargs: float | int | None) -> float:
result: float = a + b + sum(kwargs.values())
return result
"""
with open(filename, "rb") as fp:
source = fp.read()
tree = ast.parse(source)
visitor = AnnotationVisitor()
visitor.visit(tree)
if py_version >= (3, 10):
float_replacement = b"float | int"
complex_replacement = b"complex | float | int"
else:
float_replacement = visitor.union_name + b"[float, int]"
complex_replacement = visitor.union_name + b"[complex, float, int]"
def modify_source():
for lineno, line in enumerate(iter(BytesIO(source).readline, b""), start=1):
if lineno == visitor.import_position[0] and py_version < (3, 10) and visitor.union_import_needed:
yield b"from typing import Union\n"
relevant_float_locations = [
(float_lineno, col_offset, b"float")
for float_lineno, col_offset in visitor.float_ann_locations
if lineno == float_lineno
]
relevant_complex_locations = [
(complex_lineno, col_offset, b"complex")
for complex_lineno, col_offset in visitor.complex_ann_locations
if lineno == complex_lineno
]
# Needs to be in reverse order to avoid column offsets on the same line from becoming incorrect as the line
# is mutated.
relevant_locations = sorted((*relevant_float_locations, *relevant_complex_locations), reverse=True)
for _, rel_col_offset, type_ in relevant_locations:
if type_ == b"float":
curr_replacement = float_replacement
else:
curr_replacement = complex_replacement
line = line[:rel_col_offset] + curr_replacement + line[rel_col_offset + len(type_) :]
yield line
with open(filename, "wb") as fp:
fp.writelines(modify_source())
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("filename", help="The file that will be modified.")
parser.add_argument(
"python_version",
help=(
"The python version to take into account when modifying the file. Determines whether 'Union' or '|' is used. "
"Must be in the format '3.x', with x being a number."
),
)
args = parser.parse_args()
filename = args.filename
py_major, _, py_minor = args.python_version.partition(".")
fix_file(filename, (int(py_major), int(py_minor)))
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment