Last active
February 18, 2025 07:18
-
-
Save harupy/bc281c07d1a2d0678f97cf519aa81131 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = "==3.10" | |
# dependencies = [ | |
# "pydantic", | |
# "openai", | |
# ] | |
# /// | |
# ruff: noqa: T201 | |
""" | |
How to run (https://gist.github.com/harupy/bc281c07d1a2d0678f97cf519aa81131) | |
uv run https://gist.githubusercontent.com/harupy/bc281c07d1a2d0678f97cf519aa81131/raw/auto_update.py \ | |
tests/xgboost/test_xgboost_model_export.py | |
""" | |
import ast | |
import asyncio | |
import sys | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any, Coroutine, Optional | |
from openai import AsyncOpenAI | |
from pydantic import BaseModel | |
class Response(BaseModel): | |
code: Optional[str] | |
class Visitor(ast.NodeVisitor): | |
def __init__(self, path: Path): | |
self.nodes: list[ast.AST] = [] | |
self.path = path | |
def visit_FunctionDef(self, node: ast.FunctionDef) -> None: | |
if node.name.startswith("test_"): | |
if AnotherVisitor.check(node): | |
self.nodes.append(node) | |
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: | |
self.visit_FunctionDef(node) | |
def location(self, node: ast.AST) -> str: | |
return f"{self.path}:{node.lineno}" | |
class AnotherVisitor(ast.NodeVisitor): | |
def __init__(self): | |
self.has_log_model = False | |
self.has_legacy_model_uri = False | |
@classmethod | |
def check(cls, node: ast.AST) -> bool: | |
visitor = cls() | |
visitor.visit(node) | |
return visitor.should_update | |
@property | |
def should_update(self) -> bool: | |
return self.has_log_model and self.has_legacy_model_uri | |
def _resolve_call(self, node: ast.AST) -> list[str]: | |
""" | |
Resolve a call to a function or method to a list of strings representing the call. | |
For example, `mlflow.get_artifact_uri(...)` will be resolved to `["mlflow", "get_artifact_uri"]`. | |
""" | |
if isinstance(node, ast.Call): | |
return self._resolve_call(node.func) | |
elif isinstance(node, ast.Attribute): | |
return self._resolve_call(node.value) + [node.attr] | |
elif isinstance(node, ast.Name): | |
return [node.id] | |
return [] | |
def visit_JoinedStr(self, node: ast.JoinedStr) -> None: | |
""" | |
An f-string that appears after `log_model` and starts with `f'runs:/{...}` is likely | |
string formatting to get the model URI. | |
""" | |
if self.has_log_model: | |
unparsed = ast.unparse(node) | |
if unparsed.startswith("f'runs:/{"): | |
self.has_legacy_model_uri = True | |
self.generic_visit(node) | |
def visit_Call(self, node: ast.Call) -> None: | |
call = self._resolve_call(node) | |
if len(call) == 3: | |
first, _, third = call | |
if first == "mlflow" and third == "log_model": | |
self.has_log_model = True | |
elif self.has_log_model and call == ["mlflow", "get_artifact_uri"]: | |
self.has_legacy_model_uri = True | |
self.generic_visit(node) | |
PROMPT_TEMPLATE = """ | |
The following code contains `log_model`, and `get_artifact_uri` or string formatting | |
to get the model URI (e.g., `f'runs:/{{...}}/model'`). | |
```python | |
{code} | |
``` | |
This code can be simplified by using the returned value of `log_model`. Here's an example: | |
# Legacy | |
```python | |
with mlflow.start_run() as run: | |
# This example uses sklearn but the flavor can be different (e.g., xgboost, pytorch). | |
mlflow.sklearn.log_model(model, "model") | |
model_uri = mlflow.get_artifact_uri("model") | |
# OR | |
model_uri = f"runs:/{{run.info.run_id}}/model" | |
# This is just an example. `model_uri` might appear/be used in a different form. | |
mlflow.sklearn.load_model(model_uri) | |
``` | |
# Recommended | |
```python | |
with mlflow.start_run() as run: | |
model_info = mlflow.sklearn.log_model(model, "model") | |
# The `model_uri` variable can be removed if not used elsewhere or removing it won't affect | |
# readability. | |
# This is just an example. `model_uri` might appear/be used in a different form. | |
mlflow.sklearn.load_model(model_info.model_uri) | |
``` | |
Can you update the code to use the returned value of `log_model`? | |
If this change is not applicable, respond with `code=null`. | |
The response should only include the code **without any code block markers**. | |
No need to add any import statements. You can assume that `mlflow` is already imported. | |
No need to insert any comments. Make sure to preserve the existing comments. | |
No worries about formatting. I'll take care of that. | |
""" | |
@dataclass | |
class Code: | |
old: str | |
new: Optional[str] | |
async def chat(client: AsyncOpenAI, code: str) -> Code: | |
completion = await client.beta.chat.completions.parse( | |
response_format=Response, | |
model="gpt-4o", | |
messages=[ | |
{"role": "developer", "content": "You are a helpful assistant."}, | |
{ | |
"role": "user", | |
"content": PROMPT_TEMPLATE.format(code=code), | |
}, | |
], | |
) | |
new = completion.choices[0].message.parsed.code | |
return Code(old=code, new=new) | |
async def main(): | |
client = AsyncOpenAI() | |
for f in sys.argv[1:]: | |
file = Path(f) | |
tree = ast.parse(file.read_text()) | |
visitor = Visitor(file) | |
visitor.visit(tree) | |
if not visitor.nodes: | |
print(f"No functions to update in {file}") | |
continue | |
orig_code = file.read_text() | |
new_code = orig_code | |
jobs: list[Coroutine[Any, Any, Code]] = [] | |
print(f"Found {len(visitor.nodes)} functions to update in {file}") | |
for idx, node in enumerate(visitor.nodes): | |
print(f"{idx}. {visitor.location(node)}") | |
lines = orig_code.splitlines()[node.lineno - 1 : node.end_lineno] | |
code = "\n".join(lines) | |
jobs.append(chat(client, code)) | |
results = await asyncio.gather(*jobs) | |
for c in results: | |
if c.new: | |
try: | |
ast.parse(c.new) | |
except SyntaxError as e: | |
print(f"Syntax error: {e}") | |
print(c.new) | |
continue | |
new_code = new_code.replace(c.old, c.new, 1) | |
with file.open("w") as f: | |
f.write(new_code) | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment