Skip to content

Instantly share code, notes, and snippets.

@f0lie
Last active January 21, 2025 06:29
Show Gist options
  • Save f0lie/f1b0383fe06113f2e9756435cfbbc8ed to your computer and use it in GitHub Desktop.
Save f0lie/f1b0383fe06113f2e9756435cfbbc8ed to your computer and use it in GitHub Desktop.
Add UTF-8 to open() while maintaining formatting and comments using libCST
# /// script
# dependencies = [
# "libcst",
# ]
# ///
"""
Author: https://github.com/f0lie
This script added explicit encoding of UTF-8 to open(...) for both read and writes.
This is useful for migrating a codebase to work cross platform to handle unicode data because
Windows by default uses CP-1252. The default is supposed to UTF-16 but Windows is Windows.
libcst has the advantage of being able to maintain formatting and all comments of code files
which is necessary for codebase wide changes.
I still suggest running a formatter before applying this code so the diffs can highlight the
changes properly and for edge cases of formatting.
Run script with `uv run add_explicit_encoding.py` to automagically pull the dependencies.
"""
import libcst as cst
import libcst.matchers as m
from pathlib import Path
READ_MODES = {
quote for mode in ["r", "rt", "r+"] for quote in (f'"{mode}"', f"'{mode}'")
}
WRITE_MODES = {
quote
for mode in ["w", "wt", "w+", "a", "a+"]
for quote in (f'"{mode}"', f"'{mode}'")
}
class OpenWriteVisitor(cst.CSTTransformer):
"""Adds encoding="utf-8" to open() calls with write modes."""
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
# 1) Match open(...) or file.open(...)
open_name_matcher = m.OneOf(
m.Name("open"), # direct open(...)
m.Attribute(value=m.DoNotCare(), attr=m.Name("open")), # file.open(...)
)
# 2) Mandatory first argument (filename), positional only
filename_arg = m.Arg(value=m.DoNotCare(), keyword=None)
# 3) Mandatory second argument for write mode, either positional or mode="..."
write_mode_arg = m.Arg(
value=m.SimpleString(m.MatchIfTrue(lambda s: s in WRITE_MODES)), keyword=None
) | m.Arg(
keyword=m.Name("mode"),
value=m.SimpleString(m.MatchIfTrue(lambda s: s in WRITE_MODES)),
)
# 4) Zero or more **keyword** arguments (no extra positional arguments allowed)
additional_kw_args = m.ZeroOrMore(
m.Arg(keyword=m.Name(), value=m.DoNotCare()) # strictly keyword
)
# 5) Combine matchers into a single call pattern:
# (filename_arg, write_mode_arg, then any number of keyword args)
write_open_matcher = m.Call(
func=open_name_matcher,
args=[filename_arg, write_mode_arg, additional_kw_args],
)
if not m.matches(updated_node, write_open_matcher):
return updated_node
has_encoding = any(
isinstance(arg.keyword, cst.Name) and arg.keyword.value == "encoding"
for arg in updated_node.args
)
if not has_encoding:
new_encoding_arg = cst.Arg(
keyword=cst.Name("encoding"),
value=cst.SimpleString('"utf-8"'),
equal=cst.AssignEqual(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
),
)
return updated_node.with_changes(
args=list(updated_node.args) + [new_encoding_arg]
)
return updated_node
class OpenReadVisitor(cst.CSTTransformer):
"""Adds encoding="utf-8" to open() calls with read modes."""
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
# 1) Match open(...) or file.open(...)
open_name_matcher = m.OneOf(
m.Name("open"), # open(...)
m.Attribute(value=m.DoNotCare(), attr=m.Name("open")), # file.open(...)
)
# 2) Allow mandatory filename argument (anything goes)
filename_arg = m.Arg(value=m.DoNotCare(), keyword=None)
# 3) Allow optional second argument (only for valid read modes)
optional_read_mode_arg = m.ZeroOrOne(
m.Arg(
value=m.SimpleString(m.MatchIfTrue(lambda s: s in READ_MODES)), keyword=None
)
| m.Arg(
keyword=m.Name("mode"),
value=m.SimpleString(m.MatchIfTrue(lambda s: s in READ_MODES)),
)
)
# 4) Ensure no additional positional arguments are allowed after mode
additional_kw_args = m.ZeroOrMore(
m.Arg(keyword=m.Name(), value=m.SimpleString()) # keyword args only
)
# 5) Final matcher for the call structure
read_open_matcher = m.Call(
func=open_name_matcher,
args=[filename_arg, optional_read_mode_arg, additional_kw_args],
)
if not m.matches(updated_node, read_open_matcher):
return updated_node
has_encoding = any(
isinstance(arg.keyword, cst.Name) and arg.keyword.value == "encoding"
for arg in updated_node.args
)
if not has_encoding:
new_encoding_arg = cst.Arg(
keyword=cst.Name("encoding"),
value=cst.SimpleString('"utf-8"'),
equal=cst.AssignEqual(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
),
)
return updated_node.with_changes(
args=list(updated_node.args) + [new_encoding_arg]
)
return updated_node
def update_file(file_path: Path, write_visitor: bool = True):
try:
content = file_path.read_text(encoding="utf-8")
# Check if 'open(' exists in the content before parsing
if "open(" not in content:
return
tree = cst.parse_module(content)
if write_visitor:
visitor = OpenWriteVisitor()
else:
visitor = OpenReadVisitor()
updated_tree = tree.visit(visitor)
updated_content = updated_tree.code
if updated_content != content:
file_path.write_text(updated_content, encoding="utf-8")
print(f"Updated: {file_path}")
except (cst.ParserSyntaxError, UnicodeDecodeError, IOError) as e:
print(f"Skipping {file_path}: {e}")
def process_directory(directory: Path, excluded_dirs, write_visitor: bool = True):
for file_path in directory.rglob("*.py"):
if not bool(set(file_path.parts) & excluded_dirs):
update_file(file_path, write_visitor)
if __name__ == "__main__":
process_directory(
Path("your_dictionary/"),
{".venv", "resources", "__init__.py"},
write_visitor=False,
)
exit()
code = """\
with open("test.txt", "w") as f:
f.write("Hello, world!")
with open("test.txt", "wb") as f:
f.write("Hello, world!")
with open("test.txt", "w", encoding="utf-16") as f:
f.write("Hello, world!")
with open("test.txt", "a") as f:
f.write("Hello, world!")
with open("test.txt", "a", encoding="utf-16") as f:
f.write("Hello, world!")
with open("test.txt") as f:
f.read()
with open("test.txt", "r") as f:
f.read()
with open("test.txt", "rb") as f:
f.read()
with open("test.txt", encoding="utf-16") as f:
f.read()
with file.open("r") as f:
f.read()
"""
tree = cst.parse_module(code)
visitor = OpenWriteVisitor()
updated_tree = tree.visit(visitor)
updated_content = updated_tree.code
print(updated_content)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment