Created
September 24, 2024 11:31
-
-
Save harupy/6a4451cd6ec1ef99d4424bdba0c59393 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from pathlib import Path | |
import subprocess | |
from typing import List, Dict | |
from dataclasses import dataclass | |
def line_no_to_offset(lines: List[str]) -> Dict[int, int]: | |
col = 0 | |
col_offset = {} | |
for i, line in enumerate(lines): | |
col_offset[i + 1] = col | |
col += len(line) + 1 | |
return col_offset | |
@dataclass | |
class Fix: | |
start: int | |
end: int | |
length: int | |
replacement: str | |
# The difference in length between the original and replacement strings | |
diff: int | |
class Visitor(ast.NodeVisitor): | |
def __init__(self, path: Path): | |
self.path = path | |
self.src = path.read_text() | |
self.col_offset = line_no_to_offset(self.src.split("\n")) | |
self.fixes: List[Fix] = [] | |
self.classes: List[ast.ClassDef] = [] | |
def visit_ClassDef(self, node: ast.ClassDef): | |
self.classes.append(node) | |
self.generic_visit(node) | |
self.classes.pop() | |
def in_class(self, node: ast.AST) -> bool: | |
return bool(self.classes) | |
def report(self, node: ast.AST): | |
print(f"{self.path}:{node.lineno}:{node.col_offset}") | |
print(f"{self.path}:{node.end_lineno}:{node.end_col_offset}") | |
def visit_Call(self, node: ast.Call): | |
""" | |
Find `artifact_path` in `Model.log` calls | |
""" | |
if isinstance(node.func, ast.Attribute): | |
if ( | |
node.func.attr == "log" | |
and isinstance(node.func.value, ast.Name) | |
and node.func.value.id == "Model" | |
): | |
for a in node.keywords: | |
if a.arg == "artifact_path": | |
self.report(a) | |
start = self.col_offset[a.lineno] + a.col_offset | |
end = self.col_offset[a.lineno] + a.end_col_offset | |
f = Fix( | |
start=start, | |
end=end, | |
length=end - start, | |
diff=len("name=name") - (end - start), | |
replacement="name=name", | |
) | |
# print(self.src[f.start : f.end]) | |
self.fixes.append(f) | |
last_arg = next(a for a in node.keywords[::-1] if a.arg) | |
last_value = last_arg.value | |
start = self.col_offset[last_arg.lineno] + last_arg.col_offset | |
end = self.col_offset[last_value.lineno] + last_value.end_col_offset + 1 | |
original = self.src[start:end] | |
new = f""" | |
{original} | |
params=params, | |
tags=tags, | |
model_type=model_type, | |
step=step, | |
model_id=model_id, | |
""" | |
f = Fix( | |
start=start, | |
end=end, | |
length=end - start, | |
diff=len(new) - (end - start), | |
replacement=new, | |
) | |
self.fixes.append(f) | |
self.generic_visit(node) | |
def visit_FunctionDef(self, node: ast.FunctionDef): | |
return self.generic_visit(node) | |
if not self.in_class and node.name == "log_model": | |
for a in node.args.args: | |
if a.arg == "artifact_path": | |
self.report(node) | |
start = self.col_offset[a.lineno] + a.col_offset | |
end = self.col_offset[a.lineno] + a.end_col_offset | |
f = Fix( | |
start=start, | |
end=end, | |
length=end - start, | |
diff=len("name: Optional[str]=None") - (end - start), | |
replacement="name: Optional[str]=None", | |
) | |
self.fixes.append(f) | |
if not any(a == "model_id" for a in node.args.args): | |
last_arg = node.args.args[-1] | |
last_value = node.args.defaults[-1] | |
start = self.col_offset[last_arg.lineno] + last_arg.col_offset | |
end = self.col_offset[last_value.lineno] + last_value.end_col_offset + 1 | |
original = self.src[start:end] | |
new = f""" | |
{original} | |
params: Optional[Dict[str, Any]] = None, | |
tags: Optional[Dict[str, Any]] = None, | |
model_type: Optional[str] = None, | |
step: int = 0, | |
model_id: Optional[str] = None, | |
""" | |
f = Fix( | |
start=start, | |
end=end, | |
length=end - start, | |
diff=len(new) - (end - start), | |
replacement=new, | |
) | |
self.fixes.append(f) | |
self.generic_visit(node) | |
def fix(self) -> str: | |
src = self.src | |
offset = 0 | |
for f in self.fixes: | |
src = src[: f.start + offset] + f.replacement + src[f.end + offset :] | |
offset += f.diff | |
self.path.write_text(src) | |
for p in map( | |
Path, | |
subprocess.check_output( | |
[ | |
"git", | |
"ls-files", | |
"mlflow/*.py", | |
] | |
) | |
.decode() | |
.split(), | |
): | |
v = Visitor(p) | |
with p.open() as f: | |
v.visit(ast.parse(f.read())) | |
if v.fixes: | |
v.fix() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment