from contextlib import contextmanager
from curses.ascii import isalpha
from dataclasses import dataclass
from pathlib import Path
import ast
import re
from typing import Any, Optional, Sequence, Set


@dataclass
class ActionClassInfo:
    ast: ast.ClassDef


@dataclass
class ModuleActionClass(ActionClassInfo):
    pass


@dataclass
class ContextActionClass(ActionClassInfo):
    scope: str


@dataclass(eq=True, frozen=True)
class Node:
    name: str
    text: str

    def as_dot(self) -> str:
        return f'{self.name} [label="{self.text}"];'


@dataclass(eq=True, frozen=True)
class Edge:
    tail: str
    head: str

    def as_dot(self) -> str:
        return f"{self.tail} -> {self.head}"


@dataclass(eq=True, frozen=True)
class Graph:
    nodes: Set[Node]
    edges: Set[Edge]

    def as_dot(self) -> str:
        lines = []
        lines.append(f"digraph {{")
        for node in self.nodes:
            lines.append(f"  {node.as_dot()}")
        for edge in self.edges:
            lines.append(f"  {edge.as_dot()}")
        lines.append(f"}}")
        return "\n".join(lines)

    def without_unconnected_nodes(self):
        connected_node_head_names = set(edge.head for edge in self.edges)
        connected_node_tail_names = set(edge.tail for edge in self.edges)
        connected_node_names = connected_node_head_names | connected_node_tail_names
        nodes = set(node for node in self.nodes if node.name in connected_node_names)
        return Graph(nodes=nodes, edges=self.edges)


@dataclass
class ActionInfo:
    name: str
    action_class_info: ActionClassInfo
    ast: ast.FunctionDef
    path: Path
    baseurl: str

    def __str__(self):
        if type(self.action_class_info) == ContextActionClass:
            return f"{self.action_class_info.scope}.{self.name}"
        else:
            return f"user.{self.name}"

    def url(self) -> str:
        return f"{self.baseurl}/{self.path}#L{self.ast.lineno}-L{self.ast.end_lineno}"

    def anchor(self) -> str:
        if type(self.action_class_info) == ModuleActionClass:
            anchor = f"{self.path}_define_{self}"
        else:
            anchor = f"{self.path}_refine_{self}"
        return "".join(c if c.isalnum() else "_" for c in anchor)

    def is_abstract(self):
        if len(self.ast.body) == 1:
            stmt: ast.stmt = self.ast.body[0]
            if type(stmt) == ast.Expr:
                if type(stmt.value) == ast.Constant:
                    if type(stmt.value.value) == str:
                        return True
        return False

    def is_concrete(self):
        return not (self.is_abstract())

    def is_override(self):
        return type(self.action_class_info) == ContextActionClass


