import sys
import time
import requests
import itertools
import numpy as np
import pandas as pd

from tqdm import tqdm
from PIL import Image

import torch
import transformers

print("\n## Environment:\n")
print("Python version:", sys.version)
print("Transformers version:", transformers.__version__)
print("Torch version:", torch.__version__)
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

@torch.no_grad()
def get_model_iteration_time(model, inputs, device, min_iterations=100, min_benchmark_time=4, warm_up_steps=10):
    
    with torch.autocast(device):

        for _ in range(warm_up_steps):
            model(**inputs)
        
        timings = []
        iterations = 0
        benchmark_time = 0
        
        torch.cuda.synchronize()
        while benchmark_time < min_benchmark_time or iterations < min_iterations:
            for _ in range(10):

                start_time = time.time()
                
                _ = model(**inputs)
                torch.cuda.synchronize()

                end_time = time.time()
                elapsed_time = end_time - start_time

                # store the time
                timings.append(elapsed_time)

                # update the benchmark time and iterations
                benchmark_time += elapsed_time
                iterations += 1

    median_time = np.median(timings)
    ci = 1.96 * np.array(timings).std() / np.sqrt(len(timings))

    return median_time, ci


def prepare_inputs(processor, image_batch_size=None, text_batch_size=None):
    
    # loading image
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)

    images = [image] * image_batch_size if image_batch_size is not None else None
    texts = ["a photo of 2 cats"] * text_batch_size if text_batch_size is not None else None

    inputs = processor(text=texts, images=images, padding="max_length", return_tensors="pt")

    return inputs


def format_df(df) -> str:
    
    # format float numbers
    for column in df.columns:
        if "CI" in column and "%" in column:
            df[column] = df[column].apply(lambda x: f"±{x:.1f}%")
        else:
            df[column] = df[column].apply(lambda x: f"{x:.3f}")
    
    # rename columns
    columns_mapping = {
        "image_batch_size": "Image batch size",
        "text_batch_size": "Num text labels",
        "Eager": "Eager (s/iter)",
        "FA2": "FA2 (s/iter)",
        "SDPA": "SDPA (s/iter)",
    }
    for column_name, new_column_name in columns_mapping.items():
        if column_name in df.columns:
            df = df.rename(columns={column_name: new_column_name})

    # format as markdown table
    markdown = df.to_markdown(index=False)
    
    return markdown


def benchmark(models_dict, processor, device, image_batch_sizes=None, text_batch_sizes=None, n_iterations=100):
    
    image_batch_sizes = image_batch_sizes or [None]
    text_batch_sizes = text_batch_sizes or [None]

    cases = list(itertools.product(image_batch_sizes, text_batch_sizes))

    results = []

    for image_batch_size, text_batch_size in tqdm(cases):

        inputs = prepare_inputs(processor, image_batch_size, text_batch_size).to(device)
        
        step_results = {}
        if image_batch_size is not None:
            step_results["image_batch_size"] = image_batch_size
        if text_batch_size is not None:
            step_results["text_batch_size"] = text_batch_size

        for attn_name, model in models_dict.items():

            mean_time, confidence_interval = get_model_iteration_time(
                model, inputs, device, min_iterations=n_iterations, min_benchmark_time=4
            )
            step_results[f"{attn_name}"] = mean_time
            confidence_interval_percent = (confidence_interval / mean_time) * 100
            step_results[f"{attn_name} CI, %"] = confidence_interval_percent
        
            if attn_name != "Eager":
                step_results[f"{attn_name} speedup"] = step_results["Eager"] / mean_time
        
        results.append(step_results)
    
    df = pd.DataFrame(results)
    markdown = format_df(df)

    return markdown


def load_models(model_class, checkpoint, dtype, device):
    models_dict = {
        "Eager": model_class.from_pretrained(checkpoint, attn_implementation="eager", torch_dtype=dtype, device_map=device).eval()
    }
    if model_class._supports_flash_attn_2:
        models_dict["FA2"] = model_class.from_pretrained(checkpoint, attn_implementation="flash_attention_2", torch_dtype=dtype, device_map=device).eval()
    if model_class._supports_sdpa:
        models_dict["SDPA"] = model_class.from_pretrained(checkpoint, attn_implementation="sdpa", torch_dtype=dtype, device_map=device).eval()
    return models_dict


if __name__ == "__main__":

    import argparse
    from transformers import AutoProcessor, CLIPModel, CLIPTextModel, CLIPVisionModel

    parser = argparse.ArgumentParser()
    parser.add_argument("--n_iterations", type=int, default=100)
    parser.add_argument("--checkpoint", type=str, default="openai/clip-vit-large-patch14")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--dtype", type=str, default="float16")
    args = parser.parse_args()

    benchmark_multimodal = True
    benchmark_text = True
    benchmark_vision = True

    dtype = {
        "float16": torch.float16,
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
    }[args.dtype]

    processor = AutoProcessor.from_pretrained(args.checkpoint)

    print("\n## Benchmark results\n")

    # ---------------------------
    # Multi-modal model
    # ---------------------------

    if benchmark_multimodal:
        models_dict = load_models(CLIPModel, args.checkpoint, dtype, args.device)
        result = benchmark(
            models_dict,
            processor, 
            image_batch_sizes=[1, 4, 16, 32],
            text_batch_sizes=[4, 16, 32, 64],
            device=args.device,
            n_iterations=args.n_iterations,
        )
        print(f"\n### {CLIPModel.__name__}\n")
        print(result)
        print()

    # ---------------------------
    # Text model
    # ---------------------------

    if benchmark_text:
        models_dict = load_models(CLIPTextModel, args.checkpoint, dtype, args.device)
        result = benchmark(
            models_dict,
            processor, 
            text_batch_sizes=[4, 16, 32, 64, 128],
            device=args.device,
            n_iterations=args.n_iterations,
        )

        print(f"\n### {CLIPTextModel.__name__}\n")
        print(result)
        print()

    # ---------------------------
    # Vision model
    # ---------------------------
    if benchmark_vision:
        models_dict = load_models(CLIPVisionModel, args.checkpoint, dtype, args.device)
        result = benchmark(
            models_dict,
            processor, 
            image_batch_sizes=[1, 4, 16, 32],
            device=args.device,
            n_iterations=args.n_iterations,
        )

        print(f"\n### {CLIPVisionModel.__name__}\n")
        print(result)
        print()