import sys
from enum import Enum
from abc import ABC


class AST(ABC):
    pass


class KnownSymbol(Enum):
    LPAREN = "("
    RPAREN = ")"
    IF = "if"
    DEFINE = "define"


class Operator(Enum):
    ADD = "+"
    SUB = "-"
    MULT = "*"
    DIV = "/"
    LESS = "<"
    GREATER = ">"
    LEQ = "<="
    GEQ = ">="
    EQ = "=="
    NEQ = "!="


class Terminal:
    def __str__(self):
        return str(self.spelling)

    def __repr__(self):
        return f"({self.__class__}, {self.spelling})"


class Number(Terminal):
    spelling: (int, float)

    def __init__(self, spelling: str):
        try:
            self.spelling = int(spelling)
        except ValueError:
            self.spelling = float(spelling)


class Symbol(Terminal):
    spelling: str

    def __init__(self, spelling: str):
        match spelling:
            case "(":
                self.spelling = KnownSymbol.LPAREN
            case ")":
                self.spelling = KnownSymbol.RPAREN
            case "+":
                self.spelling = Operator.ADD
            case "-":
                self.spelling = Operator.SUB
            case "*":
                self.spelling = Operator.MULT
            case "/":
                self.spelling = Operation.DIV
            case "if":
                self.spelling = KnownSymbol.IF
            case "define":
                self.spelling = KnownSymbol.DEFINE
            case _:
                self.spelling = spelling


class Atom:
    atom: (Symbol, Number)

    def __init__(self, spelling: str):
        try:
            self.atom = Number(spelling)
        except ValueError:
            self.atom = Symbol(spelling)

    def __str__(self):
        return str(self.atom)

    def __repr__(self):
        return f"({self.atom.__class__}, {self.atom.spelling})"


Expression = list | Atom
# Expression:
# variable reference: Symbol
# constant literal: Number
# conditional: (if test: Expression consequence: Expression alternative: Expression)
# definition: (define Symbol Expression)
# procedure call: (Symbol Expression) encompasses all builtin operations, defined procedures


class Parser:
    def parse(self, tokens: [str]) -> Expression:
        if len(tokens) == 0:
            raise SyntaxError("Unexpected EOF")
        token = tokens.pop(0)
        if token == KnownSymbol.LPAREN.value:  # nested Expression
            nested_expression = []
            while tokens[0] != KnownSymbol.RPAREN.value:
                nested_expression.append(self.parse(tokens))
                if len(tokens) == 0:
                    raise SyntaxError("Unexpected EOF")
            tokens.pop(0)  # pop matching RPAREN
            return nested_expression
        elif token == KnownSymbol.RPAREN.value:
            raise SyntaxError("Unexpected )")
        else:
            return Atom(token)


class Interpreter:
    @classmethod
    def eval(cls, x: Expression):
        if isinstance(x, Atom):
            return x.atom.spelling
        if isinstance(x, list):
            if x[0].atom.spelling in [o for o in Operator]:
                (operator, lhs, rhs) = x
                match operator.spelling:
                    case Operator.ADD:
                        return cls.eval(lhs) + cls.eval(rhs)
                    case _:
                        print("Not supported yet!")
            if x[0].atom.spelling in [KnownSymbol.IF, KnownSymbol.DEFINE]:
                # Special expressions
                match x[0].spelling:
                    case KnownSymbol.IF:
                        print("Not supported yet!")
                        (_, test, consequence, alt) = x
            else:
                procedure = x[0].spelling
                args = [cls.eval(arg) for arg in x[1:]]
                match procedure:
                    # could generalize procedure calls with Env
                    # generalized procedure calls unpack args, naively leaving them as lists
                    case "list":
                        return args
                    case "first":
                        return list(*args)[0]


class Scanner:
    @classmethod
    def scan(cls, lisp: str):
        tokenized_input = lisp.replace("(", " ( ").replace(")", " ) ").split()
        return tokenized_input


def main():
    args = sys.argv
    lisp = sys.argv[1]
    scanner = Scanner()
    parser = Parser()
    parsed = parser.parse(scanner.scan(lisp))
    print(parsed)


if __name__ == "__main__":
    main()