Skip to content

Instantly share code, notes, and snippets.

@dvarrazzo
Last active March 28, 2025 11:25
Show Gist options
  • Save dvarrazzo/a0e9bd4bee4bb64b543332721de294c0 to your computer and use it in GitHub Desktop.
Save dvarrazzo/a0e9bd4bee4bb64b543332721de294c0 to your computer and use it in GitHub Desktop.
A script to convert assignments followed by ifs with a walrus operator
#!/usr/bin/env python
"""Convert a codebase to using the walrus operator.
Hint: in order to explore the AST of a module you can run:
python -m ast path/to/module.py
"""
from __future__ import annotations
import sys
import logging
import subprocess as sp
from typing import Any
from pathlib import Path
from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter
import ast_comments as ast # type: ignore
PROJECT_DIR = Path(__file__).parent.parent
logger = logging.getLogger()
def main() -> int:
opt = parse_cmdline()
logging.basicConfig(level=opt.log_level, format="%(levelname)s %(message)s")
for fp in opt.inputs:
convert(fp, opt.inplace)
return 0
def convert(fp: Path, inplace: bool) -> None:
logger.info("reading %s", fp)
with fp.open() as f:
source = f.read()
tree = ast.parse(source, filename=str(fp))
tree = (w := Walrusifier()).visit(tree)
tree = BlanksInserter().visit(tree)
output = unparse(tree)
if inplace:
if w.changed:
with fp.open("w") as f:
print(output, file=f)
else:
print(output)
sp.check_call(["black", "-q", str(fp)])
class Walrusifier(ast.NodeTransformer): # type: ignore
def __init__(self) -> None:
super().__init__()
self.changed = False
self.bodies: list[list[ast.AST]] = []
self.name_to_replace: ast.Name | None = None
self.value_to_replace: ast.AST | None = None
def generic_visit(self, node: ast.AST) -> ast.AST:
if not isinstance(getattr(node, "body", None), list):
super().generic_visit(node)
return node
body = node.body[:]
for attr in ("orelse", "finalbody"):
if isinstance((otherbody := getattr(node, attr, None)), list):
# Use a None to separate if and else or there could be a mixup
body += [None] + otherbody[:]
self.bodies.append(body)
super().generic_visit(node)
self.bodies.pop()
return node
def visit_Assign(self, node: ast.Assign) -> ast.AST:
# Check if this is a single assignment to a name.
name: str
match node:
case ast.Assign(targets=[ast.Name(id=str(name))]):
pass
case _:
return node
# Check if the assignment is followed by an if.
if not (self.bodies and (body := self.bodies[-1])):
return node
try:
idx = body.index(node)
except ValueError:
breakpoint()
if not (idx + 1 < len(body) and isinstance(body[idx + 1], ast.If)):
return node
# We can only replace the expression if the name is used exactly once
# in the test.
counter = NameCounter(name)
counter.visit(body[idx + 1].test)
if counter.count != 1:
return node
# Write down the node to replace (by the If visitor)
self.name_to_replace = node.targets[0]
self.value_to_replace = node.value
return None
def visit_Name(self, node: ast.Name) -> ast.AST:
if not (self.name_to_replace and node.id == self.name_to_replace.id):
return node
rv = ast.NamedExpr(target=self.name_to_replace, value=self.value_to_replace)
self.name_to_replace = self.value_to_replace = None
self.changed = True
return rv
class NameCounter(ast.NodeVisitor): # type: ignore
def __init__(self, name: str):
self.name = name
self.count = 0
super().__init__()
def visit_Name(self, node: ast.Name) -> None:
if node.id == self.name:
self.count += 1
class BlanksInserter(ast.NodeTransformer): # type: ignore
"""
Restore the missing spaces in the source (or something similar)
"""
def generic_visit(self, node: ast.AST) -> ast.AST:
if isinstance(getattr(node, "body", None), list):
node.body = self._inject_blanks(node.body)
super().generic_visit(node)
return node
def _inject_blanks(self, body: list[ast.Node]) -> list[ast.AST]:
if not body:
return body
new_body = []
before = body[0]
new_body.append(before)
for i in range(1, len(body)):
after = body[i]
nblanks = after.lineno - before.end_lineno - 1
if nblanks > 0:
# Inserting one blank is enough.
blank = ast.Comment(
value="",
inline=False,
lineno=before.end_lineno + 1,
end_lineno=before.end_lineno + 1,
col_offset=0,
end_col_offset=0,
)
new_body.append(blank)
new_body.append(after)
before = after
return new_body
def unparse(tree: ast.AST) -> str:
rv: str = Unparser().visit(tree)
rv = _fix_comment_on_decorators(rv)
return rv
def _fix_comment_on_decorators(source: str) -> str:
"""
Re-associate comments to decorators.
In a case like:
1 @deco # comment
2 def func(x):
3 pass
it seems that Function lineno is 2 instead of 1 (Python 3.10). Because
the Comment lineno is 1, it ends up printed above the function, instead
of inline. This is a problem for '# type: ignore' comments.
Maybe the problem could be fixed in the tree, but this solution is a
simpler way to start.
"""
lines = source.splitlines()
comment_at = None
for i, line in enumerate(lines):
if line.lstrip().startswith("#"):
comment_at = i
elif not line.strip():
pass
elif line.lstrip().startswith("@classmethod"):
if comment_at is not None:
lines[i] = lines[i] + " " + lines[comment_at].lstrip()
lines[comment_at] = ""
else:
comment_at = None
return "\n".join(lines)
class Unparser(ast._Unparser): # type: ignore
"""
Try to emit long strings as multiline.
The normal class only tries to emit docstrings as multiline,
but the resulting source doesn't pass flake8.
"""
# Beware: private method. Tested with in Python 3.10, 3.11.
def _write_constant(self, value: Any) -> None:
if isinstance(value, str) and len(value) > 50:
self._write_str_avoiding_backslashes(value)
else:
super()._write_constant(value)
def parse_cmdline() -> Namespace:
parser = ArgumentParser(
description=__doc__, formatter_class=RawDescriptionHelpFormatter
)
parser.add_argument(
"-L",
"--log-level",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
default="INFO",
help="Logger level.",
)
parser.add_argument(
"inputs",
metavar="FILE",
nargs="+",
type=Path,
help="the files to process",
)
parser.add_argument(
"-i",
"--inplace",
action="store_true",
default=False,
help="change the input files inplace (instead of dumping to stdout)",
)
opt = parser.parse_args()
return opt
if __name__ == "__main__":
sys.exit(main())
@pauloxnet
Copy link

