Skip to content

Instantly share code, notes, and snippets.

@harupy
Created September 24, 2024 11:31
Show Gist options
  • Save harupy/6a4451cd6ec1ef99d4424bdba0c59393 to your computer and use it in GitHub Desktop.
Save harupy/6a4451cd6ec1ef99d4424bdba0c59393 to your computer and use it in GitHub Desktop.
import ast
from pathlib import Path
import subprocess
from typing import List, Dict
from dataclasses import dataclass
def line_no_to_offset(lines: List[str]) -> Dict[int, int]:
col = 0
col_offset = {}
for i, line in enumerate(lines):
col_offset[i + 1] = col
col += len(line) + 1
return col_offset
@dataclass
class Fix:
start: int
end: int
length: int
replacement: str
# The difference in length between the original and replacement strings
diff: int
class Visitor(ast.NodeVisitor):
def __init__(self, path: Path):
self.path = path
self.src = path.read_text()
self.col_offset = line_no_to_offset(self.src.split("\n"))
self.fixes: List[Fix] = []
self.classes: List[ast.ClassDef] = []
def visit_ClassDef(self, node: ast.ClassDef):
self.classes.append(node)
self.generic_visit(node)
self.classes.pop()
def in_class(self, node: ast.AST) -> bool:
return bool(self.classes)
def report(self, node: ast.AST):
print(f"{self.path}:{node.lineno}:{node.col_offset}")
print(f"{self.path}:{node.end_lineno}:{node.end_col_offset}")
def visit_Call(self, node: ast.Call):
"""
Find `artifact_path` in `Model.log` calls
"""
if isinstance(node.func, ast.Attribute):
if (
node.func.attr == "log"
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "Model"
):
for a in node.keywords:
if a.arg == "artifact_path":
self.report(a)
start = self.col_offset[a.lineno] + a.col_offset
end = self.col_offset[a.lineno] + a.end_col_offset
f = Fix(
start=start,
end=end,
length=end - start,
diff=len("name=name") - (end - start),
replacement="name=name",
)
# print(self.src[f.start : f.end])
self.fixes.append(f)
last_arg = next(a for a in node.keywords[::-1] if a.arg)
last_value = last_arg.value
start = self.col_offset[last_arg.lineno] + last_arg.col_offset
end = self.col_offset[last_value.lineno] + last_value.end_col_offset + 1
original = self.src[start:end]
new = f"""
{original}
params=params,
tags=tags,
model_type=model_type,
step=step,
model_id=model_id,
"""
f = Fix(
start=start,
end=end,
length=end - start,
diff=len(new) - (end - start),
replacement=new,
)
self.fixes.append(f)
self.generic_visit(node)
def visit_FunctionDef(self, node: ast.FunctionDef):
return self.generic_visit(node)
if not self.in_class and node.name == "log_model":
for a in node.args.args:
if a.arg == "artifact_path":
self.report(node)
start = self.col_offset[a.lineno] + a.col_offset
end = self.col_offset[a.lineno] + a.end_col_offset
f = Fix(
start=start,
end=end,
length=end - start,
diff=len("name: Optional[str]=None") - (end - start),
replacement="name: Optional[str]=None",
)
self.fixes.append(f)
if not any(a == "model_id" for a in node.args.args):
last_arg = node.args.args[-1]
last_value = node.args.defaults[-1]
start = self.col_offset[last_arg.lineno] + last_arg.col_offset
end = self.col_offset[last_value.lineno] + last_value.end_col_offset + 1
original = self.src[start:end]
new = f"""
{original}
params: Optional[Dict[str, Any]] = None,
tags: Optional[Dict[str, Any]] = None,
model_type: Optional[str] = None,
step: int = 0,
model_id: Optional[str] = None,
"""
f = Fix(
start=start,
end=end,
length=end - start,
diff=len(new) - (end - start),
replacement=new,
)
self.fixes.append(f)
self.generic_visit(node)
def fix(self) -> str:
src = self.src
offset = 0
for f in self.fixes:
src = src[: f.start + offset] + f.replacement + src[f.end + offset :]
offset += f.diff
self.path.write_text(src)
for p in map(
Path,
subprocess.check_output(
[
"git",
"ls-files",
"mlflow/*.py",
]
)
.decode()
.split(),
):
v = Visitor(p)
with p.open() as f:
v.visit(ast.parse(f.read()))
if v.fixes:
v.fix()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment