import logging
import traceback
import inspect
import os
from colorama import Fore, Back, Style, init as colorama_init

from rich.console import Console
from rich.syntax import Syntax

colorama_init()

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

import io
from rich.console import Console
from rich.syntax import Syntax
from rich.theme import Theme

#def highlight(code: str, language: str = "python", theme: str = "monokai") -> str:
def code_highlight(code, language="python", theme="monokai"):
  # Create a StringIO object to capture the output
  output = io.StringIO()

  custom_theme = Theme({
      "": "white on default",  # Default text
      "keyword": "bold cyan on default",
      "operator": "bold magenta on default",
      "string": "green on default",
      "comment": "dim white on default",
      "number": "bold yellow on default",
      "function": "bold blue on default"
  })

  # Create a console that writes to the StringIO object
  with Console(file=output, force_terminal=True, width=100, theme=custom_theme) as console:
    #syntax = Syntax(code, language, theme=theme)
    syntax = Syntax(code, language, theme="ansi_dark")
    console.print(syntax, end="")

  # Return the captured content and automatically close the StringIO object
  return output.getvalue()


class AdvancedTraceFormatter(logging.Formatter):
  def __init__(self, fmt=None, datefmt=None, style='%', project_root=''):
    super().__init__(fmt, datefmt, style)
    self.project_root = project_root

  def formatException(self, exc_info):
    exc_type, exc_value, exc_traceback = exc_info
    tb_list = traceback.extract_tb(exc_traceback)
    app_tb_list = [line for line in tb_list if '/app/' in line.filename]

    # Reverse the traceback order
    app_tb_list.reverse()

    formatted_lines = []
    context_lines = 3  # Number of lines to show before and after the error line
    for i, trace in enumerate(app_tb_list):
      filename, line_no, func_name, text = trace
      # Remove the project root from the filename
      rel_filename = os.path.relpath(filename, start=self.project_root)
      # Format the line without code first
      formatted_line = f"  {Fore.CYAN}{rel_filename}:{line_no}{Style.RESET_ALL} in {Fore.GREEN}{func_name}{Style.RESET_ALL}"
      formatted_lines.append(formatted_line)

      # Add code context for the first trace line
      if i < 2:
        code_snippet = self.get_code_context(exc_traceback, filename, line_no, context_lines)
        formatted_lines.extend(code_snippet)
        formatted_lines.append(" ")  # Adding a space for better readability

    trace_output = "\n".join(formatted_lines) if formatted_lines else "No traceback available from the /app/ directory."
    error_message = f"\n{Fore.RED}  {exc_type.__name__}:\n  {exc_value}{Style.RESET_ALL}"

    return f"{error_message}\n\n{trace_output}\n"


  def get_code_context(self, exc_traceback, filename, line_no, context_lines):
    try:
      with open(filename, 'r') as file:
        lines = file.readlines()
        start = max(line_no - context_lines - 1, 0)
        end = min(line_no + context_lines, len(lines))
        code_snippet = []

        zero_based_line_no = line_no - 1
        error_line = lines[zero_based_line_no].strip()

        # Extract local variables at the error line
        # FrameInfo objects are tuples with more elements; we only need the frame (first element)
        locals_str = []
        for frame_info in inspect.getinnerframes(exc_traceback):
          if frame_info.lineno == line_no:
            frame = frame_info.frame

            for k, v in frame.f_locals.items():
              # if k is not in 'line' variable, skip
              if k not in error_line:
                continue
              if k != 'self':
                locals_str.append(f"               {Fore.MAGENTA}{Style.DIM}# {k}:{Style.NORMAL} {repr(v)}{Style.RESET_ALL}")

            locals_str = "\n".join(locals_str)


        out_lines = []
        for idx, line in enumerate(lines[start:end], start + 1):
          highlight = Fore.YELLOW if idx == line_no else Fore.RESET

          line_with_syntax = code_highlight(line.rstrip()).rstrip()
          out_lines.append(f"\n{highlight}{Style.DIM}{idx:>6}:{Style.RESET_ALL}{highlight}{line_with_syntax}")

          if idx == line_no:
            out_lines.append(f"\n{locals_str}")

        code_snippet.append("".join(out_lines))

      return code_snippet
    except Exception as e:
      return [f"  Could not retrieve code context: {e}"]

  def format(self, record):
    result = super().format(record)
    if record.exc_info:
      result = self.formatException(record.exc_info) + "\n" + result
    return result