Last active
January 21, 2025 06:29
-
-
Save f0lie/f1b0383fe06113f2e9756435cfbbc8ed to your computer and use it in GitHub Desktop.
Add UTF-8 to open() while maintaining formatting and comments using libCST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# dependencies = [ | |
# "libcst", | |
# ] | |
# /// | |
""" | |
Author: https://github.com/f0lie | |
This script added explicit encoding of UTF-8 to open(...) for both read and writes. | |
This is useful for migrating a codebase to work cross platform to handle unicode data because | |
Windows by default uses CP-1252. The default is supposed to UTF-16 but Windows is Windows. | |
libcst has the advantage of being able to maintain formatting and all comments of code files | |
which is necessary for codebase wide changes. | |
I still suggest running a formatter before applying this code so the diffs can highlight the | |
changes properly and for edge cases of formatting. | |
Run script with `uv run add_explicit_encoding.py` to automagically pull the dependencies. | |
""" | |
import libcst as cst | |
import libcst.matchers as m | |
from pathlib import Path | |
READ_MODES = { | |
quote for mode in ["r", "rt", "r+"] for quote in (f'"{mode}"', f"'{mode}'") | |
} | |
WRITE_MODES = { | |
quote | |
for mode in ["w", "wt", "w+", "a", "a+"] | |
for quote in (f'"{mode}"', f"'{mode}'") | |
} | |
class OpenWriteVisitor(cst.CSTTransformer): | |
"""Adds encoding="utf-8" to open() calls with write modes.""" | |
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: | |
# 1) Match open(...) or file.open(...) | |
open_name_matcher = m.OneOf( | |
m.Name("open"), # direct open(...) | |
m.Attribute(value=m.DoNotCare(), attr=m.Name("open")), # file.open(...) | |
) | |
# 2) Mandatory first argument (filename), positional only | |
filename_arg = m.Arg(value=m.DoNotCare(), keyword=None) | |
# 3) Mandatory second argument for write mode, either positional or mode="..." | |
write_mode_arg = m.Arg( | |
value=m.SimpleString(m.MatchIfTrue(lambda s: s in WRITE_MODES)), keyword=None | |
) | m.Arg( | |
keyword=m.Name("mode"), | |
value=m.SimpleString(m.MatchIfTrue(lambda s: s in WRITE_MODES)), | |
) | |
# 4) Zero or more **keyword** arguments (no extra positional arguments allowed) | |
additional_kw_args = m.ZeroOrMore( | |
m.Arg(keyword=m.Name(), value=m.DoNotCare()) # strictly keyword | |
) | |
# 5) Combine matchers into a single call pattern: | |
# (filename_arg, write_mode_arg, then any number of keyword args) | |
write_open_matcher = m.Call( | |
func=open_name_matcher, | |
args=[filename_arg, write_mode_arg, additional_kw_args], | |
) | |
if not m.matches(updated_node, write_open_matcher): | |
return updated_node | |
has_encoding = any( | |
isinstance(arg.keyword, cst.Name) and arg.keyword.value == "encoding" | |
for arg in updated_node.args | |
) | |
if not has_encoding: | |
new_encoding_arg = cst.Arg( | |
keyword=cst.Name("encoding"), | |
value=cst.SimpleString('"utf-8"'), | |
equal=cst.AssignEqual( | |
whitespace_before=cst.SimpleWhitespace(""), | |
whitespace_after=cst.SimpleWhitespace(""), | |
), | |
) | |
return updated_node.with_changes( | |
args=list(updated_node.args) + [new_encoding_arg] | |
) | |
return updated_node | |
class OpenReadVisitor(cst.CSTTransformer): | |
"""Adds encoding="utf-8" to open() calls with read modes.""" | |
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: | |
# 1) Match open(...) or file.open(...) | |
open_name_matcher = m.OneOf( | |
m.Name("open"), # open(...) | |
m.Attribute(value=m.DoNotCare(), attr=m.Name("open")), # file.open(...) | |
) | |
# 2) Allow mandatory filename argument (anything goes) | |
filename_arg = m.Arg(value=m.DoNotCare(), keyword=None) | |
# 3) Allow optional second argument (only for valid read modes) | |
optional_read_mode_arg = m.ZeroOrOne( | |
m.Arg( | |
value=m.SimpleString(m.MatchIfTrue(lambda s: s in READ_MODES)), keyword=None | |
) | |
| m.Arg( | |
keyword=m.Name("mode"), | |
value=m.SimpleString(m.MatchIfTrue(lambda s: s in READ_MODES)), | |
) | |
) | |
# 4) Ensure no additional positional arguments are allowed after mode | |
additional_kw_args = m.ZeroOrMore( | |
m.Arg(keyword=m.Name(), value=m.SimpleString()) # keyword args only | |
) | |
# 5) Final matcher for the call structure | |
read_open_matcher = m.Call( | |
func=open_name_matcher, | |
args=[filename_arg, optional_read_mode_arg, additional_kw_args], | |
) | |
if not m.matches(updated_node, read_open_matcher): | |
return updated_node | |
has_encoding = any( | |
isinstance(arg.keyword, cst.Name) and arg.keyword.value == "encoding" | |
for arg in updated_node.args | |
) | |
if not has_encoding: | |
new_encoding_arg = cst.Arg( | |
keyword=cst.Name("encoding"), | |
value=cst.SimpleString('"utf-8"'), | |
equal=cst.AssignEqual( | |
whitespace_before=cst.SimpleWhitespace(""), | |
whitespace_after=cst.SimpleWhitespace(""), | |
), | |
) | |
return updated_node.with_changes( | |
args=list(updated_node.args) + [new_encoding_arg] | |
) | |
return updated_node | |
def update_file(file_path: Path, write_visitor: bool = True): | |
try: | |
content = file_path.read_text(encoding="utf-8") | |
# Check if 'open(' exists in the content before parsing | |
if "open(" not in content: | |
return | |
tree = cst.parse_module(content) | |
if write_visitor: | |
visitor = OpenWriteVisitor() | |
else: | |
visitor = OpenReadVisitor() | |
updated_tree = tree.visit(visitor) | |
updated_content = updated_tree.code | |
if updated_content != content: | |
file_path.write_text(updated_content, encoding="utf-8") | |
print(f"Updated: {file_path}") | |
except (cst.ParserSyntaxError, UnicodeDecodeError, IOError) as e: | |
print(f"Skipping {file_path}: {e}") | |
def process_directory(directory: Path, excluded_dirs, write_visitor: bool = True): | |
for file_path in directory.rglob("*.py"): | |
if not bool(set(file_path.parts) & excluded_dirs): | |
update_file(file_path, write_visitor) | |
if __name__ == "__main__": | |
process_directory( | |
Path("your_dictionary/"), | |
{".venv", "resources", "__init__.py"}, | |
write_visitor=False, | |
) | |
exit() | |
code = """\ | |
with open("test.txt", "w") as f: | |
f.write("Hello, world!") | |
with open("test.txt", "wb") as f: | |
f.write("Hello, world!") | |
with open("test.txt", "w", encoding="utf-16") as f: | |
f.write("Hello, world!") | |
with open("test.txt", "a") as f: | |
f.write("Hello, world!") | |
with open("test.txt", "a", encoding="utf-16") as f: | |
f.write("Hello, world!") | |
with open("test.txt") as f: | |
f.read() | |
with open("test.txt", "r") as f: | |
f.read() | |
with open("test.txt", "rb") as f: | |
f.read() | |
with open("test.txt", encoding="utf-16") as f: | |
f.read() | |
with file.open("r") as f: | |
f.read() | |
""" | |
tree = cst.parse_module(code) | |
visitor = OpenWriteVisitor() | |
updated_tree = tree.visit(visitor) | |
updated_content = updated_tree.code | |
print(updated_content) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment