Skip to content

Instantly share code, notes, and snippets.

@harupy
Last active February 18, 2025 07:18
Show Gist options
  • Save harupy/bc281c07d1a2d0678f97cf519aa81131 to your computer and use it in GitHub Desktop.
Save harupy/bc281c07d1a2d0678f97cf519aa81131 to your computer and use it in GitHub Desktop.
# /// script
# requires-python = "==3.10"
# dependencies = [
# "pydantic",
# "openai",
# ]
# ///
# ruff: noqa: T201
"""
How to run (https://gist.github.com/harupy/bc281c07d1a2d0678f97cf519aa81131)
uv run https://gist.githubusercontent.com/harupy/bc281c07d1a2d0678f97cf519aa81131/raw/auto_update.py \
tests/xgboost/test_xgboost_model_export.py
"""
import ast
import asyncio
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Coroutine, Optional
from openai import AsyncOpenAI
from pydantic import BaseModel
class Response(BaseModel):
code: Optional[str]
class Visitor(ast.NodeVisitor):
def __init__(self, path: Path):
self.nodes: list[ast.AST] = []
self.path = path
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name.startswith("test_"):
if AnotherVisitor.check(node):
self.nodes.append(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.visit_FunctionDef(node)
def location(self, node: ast.AST) -> str:
return f"{self.path}:{node.lineno}"
class AnotherVisitor(ast.NodeVisitor):
def __init__(self):
self.has_log_model = False
self.has_legacy_model_uri = False
@classmethod
def check(cls, node: ast.AST) -> bool:
visitor = cls()
visitor.visit(node)
return visitor.should_update
@property
def should_update(self) -> bool:
return self.has_log_model and self.has_legacy_model_uri
def _resolve_call(self, node: ast.AST) -> list[str]:
"""
Resolve a call to a function or method to a list of strings representing the call.
For example, `mlflow.get_artifact_uri(...)` will be resolved to `["mlflow", "get_artifact_uri"]`.
"""
if isinstance(node, ast.Call):
return self._resolve_call(node.func)
elif isinstance(node, ast.Attribute):
return self._resolve_call(node.value) + [node.attr]
elif isinstance(node, ast.Name):
return [node.id]
return []
def visit_JoinedStr(self, node: ast.JoinedStr) -> None:
"""
An f-string that appears after `log_model` and starts with `f'runs:/{...}` is likely
string formatting to get the model URI.
"""
if self.has_log_model:
unparsed = ast.unparse(node)
if unparsed.startswith("f'runs:/{"):
self.has_legacy_model_uri = True
self.generic_visit(node)
def visit_Call(self, node: ast.Call) -> None:
call = self._resolve_call(node)
if len(call) == 3:
first, _, third = call
if first == "mlflow" and third == "log_model":
self.has_log_model = True
elif self.has_log_model and call == ["mlflow", "get_artifact_uri"]:
self.has_legacy_model_uri = True
self.generic_visit(node)
PROMPT_TEMPLATE = """
The following code contains `log_model`, and `get_artifact_uri` or string formatting
to get the model URI (e.g., `f'runs:/{{...}}/model'`).
```python
{code}
```
This code can be simplified by using the returned value of `log_model`. Here's an example:
# Legacy
```python
with mlflow.start_run() as run:
# This example uses sklearn but the flavor can be different (e.g., xgboost, pytorch).
mlflow.sklearn.log_model(model, "model")
model_uri = mlflow.get_artifact_uri("model")
# OR
model_uri = f"runs:/{{run.info.run_id}}/model"
# This is just an example. `model_uri` might appear/be used in a different form.
mlflow.sklearn.load_model(model_uri)
```
# Recommended
```python
with mlflow.start_run() as run:
model_info = mlflow.sklearn.log_model(model, "model")
# The `model_uri` variable can be removed if not used elsewhere or removing it won't affect
# readability.
# This is just an example. `model_uri` might appear/be used in a different form.
mlflow.sklearn.load_model(model_info.model_uri)
```
Can you update the code to use the returned value of `log_model`?
If this change is not applicable, respond with `code=null`.
The response should only include the code **without any code block markers**.
No need to add any import statements. You can assume that `mlflow` is already imported.
No need to insert any comments. Make sure to preserve the existing comments.
No worries about formatting. I'll take care of that.
"""
@dataclass
class Code:
old: str
new: Optional[str]
async def chat(client: AsyncOpenAI, code: str) -> Code:
completion = await client.beta.chat.completions.parse(
response_format=Response,
model="gpt-4o",
messages=[
{"role": "developer", "content": "You are a helpful assistant."},
{
"role": "user",
"content": PROMPT_TEMPLATE.format(code=code),
},
],
)
new = completion.choices[0].message.parsed.code
return Code(old=code, new=new)
async def main():
client = AsyncOpenAI()
for f in sys.argv[1:]:
file = Path(f)
tree = ast.parse(file.read_text())
visitor = Visitor(file)
visitor.visit(tree)
if not visitor.nodes:
print(f"No functions to update in {file}")
continue
orig_code = file.read_text()
new_code = orig_code
jobs: list[Coroutine[Any, Any, Code]] = []
print(f"Found {len(visitor.nodes)} functions to update in {file}")
for idx, node in enumerate(visitor.nodes):
print(f"{idx}. {visitor.location(node)}")
lines = orig_code.splitlines()[node.lineno - 1 : node.end_lineno]
code = "\n".join(lines)
jobs.append(chat(client, code))
results = await asyncio.gather(*jobs)
for c in results:
if c.new:
try:
ast.parse(c.new)
except SyntaxError as e:
print(f"Syntax error: {e}")
print(c.new)
continue
new_code = new_code.replace(c.old, c.new, 1)
with file.open("w") as f:
f.write(new_code)
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment