import os
import time
from typing import TextIO


class Element:
    containing_class: str
    signature: str
    official_name: str
    intermediary_name: str

    def __init__(self, definition: list[str]):
        self.containing_class = definition[1]
        self.signature = definition[2]
        self.official_name = definition[3]
        self.intermediary_name = definition[4]

    def __str__(self) -> str:
        return f"{self.intermediary_name}, {self.official_name}: {self.signature} in {self.containing_class}"


def find_element(signature: str, official_name: str, candidates: list[Element]) -> Element:
    results = list(filter(lambda x: x.signature == signature and x.official_name == official_name, candidates))
    if len(results) != 1:
        print(signature, official_name, [str(element) for element in candidates])
        raise AssertionError()
    return results[0]


def rename_class(clas: str) -> str:
    candidate = classes.get(clas)
    if candidate is None:
        return clas
    return candidate


# some documentation on signatures
# https://jenkins.liteloader.com/view/Other/job/Mixin/javadoc/index.html?org/spongepowered/asm/mixin/injection/InjectionPoint.Selector.html
def remap_signature(signature: str) -> str:
    is_method = signature[0] == "("
    method_part = []
    field_part = signature
    if is_method:
        sections = signature[1:].split(")")
        # method_part = sections[0].split(";")
        started = False
        method_part = list()
        name = list()
        for c in sections[0]:
            if c == "L" and not started:
                started = True
                name.append(c)
            elif c == ";" and started:
                started = False
                method_part.append("".join(name))
                name = list()
            elif started:
                name.append(c)
            else:
                method_part.append(c)
        field_part = sections[1]
        for i, method_name in enumerate(method_part):
            if method_name != "" and method_name[0] == "L":
                method_part[i] = "L" + rename_class(method_name[1:]) + ";"
    if ";" in field_part:
        field_part = "L" + rename_class(field_part[1:-1]) + ";"
    result = ""
    if is_method:
        result += "(" + "".join(method_part) + ")"
    return result + field_part


def remap_file(file: TextIO) -> str:
    output = ""

    # fifo queue for inner classes
    current_class: list[str] = list()
    indent = 0

    for line in file:
        current_line = line.strip().split(" ")

        # go back to outer class after finishing inner class
        if indent >= 1 and current_line[0] != "ARG" and indent != len(line) - len(line.lstrip("\t")):
            indent = len(line) - len(line.lstrip("\t"))
            while len(current_class) > indent:
                current_class.pop()

        match (current_line[0]):
            case "CLASS":
                indent += 1
                current_class.append(current_line[1])
                # already mapped
                if len(current_line) == 2:
                    output += line
                    continue
                parts = line.split(" ")
                parts[1] = classes.get(parts[1])
                output += " ".join(parts)
            case "METHOD" | "FIELD":
                if current_line[1] == "<init>":
                    parts = line.split(" ")
                    parts[2] = remap_signature(parts[2])
                    output += " ".join(parts)
                    continue

                if len(current_line) == 3:
                    output += line
                    continue

                element = find_element(
                    current_line[3],
                    current_line[1],
                    methods_by_class[current_class[indent - 1]] if current_line[0] == "METHOD" else
                    fields_by_class[current_class[indent - 1]]
                )
                parts = line.split(" ")
                parts[1] = element.intermediary_name
                parts[3] = remap_signature(element.signature)
                output += " ".join(parts) + "\n"
            # args and comments
            case _:
                output += line
    return output


classes: dict[str, str] = dict()
methods: set[Element] = set()
fields: set[Element] = set()

with open("intermediary.tiny") as intermediary:
    for line in intermediary.readlines():
        parts = line.strip().split("\t")
        match (parts[0]):
            case "CLASS":
                classes[parts[1]] = parts[2]
            case "METHOD":
                methods.add(Element(parts))
            case "FIELD":
                fields.add(Element(parts))

methods_by_class: dict[str, list[Element]] = {clazz: [] for clazz in classes}
fields_by_class: dict[str, list[Element]] = {clazz: [] for clazz in classes}

for method in methods:
    methods_by_class[method.containing_class].append(method)

for field in fields:
    fields_by_class[field.containing_class].append(field)


def main():
    start = time.perf_counter()
    for subdir, dirs, files in os.walk("mappings"):
        for file in files:
            # if "net\\minecraft\\advancement" not in subdir:
            #     continue
            with open(subdir + "\\" + file, "r") as input:
                dir = f"remapped{subdir.replace("mappings", "")}\\"
                try:
                    os.makedirs(dir)
                except FileExistsError:
                    pass
                with open(dir + file, "w") as output:
                    try:
                        output.write(remap_file(input))
                    except AssertionError:
                        print(file)
                        exit()
    end = time.perf_counter()
    print(end - start)


if __name__ == "__main__":
    main()