#!/usr/bin/env python3
"""
This script uses libcst to automatically update docstrings that might
raise an "invalid escape sequence" syntax warning – converting them to raw
docstrings (adding an r prefix) if they are not already and if they contain
backslashes. It processes module-level, function, and class docstrings.
"""

import re
import sys
import libcst as cst
from libcst import MetadataWrapper, CSTTransformer, FunctionDef, ClassDef, Module, SimpleStatementLine, Expr, SimpleString
from typing import Sequence

# Helper: check if the string literal already uses a raw prefix.
def has_raw_prefix(text: str) -> bool:
    # The string literal may start with multiple possible prefixes; we only care if a raw prefix is present.
    # For example: r'...', R"..." etc.
    # We'll use a regular expression. Note that prefixes may come in any order (see PEP 414), but we'll assume that
    # "r" or "R" is present if the docstring is meant to be raw.
    return bool(re.match(r"(?i)^(?:[urbf]*r[ubf]*)", text))

# Helper: update a SimpleString node to be a raw string.
def make_raw_string(node: SimpleString) -> SimpleString:
    # The node.value is the full literal text, including quotes.
    old_val = node.value
    # If it is already raw, return unchanged.
    if has_raw_prefix(old_val):
        return node
    # Otherwise, add an "r" prefix. We want to preserve the quote style (for example, triple quotes vs single).
    # We use a regex to separate any existing prefixes from the quotes.
    m = re.match(r"^(?P<prefix>[rubfRUBF]*)(?P<quote>['\"]{3}|['\"])", old_val)
    if not m:
        # Should not happen; fallback.
        new_literal = "r" + old_val
    else:
        prefix = m.group("prefix")
        quote = m.group("quote")
        # Remove any lower-case raw specifiers in prefix; add an 'r'.
        # For example, if prefix is "b" or "u" or "f", etc.
        # We want to add "r" along with the other prefixes (making sure not to duplicate).
        new_prefix = prefix + "r"
        # Remove duplicated letters (and normalize order keeping r first preferred)
        # For simplicity, we put "r" at front then add the sorted remainder (excluding any r) 
        others = sorted(ch for ch in new_prefix if ch.lower() != "r")
        new_prefix = "r" + "".join(others)
        # Get the inner content (strip the opening and closing quotes)
        inner = old_val[len(prefix)+len(quote):-len(quote)]
        # We assume that the inner text is meant to be taken literally.
        new_literal = new_prefix + quote + inner + quote
    return node.with_changes(value=new_literal)

# Helper: Given a list of statements, update its docstring (if any) using our transformation.
def update_docstring_in_body(body: Sequence[cst.BaseStatement]) -> Sequence[cst.BaseStatement]:
    if not body:
        return body

    # We expect docstring nodes to appear as a SimpleStatementLine with a single Expr containing a SimpleString.
    first_stmt = body[0]
    if isinstance(first_stmt, SimpleStatementLine) and len(first_stmt.body) == 1:
        maybe_expr = first_stmt.body[0]
        if isinstance(maybe_expr, Expr) and isinstance(maybe_expr.value, SimpleString):
            s = maybe_expr.value
            literal_text = s.value
            # If it is already raw, do nothing.
            if has_raw_prefix(literal_text):
                return body

            # Check if the literal contains a backslash which might be interpreted as an escape.
            if "\\" in literal_text:
                # Update the literal to be raw.
                new_string = make_raw_string(s)
                new_expr = maybe_expr.with_changes(value=new_string)
                new_first_stmt = first_stmt.with_changes(body=[new_expr])
                # Replace the first statement with the new one.
                return [new_first_stmt] + list(body[1:])
    return body

# Define a transformer that will update docstrings in modules,
# class definitions, and function definitions.
class DocstringRawTransformer(cst.CSTTransformer):
    def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
        new_body = update_docstring_in_body(updated_node.body)
        return updated_node.with_changes(body=new_body)

    def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
        new_body = update_docstring_in_body(updated_node.body.body)
        # updated_node.body is a CodeBlock; update its body field.
        new_suite = updated_node.body.with_changes(body=new_body)
        return updated_node.with_changes(body=new_suite)

    def leave_ClassDef(self, original_node: ClassDef, updated_node: ClassDef) -> ClassDef:
        new_body = update_docstring_in_body(updated_node.body.body)
        new_suite = updated_node.body.with_changes(body=new_body)
        return updated_node.with_changes(body=new_suite)

# Provide a main() so this tool can be used from the command line.
def main():
    if len(sys.argv) != 2:
        sys.exit("Usage: {} <python_file_to_fix.py>".format(sys.argv[0]))

    filename = sys.argv[1]

    with open(filename, "r", encoding="utf-8") as f:
        source = f.read()

    try:
        module = cst.parse_module(source)
    except Exception as e:
        sys.exit("Error parsing {}: {}".format(filename, e))

    wrapper = MetadataWrapper(module)
    transformer = DocstringRawTransformer()
    new_module = wrapper.visit(transformer)

    with open(filename, "w", encoding="utf-8") as f:
        f.write(new_module.code)

if __name__ == "__main__":
    main()