Last active
March 28, 2025 11:25
-
-
Save dvarrazzo/a0e9bd4bee4bb64b543332721de294c0 to your computer and use it in GitHub Desktop.
A script to convert assignments followed by ifs with a walrus operator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Convert a codebase to using the walrus operator. | |
Hint: in order to explore the AST of a module you can run: | |
python -m ast path/to/module.py | |
""" | |
from __future__ import annotations | |
import sys | |
import logging | |
import subprocess as sp | |
from typing import Any | |
from pathlib import Path | |
from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter | |
import ast_comments as ast # type: ignore | |
PROJECT_DIR = Path(__file__).parent.parent | |
logger = logging.getLogger() | |
def main() -> int: | |
opt = parse_cmdline() | |
logging.basicConfig(level=opt.log_level, format="%(levelname)s %(message)s") | |
for fp in opt.inputs: | |
convert(fp, opt.inplace) | |
return 0 | |
def convert(fp: Path, inplace: bool) -> None: | |
logger.info("reading %s", fp) | |
with fp.open() as f: | |
source = f.read() | |
tree = ast.parse(source, filename=str(fp)) | |
tree = (w := Walrusifier()).visit(tree) | |
tree = BlanksInserter().visit(tree) | |
output = unparse(tree) | |
if inplace: | |
if w.changed: | |
with fp.open("w") as f: | |
print(output, file=f) | |
else: | |
print(output) | |
sp.check_call(["black", "-q", str(fp)]) | |
class Walrusifier(ast.NodeTransformer): # type: ignore | |
def __init__(self) -> None: | |
super().__init__() | |
self.changed = False | |
self.bodies: list[list[ast.AST]] = [] | |
self.name_to_replace: ast.Name | None = None | |
self.value_to_replace: ast.AST | None = None | |
def generic_visit(self, node: ast.AST) -> ast.AST: | |
if not isinstance(getattr(node, "body", None), list): | |
super().generic_visit(node) | |
return node | |
body = node.body[:] | |
for attr in ("orelse", "finalbody"): | |
if isinstance((otherbody := getattr(node, attr, None)), list): | |
# Use a None to separate if and else or there could be a mixup | |
body += [None] + otherbody[:] | |
self.bodies.append(body) | |
super().generic_visit(node) | |
self.bodies.pop() | |
return node | |
def visit_Assign(self, node: ast.Assign) -> ast.AST: | |
# Check if this is a single assignment to a name. | |
name: str | |
match node: | |
case ast.Assign(targets=[ast.Name(id=str(name))]): | |
pass | |
case _: | |
return node | |
# Check if the assignment is followed by an if. | |
if not (self.bodies and (body := self.bodies[-1])): | |
return node | |
try: | |
idx = body.index(node) | |
except ValueError: | |
breakpoint() | |
if not (idx + 1 < len(body) and isinstance(body[idx + 1], ast.If)): | |
return node | |
# We can only replace the expression if the name is used exactly once | |
# in the test. | |
counter = NameCounter(name) | |
counter.visit(body[idx + 1].test) | |
if counter.count != 1: | |
return node | |
# Write down the node to replace (by the If visitor) | |
self.name_to_replace = node.targets[0] | |
self.value_to_replace = node.value | |
return None | |
def visit_Name(self, node: ast.Name) -> ast.AST: | |
if not (self.name_to_replace and node.id == self.name_to_replace.id): | |
return node | |
rv = ast.NamedExpr(target=self.name_to_replace, value=self.value_to_replace) | |
self.name_to_replace = self.value_to_replace = None | |
self.changed = True | |
return rv | |
class NameCounter(ast.NodeVisitor): # type: ignore | |
def __init__(self, name: str): | |
self.name = name | |
self.count = 0 | |
super().__init__() | |
def visit_Name(self, node: ast.Name) -> None: | |
if node.id == self.name: | |
self.count += 1 | |
class BlanksInserter(ast.NodeTransformer): # type: ignore | |
""" | |
Restore the missing spaces in the source (or something similar) | |
""" | |
def generic_visit(self, node: ast.AST) -> ast.AST: | |
if isinstance(getattr(node, "body", None), list): | |
node.body = self._inject_blanks(node.body) | |
super().generic_visit(node) | |
return node | |
def _inject_blanks(self, body: list[ast.Node]) -> list[ast.AST]: | |
if not body: | |
return body | |
new_body = [] | |
before = body[0] | |
new_body.append(before) | |
for i in range(1, len(body)): | |
after = body[i] | |
nblanks = after.lineno - before.end_lineno - 1 | |
if nblanks > 0: | |
# Inserting one blank is enough. | |
blank = ast.Comment( | |
value="", | |
inline=False, | |
lineno=before.end_lineno + 1, | |
end_lineno=before.end_lineno + 1, | |
col_offset=0, | |
end_col_offset=0, | |
) | |
new_body.append(blank) | |
new_body.append(after) | |
before = after | |
return new_body | |
def unparse(tree: ast.AST) -> str: | |
rv: str = Unparser().visit(tree) | |
rv = _fix_comment_on_decorators(rv) | |
return rv | |
def _fix_comment_on_decorators(source: str) -> str: | |
""" | |
Re-associate comments to decorators. | |
In a case like: | |
1 @deco # comment | |
2 def func(x): | |
3 pass | |
it seems that Function lineno is 2 instead of 1 (Python 3.10). Because | |
the Comment lineno is 1, it ends up printed above the function, instead | |
of inline. This is a problem for '# type: ignore' comments. | |
Maybe the problem could be fixed in the tree, but this solution is a | |
simpler way to start. | |
""" | |
lines = source.splitlines() | |
comment_at = None | |
for i, line in enumerate(lines): | |
if line.lstrip().startswith("#"): | |
comment_at = i | |
elif not line.strip(): | |
pass | |
elif line.lstrip().startswith("@classmethod"): | |
if comment_at is not None: | |
lines[i] = lines[i] + " " + lines[comment_at].lstrip() | |
lines[comment_at] = "" | |
else: | |
comment_at = None | |
return "\n".join(lines) | |
class Unparser(ast._Unparser): # type: ignore | |
""" | |
Try to emit long strings as multiline. | |
The normal class only tries to emit docstrings as multiline, | |
but the resulting source doesn't pass flake8. | |
""" | |
# Beware: private method. Tested with in Python 3.10, 3.11. | |
def _write_constant(self, value: Any) -> None: | |
if isinstance(value, str) and len(value) > 50: | |
self._write_str_avoiding_backslashes(value) | |
else: | |
super()._write_constant(value) | |
def parse_cmdline() -> Namespace: | |
parser = ArgumentParser( | |
description=__doc__, formatter_class=RawDescriptionHelpFormatter | |
) | |
parser.add_argument( | |
"-L", | |
"--log-level", | |
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], | |
default="INFO", | |
help="Logger level.", | |
) | |
parser.add_argument( | |
"inputs", | |
metavar="FILE", | |
nargs="+", | |
type=Path, | |
help="the files to process", | |
) | |
parser.add_argument( | |
"-i", | |
"--inplace", | |
action="store_true", | |
default=False, | |
help="change the input files inplace (instead of dumping to stdout)", | |
) | |
opt = parser.parse_args() | |
return opt | |
if __name__ == "__main__": | |
sys.exit(main()) |
Note: the script only converts a pattern such as:
var = EXPR
if PRED:
CODE
into an equivalent:
if ASSIGNMENT_PRED:
CODE
where PRED
is an expression containing precisely one reference to var
. This is the most common application of a named expression. For example:
results = yield from execute(self._pgconn)
if len(results) != 1:
# ...becomes...
if len(results := (yield from execute(self._pgconn))) != 1:
Another common case is the loop-and-a-half:
while True:
var = EXPR
if PRED:
break
CODE
which can be rewritten as:
while NEGATED_ASSIGNMENT_PRED:
CODE
if var
appears exactly once in PRED
. For example:
while True:
data = await self._queue.get()
if not data:
break
# ...becomes...
while data := (await self._queue.get()):
In order to negate the expression, the first of these match rules can be used:
not EXPR
->EXPR
EXPR1 op EXPR2
->EXPR1 nop EXPR2
for certain operators that can be listed in a table (e.g.=
->!=
,in
->not in
...)EXPR
->not EXPR
.
In the codebase I wanted to transform, the examples were few enough and easy to find with grep that I didn't add this transformation to the script and refactored by hand. However, if anyone wants to play with AST transformation, I'll leave it as exercise...
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I confirm the maintainer of the project is not the most open from other proposal.
I personally started using
ruff
long ago, instead of pyupgrade.