Do you think this approach can be used to create a pyupgrade plugin?

@dvarrazzo
Copy link
Author

I think pyupgrade should propose something like this, yes. I am considering opening them a ticket to suggest to take a look at the psycopg refactoring branch once finished with it.

@pauloxnet
Copy link

I am considering opening them a ticket to suggest to take a look at the psycopg refactoring branch once finished with it.

This will be a great addition to pyupgrade. πŸ‘

@asqui
Copy link

asqui commented Mar 27, 2025

Also interested in seeing this make its way into pyupgrade πŸ‘

(Aside: If it did, would there be a way to run this upgrade and only this upgrade? In the past I've not found a way to cherry-pick individual migrations to run with pyupgrade which has prevented me from using it 😞 )

@dvarrazzo
Copy link
Author

FYI, I have opened asottile/pyupgrade#1006; however I also found asottile/pyupgrade#601 which was closed quickly and without a comment.

@dvarrazzo
Copy link
Author

As you can see from the answers to the above two ticket the author is not exactly open to the proposal and I don't care to bother further about their project.

@asqui
Copy link

asqui commented Mar 27, 2025

What a weird response...

AFAICT your asottile/pyupgrade#1006 was closed as a duplicate of asottile/pyupgrade#601 but then he commented that 601 was also a duplicate, though I'm not sure of what: I searched for issues mentioning "walrus" "assignment expression" and "PEP-572" hoping to find out more but I couldn't find any other relevant issues.

And now I can't ask, because both of these issues are:

Repository owner locked as off topic and limited conversation to collaborators

I guess I'll continue not using pyupgrade πŸ‘ πŸ€·β€β™‚οΈ

@pauloxnet
Copy link

What a weird response...

AFAICT your asottile/pyupgrade#1006 was closed as a duplicate of asottile/pyupgrade#601 but then he commented that 601 was also a duplicate, though I'm not sure of what: I searched for issues mentioning "walrus" "assignment expression" and "PEP-572" hoping to find out more but I couldn't find any other relevant issues.

I guess I'll continue not using pyupgrade πŸ‘ πŸ€·β€β™‚οΈ

I confirm the maintainer of the project is not the most open from other proposal.

I personally started using ruff long ago, instead of pyupgrade.

@dvarrazzo
Copy link
Author

dvarrazzo commented Mar 28, 2025

Note: the script only converts a pattern such as:

var = EXPR
if PRED:
    CODE

into an equivalent:

if ASSIGNMENT_PRED:
    CODE

where PRED is an expression containing precisely one reference to var. This is the most common application of a named expression. For example:

        results = yield from execute(self._pgconn)
        if len(results) != 1:

        # ...becomes...

        if len(results := (yield from execute(self._pgconn))) != 1:

Another common case is the loop-and-a-half:

while True:
    var = EXPR
    if PRED:
        break
    CODE

which can be rewritten as:

while NEGATED_ASSIGNMENT_PRED:
    CODE

if var appears exactly once in PRED. For example:

            while True:
                data = await self._queue.get()
                if not data:
                    break
   
           # ...becomes...

            while data := (await self._queue.get()):

In order to negate the expression, the first of these match rules can be used:

  • not EXPR -> EXPR
  • EXPR1 op EXPR2 -> EXPR1 nop EXPR2 for certain operators that can be listed in a table (e.g. = -> !=, in -> not in...)
  • EXPR -> not EXPR.

In the codebase I wanted to transform, the examples were few enough and easy to find with grep that I didn't add this transformation to the script and refactored by hand. However, if anyone wants to play with AST transformation, I'll leave it as exercise...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment