import sys
import time
import random
import asyncio
import argparse
from typing import Callable, Coroutine

import httpx
import aiohttp
import requests


RESP_KEY = 'slept_for'
URL = DEFAULT_URL = 'http://localhost:8080/sleep' if len(sys.argv) != 2 else sys.argv[1]


class SingleFetch:

    @staticmethod
    async def httpx(client: httpx.AsyncClient, secs: float) -> float:
        resp = await client.post(URL, data={'time': secs})
        resp.raise_for_status()
        return resp.json()[RESP_KEY]

    @staticmethod
    async def aiohttp(session: aiohttp.ClientSession, secs: float) -> float:
        async with session.post(URL, data={'time': secs}) as response:
            if response.status != 200:
                raise Exception(await response.text())
            return (await response.json())[RESP_KEY]

    @staticmethod
    async def requests(session: requests.Session, secs: float) -> float:
        resp = session.post(URL, data={'time': secs}, timeout=secs + 3)  # stupid requests
        resp.raise_for_status()
        return resp.json()[RESP_KEY]


class ManyFetch:

    MIN = 1.3
    MAX = 5.7

    FloatsOrExceptions = tuple[float | Exception, ...]

    # First element is total "actual" time taken
    # Second element is a tuple of floats or exceptions
    # where each float is the time to be slept for in that
    # particular request. If an exception is raised, it is
    # stored in place of the float.
    ManyFetchReturnType = tuple[float, FloatsOrExceptions]

    def __init__(self, n: int = 100):
        self.n = n
        self.nums = tuple(self.randfloat() for _ in range(n))
        self.timeout = max(self.nums) + 3  # 3 seconds buffer

    @staticmethod
    async def perf_it(func: Callable[..., Coroutine[None, None, FloatsOrExceptions]]) -> ManyFetchReturnType:
        start = time.perf_counter()
        result = await func()
        return round(time.perf_counter() - start, 3), result

    @staticmethod
    def randfloat(a: float = MIN, b: float = MAX) -> float:
        return round(random.uniform(a, b), 3)

    async def _httpx(self) -> FloatsOrExceptions:
        async with httpx.AsyncClient(timeout=self.timeout) as client:
            tasks = [SingleFetch.httpx(client, num) for num in self.nums]
            return tuple(await asyncio.gather(*tasks, return_exceptions=True))

    async def _aiohttp(self) -> FloatsOrExceptions:
        async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(self.timeout)) as session:
            tasks = [SingleFetch.aiohttp(session, num) for num in self.nums]
            return tuple(await asyncio.gather(*tasks, return_exceptions=True))

    async def _requests(self) -> FloatsOrExceptions:
        with requests.Session() as session:
            tasks = [SingleFetch.requests(session, num) for num in self.nums]
            return tuple(await asyncio.gather(*tasks, return_exceptions=True))

    async def httpx(self) -> ManyFetchReturnType:
        return await self.perf_it(self._httpx)

    async def aiohttp(self) -> ManyFetchReturnType:
        return await self.perf_it(self._aiohttp)

    async def requests(self) -> ManyFetchReturnType:
        return await self.perf_it(self._requests)


def create_parser() -> argparse.ArgumentParser:

    # Strategy pattern all the mfing way.
    choices: tuple[str, ...] = tuple(filter(lambda c: not c.startswith('__'), dir(SingleFetch)))

    parser = argparse.ArgumentParser(description='Python asyncio fun')

    parser.add_argument('url', nargs='?', default=DEFAULT_URL, help='URL to send requests to')
    parser.add_argument('-m', '--mode', choices=choices, default='httpx', help='Library to use')
    parser.add_argument('-n', '--number', type=int, default=100, help='Number of requests to send')
    parser.add_argument('-s', '--seed', type=int, help='Seed for random number generation', default=None)

    return parser


def print_stats(results: ManyFetch.ManyFetchReturnType):
    successful = failed = 0
    total_time, times = results

    print(f'Total time taken: {total_time}s')

    for idx, secs in enumerate(times, start=1):
        if isinstance(secs, float):
            successful += 1
            print(f'[{idx}/{len(times)}] {secs:.2f}s')
        else:
            failed += 1
            print(f'[{idx}/{len(times)}] Error: {secs}')

    print(f'Successful: {successful} | Failed: {failed} | Total: {len(times)}')


async def main() -> int:

    global URL

    parser = create_parser()
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)

    URL = args.url

    fetcher = ManyFetch(n=args.number)
    func = getattr(fetcher, args.mode, None)

    if func is None:
        # fmt: off
        parser.error(
            f'Invalid mode {args.mode} provided.'
            f'Add a async def {args.mode}(self) -> tuple[float, list[float]] method to ManyFetch'
        )
        # fmt: on

    results: ManyFetch.ManyFetchReturnType = await func()
    print_stats(results)

    return 0


if __name__ == "__main__":
    sys.exit(asyncio.run(main()))