Created
August 18, 2023 09:54
-
-
Save harupy/7483a548d2033b74e89c3c64007ce98c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import ast | |
import os | |
import random | |
import subprocess | |
import textwrap | |
import openai | |
class DocstringVisitor(ast.NodeVisitor): | |
def __init__(self): | |
self.docstring_nodes = [] | |
def visit_FunctionDef(self, node: ast.FunctionDef): | |
if ( | |
node.body | |
and isinstance(node.body[0], ast.Expr) | |
and isinstance(node.body[0].value, ast.Str) | |
and ":param" in node.body[0].value.s | |
and not node.name.startswith("_") | |
): | |
self.docstring_nodes.append(node.body[0].value) | |
def transform(docstring: str) -> str: | |
res = openai.ChatCompletion.create( | |
model="gpt-4", | |
messages=[ | |
{ | |
"role": "user", | |
"content": f""" | |
Hi GPT4, I'd like you to rewrite python docstrings in a more readable format. Here's an example: | |
# Before | |
```python | |
\"\"\" | |
This is a docstring | |
:param artifact_path: The run-relative path to which to log model artifacts. | |
:param custom_objects: A Keras ``custom_objects`` dictionary mapping names (strings) to | |
custom classes or functions associated with the Keras model. MLflow saves | |
... | |
:return: This is a return value. | |
...a | |
\"\"\" | |
``` | |
# After (similar to google docstrings, but no need to add types) | |
```python | |
\"\"\" | |
This is a docstring | |
Args: | |
artifact_path: The run-relative path to which to log model artifacts. | |
custom_objects: A Keras ``custom_objects`` dictionary mapping names (strings) | |
to custom classes or functions associated with the Keras model. MLflow saves | |
... | |
Returns: | |
This is a return value. | |
... | |
\"\"\" | |
``` | |
# Transformation Rules: | |
- Be sure to prserve the indentation of the original docstring. | |
- Be sure to preserve the quotes of the original docstring. | |
- Be sure to avoid the line length exceeding 100 characters. | |
- Be sure to only update the parameters and returns sections. | |
- The Returns section should is optional. If the original docstring doesn't have | |
':return:' or ':returns:' entries, then don't add a 'Returns' section. | |
- Be sure to use the following format for the new docstring: | |
```python | |
{{new_docstring}} | |
``` | |
Given these rules, can you rewrite the following docstring? Thanks for your help! | |
```python | |
{docstring} | |
``` | |
""", | |
} | |
], | |
) | |
return res.choices[0].message.content | |
def node_to_char_range(docstring_node: ast.Str, line_lengths: list[int]) -> tuple[int, int]: | |
start = sum(line_lengths[: docstring_node.lineno - 1]) + docstring_node.col_offset | |
node_length = ( | |
(line_lengths[docstring_node.lineno - 1] - docstring_node.col_offset) | |
+ sum(line_lengths[docstring_node.lineno : docstring_node.end_lineno - 1]) | |
+ docstring_node.end_col_offset | |
) | |
return start, start + node_length | |
def extract_code(s: str) -> str | None: | |
import re | |
if m := re.search(r"```python\n(.*)```", s, re.DOTALL): | |
return m.group(1) | |
return None | |
def format_code(code: str, indent: str, opening_quote: str, closing_quote: str) -> str: | |
code = code.strip().lstrip('r"\n').rstrip('" \n') | |
code = textwrap.dedent(code) | |
code = textwrap.indent(code, indent) | |
code = f"{opening_quote}\n{code}\n{indent}{closing_quote}" | |
return code | |
def leading_quote(s: str) -> str: | |
for idx, c in enumerate(s): | |
if c not in ("'", '"', "f", "r"): | |
return s[:idx] | |
raise ValueError("No leading quote found") | |
def trailing_quote(s: str) -> str: | |
for idx, c in enumerate(s[::-1]): | |
if c not in ("'", '"'): | |
return s[-idx:] | |
raise ValueError("No leading quote found") | |
def main(): | |
assert "OPENAI_API_KEY" in os.environ | |
py_files = subprocess.check_output(["git", "ls-files", "mlflow/*.py"]).decode().splitlines() | |
random.shuffle(py_files) | |
for py_file in py_files: | |
with open(py_file) as f: | |
src = f.read() | |
tree = ast.parse(src) | |
visitor = DocstringVisitor() | |
visitor.visit(tree) | |
if not visitor.docstring_nodes: | |
continue | |
lines = src.splitlines(keepends=True) | |
line_lengths = list(map(len, lines)) | |
new_src = str(src) | |
offset = 0 | |
for node in visitor.docstring_nodes: | |
print(f"Transforming {py_file}:{node.lineno}:{node.col_offset + 1}") | |
start, end = node_to_char_range(node, line_lengths) | |
indent = " " * node.col_offset | |
original = src[start:end] | |
transformed = transform(indent + original) | |
code = extract_code(transformed) | |
if code is None: | |
continue | |
code = format_code( | |
code, | |
indent, | |
leading_quote(original), | |
trailing_quote(original), | |
) | |
original_length = end - start | |
new_src = new_src[: (start + offset)] + code + new_src[(end + offset) :] | |
offset += len(code) - original_length | |
with open(py_file, "w") as f: | |
f.write(new_src) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment