-
-
Save dvarrazzo/a0e9bd4bee4bb64b543332721de294c0 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python | |
"""Convert a codebase to using the walrus operator. | |
Hint: in order to explore the AST of a module you can run: | |
python -m ast path/to/module.py | |
""" | |
from __future__ import annotations | |
import sys | |
import logging | |
import subprocess as sp | |
from typing import Any | |
from pathlib import Path | |
from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter | |
import ast_comments as ast # type: ignore | |
PROJECT_DIR = Path(__file__).parent.parent | |
logger = logging.getLogger() | |
def main() -> int: | |
opt = parse_cmdline() | |
logging.basicConfig(level=opt.log_level, format="%(levelname)s %(message)s") | |
for fp in opt.inputs: | |
convert(fp, opt.inplace) | |
return 0 | |
def convert(fp: Path, inplace: bool) -> None: | |
logger.info("reading %s", fp) | |
with fp.open() as f: | |
source = f.read() | |
tree = ast.parse(source, filename=str(fp)) | |
tree = (w := Walrusifier()).visit(tree) | |
tree = BlanksInserter().visit(tree) | |
output = unparse(tree) | |
if inplace: | |
if w.changed: | |
with fp.open("w") as f: | |
print(output, file=f) | |
else: | |
print(output) | |
sp.check_call(["black", "-q", str(fp)]) | |
class Walrusifier(ast.NodeTransformer): # type: ignore | |
def __init__(self) -> None: | |
super().__init__() | |
self.changed = False | |
self.bodies: list[list[ast.AST]] = [] | |
self.name_to_replace: ast.Name | None = None | |
self.value_to_replace: ast.AST | None = None | |
def generic_visit(self, node: ast.AST) -> ast.AST: | |
if not isinstance(getattr(node, "body", None), list): | |
super().generic_visit(node) | |
return node | |
body = node.body[:] | |
for attr in ("orelse", "finalbody"): | |
if isinstance((otherbody := getattr(node, attr, None)), list): | |
# Use a None to separate if and else or there could be a mixup | |
body += [None] + otherbody[:] | |
self.bodies.append(body) | |
super().generic_visit(node) | |
self.bodies.pop() | |
return node | |
def visit_Assign(self, node: ast.Assign) -> ast.AST: | |
# Check if this is a single assignment to a name. | |
name: str | |
match node: | |
case ast.Assign(targets=[ast.Name(id=str(name))]): | |
pass | |
case _: | |
return node | |
# Check if the assignment is followed by an if. | |
if not (self.bodies and (body := self.bodies[-1])): | |
return node | |
try: | |
idx = body.index(node) | |
except ValueError: | |
breakpoint() | |
if not (idx + 1 < len(body) and isinstance(body[idx + 1], ast.If)): | |
return node | |
# We can only replace the expression if the name is used exactly once | |
# in the test. | |
counter = NameCounter(name) | |
counter.visit(body[idx + 1].test) | |
if counter.count != 1: | |
return node | |
# Write down the node to replace (by the If visitor) | |
self.name_to_replace = node.targets[0] | |
self.value_to_replace = node.value | |
return None | |
def visit_Name(self, node: ast.Name) -> ast.AST: | |
if not (self.name_to_replace and node.id == self.name_to_replace.id): | |
return node | |
rv = ast.NamedExpr(target=self.name_to_replace, value=self.value_to_replace) | |
self.name_to_replace = self.value_to_replace = None | |
self.changed = True | |
return rv | |
class NameCounter(ast.NodeVisitor): # type: ignore | |
def __init__(self, name: str): | |
self.name = name | |
self.count = 0 | |
super().__init__() | |
def visit_Name(self, node: ast.Name) -> None: | |
if node.id == self.name: | |
self.count += 1 | |
class BlanksInserter(ast.NodeTransformer): # type: ignore | |
""" | |
Restore the missing spaces in the source (or something similar) | |
""" | |
def generic_visit(self, node: ast.AST) -> ast.AST: | |
if isinstance(getattr(node, "body", None), list): | |
node.body = self._inject_blanks(node.body) | |
super().generic_visit(node) | |
return node | |
def _inject_blanks(self, body: list[ast.Node]) -> list[ast.AST]: | |
if not body: | |
return body | |
new_body = [] | |
before = body[0] | |
new_body.append(before) | |
for i in range(1, len(body)): | |
after = body[i] | |
nblanks = after.lineno - before.end_lineno - 1 | |
if nblanks > 0: | |
# Inserting one blank is enough. | |
blank = ast.Comment( | |
value="", | |
inline=False, | |
lineno=before.end_lineno + 1, | |
end_lineno=before.end_lineno + 1, | |
col_offset=0, | |
end_col_offset=0, | |
) | |
new_body.append(blank) | |
new_body.append(after) | |
before = after | |
return new_body | |
def unparse(tree: ast.AST) -> str: | |
rv: str = Unparser().visit(tree) | |
rv = _fix_comment_on_decorators(rv) | |
return rv | |
def _fix_comment_on_decorators(source: str) -> str: | |
""" | |
Re-associate comments to decorators. | |
In a case like: | |
1 @deco # comment | |
2 def func(x): | |
3 pass | |
it seems that Function lineno is 2 instead of 1 (Python 3.10). Because | |
the Comment lineno is 1, it ends up printed above the function, instead | |
of inline. This is a problem for '# type: ignore' comments. | |
Maybe the problem could be fixed in the tree, but this solution is a | |
simpler way to start. | |
""" | |
lines = source.splitlines() | |
comment_at = None | |
for i, line in enumerate(lines): | |
if line.lstrip().startswith("#"): | |
comment_at = i | |
elif not line.strip(): | |
pass | |
elif line.lstrip().startswith("@classmethod"): | |
if comment_at is not None: | |
lines[i] = lines[i] + " " + lines[comment_at].lstrip() | |
lines[comment_at] = "" | |
else: | |
comment_at = None | |
return "\n".join(lines) | |
class Unparser(ast._Unparser): # type: ignore | |
""" | |
Try to emit long strings as multiline. | |
The normal class only tries to emit docstrings as multiline, | |
but the resulting source doesn't pass flake8. | |
""" | |
# Beware: private method. Tested with in Python 3.10, 3.11. | |
def _write_constant(self, value: Any) -> None: | |
if isinstance(value, str) and len(value) > 50: | |
self._write_str_avoiding_backslashes(value) | |
else: | |
super()._write_constant(value) | |
def parse_cmdline() -> Namespace: | |
parser = ArgumentParser( | |
description=__doc__, formatter_class=RawDescriptionHelpFormatter | |
) | |
parser.add_argument( | |
"-L", | |
"--log-level", | |
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], | |
default="INFO", | |
help="Logger level.", | |
) | |
parser.add_argument( | |
"inputs", | |
metavar="FILE", | |
nargs="+", | |
type=Path, | |
help="the files to process", | |
) | |
parser.add_argument( | |
"-i", | |
"--inplace", | |
action="store_true", | |
default=False, | |
help="change the input files inplace (instead of dumping to stdout)", | |
) | |
opt = parser.parse_args() | |
return opt | |
if __name__ == "__main__": | |
sys.exit(main()) |
I think pyupgrade should propose something like this, yes. I am considering opening them a ticket to suggest to take a look at the psycopg refactoring branch once finished with it.
I am considering opening them a ticket to suggest to take a look at the psycopg refactoring branch once finished with it.
This will be a great addition to pyupgrade
. π
Also interested in seeing this make its way into pyupgrade
π
(Aside: If it did, would there be a way to run this upgrade and only this upgrade? In the past I've not found a way to cherry-pick individual migrations to run with pyupgrade
which has prevented me from using it π )
FYI, I have opened asottile/pyupgrade#1006; however I also found asottile/pyupgrade#601 which was closed quickly and without a comment.
As you can see from the answers to the above two ticket the author is not exactly open to the proposal and I don't care to bother further about their project.
What a weird response...
AFAICT your asottile/pyupgrade#1006 was closed as a duplicate of asottile/pyupgrade#601 but then he commented that 601 was also a duplicate, though I'm not sure of what: I searched for issues mentioning "walrus" "assignment expression" and "PEP-572" hoping to find out more but I couldn't find any other relevant issues.
And now I can't ask, because both of these issues are:
Repository owner locked as off topic and limited conversation to collaborators
I guess I'll continue not using pyupgrade π π€·ββοΈ
What a weird response...
AFAICT your asottile/pyupgrade#1006 was closed as a duplicate of asottile/pyupgrade#601 but then he commented that 601 was also a duplicate, though I'm not sure of what: I searched for issues mentioning "walrus" "assignment expression" and "PEP-572" hoping to find out more but I couldn't find any other relevant issues.
I guess I'll continue not using pyupgrade π π€·ββοΈ
I confirm the maintainer of the project is not the most open from other proposal.
I personally started using ruff
long ago, instead of pyupgrade.
Note: the script only converts a pattern such as:
var = EXPR
if PRED:
CODE
into an equivalent:
if ASSIGNMENT_PRED:
CODE
where PRED
is an expression containing precisely one reference to var
. This is the most common application of a named expression. For example:
results = yield from execute(self._pgconn)
if len(results) != 1:
# ...becomes...
if len(results := (yield from execute(self._pgconn))) != 1:
Another common case is the loop-and-a-half:
while True:
var = EXPR
if PRED:
break
CODE
which can be rewritten as:
while NEGATED_ASSIGNMENT_PRED:
CODE
if var
appears exactly once in PRED
. For example:
while True:
data = await self._queue.get()
if not data:
break
# ...becomes...
while data := (await self._queue.get()):
In order to negate the expression, the first of these match rules can be used:
not EXPR
->EXPR
EXPR1 op EXPR2
->EXPR1 nop EXPR2
for certain operators that can be listed in a table (e.g.=
->!=
,in
->not in
...)EXPR
->not EXPR
.
In the codebase I wanted to transform, the examples were few enough and easy to find with grep that I didn't add this transformation to the script and refactored by hand. However, if anyone wants to play with AST transformation, I'll leave it as exercise...
Do you think this approach can be used to create a
pyupgrade
plugin?