class DependencyGraph(ast.NodeVisitor):
    @staticmethod
    def anchor(any: Any) -> str:
        text: str = any if type(any) == str else str(any)
        return "".join(c if c.isalnum() else "_" for c in text)

    @staticmethod
    def telescope(func: ast.Expr) -> Sequence[str]:
        if type(func) == ast.Name:
            return (func.id,)
        if type(func) == ast.Attribute:
            return (*DependencyGraph.telescope(func.value), func.attr)
        raise ValueError(func)

    @staticmethod
    def guess_action_class_info(cls: ast.ClassDef) -> Optional[ActionClassInfo]:
        for decorator in cls.decorator_list:
            try:  # mod.action_class
                if (
                    re.match("mod", decorator.value.id)
                    and decorator.attr == "action_class"
                ):
                    return ModuleActionClass(ast=cls)
            except AttributeError:
                pass
            try:  # ctx.action_class(scope)
                if (
                    re.match("ctx", decorator.func.value.id)
                    and decorator.func.attr == "action_class"
                ):
                    return ContextActionClass(ast=cls, scope=decorator.args[0].value)
            except AttributeError:
                pass
        return None

    def __init__(self, baseurl: str):
        self.baseurl: str = baseurl
        self.path: Optional[Path] = None
        self.action_class_type: Optional[ActionClassInfo] = None
        self.action_defs: list[ActionInfo] = []
        self.action_uses: list[str] = []
        self.list_uses: list[str] = []
        self.capture_uses: list[str] = []
        self.file_to_action_infos: dict[str, Sequence[ActionInfo]] = {}
        self.file_to_action_uses: dict[str, Sequence[str]] = {}
        self.action_name_to_define: dict[str, ActionInfo] = {}
        self.action_name_to_refines: dict[str, Sequence[ActionInfo]] = {}

    def files(self):
        return self.file_to_action_infos.items()

    def uses(self, file) -> Sequence[str]:
        return self.file_to_action_uses.get(file, ())

    def reset(self):
        self.path = None
        self.action_class_type: Optional[ActionClassInfo] = None
        self.action_defs.clear()
        self.action_uses.clear()
        self.list_uses.clear()
        self.capture_uses.clear()

    def refines(self, any: Any) -> Sequence[ActionInfo]:
        action_name: str = any if type(any) == str else str(any)
        return tuple(self.action_name_to_refines.get(action_name, []))

    def define(self, any: Any) -> Optional[ActionInfo]:
        action_name: str = any if type(any) == str else str(any)
        return self.action_name_to_define.get(action_name, None)

    @contextmanager
    def open(self, path: Path):
        self.path = path
        yield
        self.file_to_action_infos[str(path)] = tuple(self.action_defs)
        self.file_to_action_uses[str(path)] = tuple(self.action_uses)
        self.reset()

    def process_python(self, path: Path):
        with path.open("r") as f:
            tree = ast.parse(f.read(), filename=str(path))
        with self.open(path):
            self.visit(tree)

    ACTION_NAME_PATTERN = re.compile(
        r"(?P<action_name>(([a-z][A-Za-z0-9]*)\.)+([a-z][A-Za-z0-9]*))\([^\)]*\)"
    )
    LIST_NAME_PATTERN = re.compile(
        r"(?P<list_name>\{(([a-z][A-Za-z0-9]*)\.)+([a-z][A-Za-z0-9]*))\}"
    )
    CAPTURE_NAME_PATTERN = re.compile(
        r"(?P<capture_name><(([a-z][A-Za-z0-9]*)\.)+([a-z][A-Za-z0-9]*))>"
    )

    def process_talon(self, path: Path):
        with self.open(path):
            with path.open("r") as f:
                talon_script = f.read()
            for match in DependencyGraph.ACTION_NAME_PATTERN.finditer(talon_script):
                self.action_uses.append(match.group('action_name'))
            for match in DependencyGraph.LIST_NAME_PATTERN.finditer(talon_script):
                self.list_uses.append(match.group('list_name'))
            for match in DependencyGraph.CAPTURE_NAME_PATTERN.finditer(talon_script):
                self.capture_uses.append(match.group('capture_name'))

    def process(self, *paths: Path):
        for path in paths:
            if path.match("**.py"):
                self.process_python(path)
            if path.match("**.talon"):
                self.process_talon(path)

    def visit_Call(self, call: ast.Call):
        try:
            telescope = DependencyGraph.telescope(call.func)
            if telescope[0] == "actions":
                action_name = ".".join(telescope[1:])
                self.action_uses.append(action_name)
        except ValueError:
            pass

    def visit_ClassDef(self, cls: ast.ClassDef):
        self.action_class_type = DependencyGraph.guess_action_class_info(cls)
        self.generic_visit(cls)
        self.action_class_type = None

    def visit_FunctionDef(self, func: ast.FunctionDef):
        if self.action_class_type:
            action = ActionInfo(
                name=func.name,
                action_class_info=self.action_class_type,
                ast=func,
                path=self.path,
                baseurl=self.baseurl,
            )
            action_name = str(action)
            self.action_defs.append(action)
            if action.is_override():
                if not action_name in self.action_name_to_refines:
                    self.action_name_to_refines[action_name] = []
                self.action_name_to_refines[action_name].append(action)
            else:
                self.action_name_to_define[action_name] = action
        self.generic_visit(func)

    def usage_graph(self, unconnected_nodes: bool = True) -> Graph:
        nodes = []
        edges = []
        for file, _ in self.files():
            nodes.append(Node(name=DependencyGraph.anchor(file), text=file))
            for action_name in self.uses(file):
                action_define_info = self.define(action_name)
                if action_define_info:
                    head = DependencyGraph.anchor(file)
                    tail = DependencyGraph.anchor(action_define_info.path)
                    edges.append(Edge(head, tail))
        nodes = set(nodes)
        edges = set(edges)
        graph = Graph(nodes=set(nodes), edges=set(edges))
        return graph if unconnected_nodes else graph.without_unconnected_nodes()

    def context_graph(self, unconnected_nodes: bool = True) -> Graph:
        nodes = []
        edges = []
        for file, action_infos in self.files():
            nodes.append(Node(name=DependencyGraph.anchor(file), text=file))
            for action_info in action_infos:
                if action_info.is_override():
                    action_define_info = self.define(action_info)
                    if action_define_info:
                        head = DependencyGraph.anchor(action_define_info.path)
                        tail = DependencyGraph.anchor(action_info.path)
                        edges.append(Edge(head, tail))
        nodes = set(nodes)
        edges = set(edges)
        graph = Graph(nodes=set(nodes), edges=set(edges))
        return graph if unconnected_nodes else graph.without_unconnected_nodes()

    def context_html(self):
        lines = []
        for file, action_infos in self.files():
            if action_infos:
                lines.append(f'<h1 id="{file}">File <tt>{file}</tt></h1>')
                lines.append(f"<ul>")
                for action_info in sorted(action_infos, key=str):
                    lines.append(f"<li>")
                    if not (action_info.is_override()):
                        lines.append(f'<a id="{action_info.anchor()}">')
                        lines.append(f"Defines <tt>{action_info}</tt>")
                        lines.append(f'(<a href="{action_info.url()}">Source</a>)')
                        lines.append(f"</a>")
                        lines.append(f"<br />")
                        lines.append(f'<p class="doc_string">')
                        doc_string = (
                            ast.get_docstring(action_info.ast)
                            .splitlines()[0]
                            .strip()
                            .rstrip(".")
                        )
                        lines.append(f"<i>{doc_string}.</i>")
                        lines.append(f"</p>")
                        if not action_info.is_override():
                            action_refine_infos = self.refines(str(action_info))
                            if action_refine_infos:
                                lines.append(f"<p>")
                                lines.append("Refined in:")
                                lines.append(f"<ul>")
                                for action_refine_info in action_refine_infos:
                                    lines.append(f"<li>")
                                    href = f"#{action_refine_info.anchor()}"
                                    lines.append(f'<a href="{href}">')
                                    lines.append(f"<tt>{action_refine_info.path}</tt>")
                                    lines.append(f"</a>")
                                    lines.append(f"</li>")
                                lines.append(f"</ul>")
                                lines.append(f"</p>")
                    else:
                        action_define_info = self.define(str(action_info))
                        lines.append(f'<a id="{action_info.anchor()}">')
                        name = f"<tt>{action_define_info}</tt>"
                        if action_define_info:
                            href = f"#{action_define_info.anchor()}"
                            lines.append(f'Refines <a href="{href}">{name}</a>')
                        else:
                            lines.append(f"Refines {name}")
                        lines.append(f'(<a href="{action_info.url()}">Source</a>)')
                        lines.append(f"</a>")
                    lines.append(f"</li>")
                lines.append(f"</ul>")
        return "\n".join(lines)


dg = DependencyGraph(baseurl="https://github.com/knausj85/knausj_talon/blob/main")
dg.process(*Path(".").glob("**/*.py"), *Path(".").glob("**/*.talon"))

with open("actions.md", "w") as f:
    f.write(dg.context_html())

with open("context_graph.dot", "w") as f:
    f.write(dg.context_graph(unconnected_nodes=False).as_dot())

with open("usage_graph.dot", "w") as f:
    f.write(dg.usage_graph(unconnected_nodes=False).as_dot())