import datetime
from typing import Iterable

import os
import uuid

import jwt
import pytz
import time
import calendar

from Cryptodome.PublicKey import RSA


PRIVATE_KEY_FILE = "jwt-test.key"
PUBLIC_KEY_FILE = "jwt-test.pub"

ALGORITHM = "RS512"


class Timer:
    def __enter__(self):
        self.start = time.clock()
        return self

    def __exit__(self, *args):
        self.end = time.clock()
        self.interval = self.end - self.start
    
    @property
    def interval_rounded(self):
        return round(self.interval, 6)


def now():
    return datetime.datetime.utcnow().replace(tzinfo=pytz.utc)


def create_key():
    key = RSA.generate(2048)
    private_key = key.export_key().decode("utf8")
    public_key = key.publickey().export_key().decode("utf8")
    return private_key, public_key


def create_and_save_key():
    private_key, public_key = create_key()
    with open(PRIVATE_KEY_FILE, "w+") as fobj:
        fobj.write(private_key)

    with open(PUBLIC_KEY_FILE, "w+") as fobj:
        fobj.write(public_key)


def load_public_key():
    return open(PUBLIC_KEY_FILE, "r").read()


def load_private_key():
    return open(PRIVATE_KEY_FILE, "r").read()


def key_file_exists():
    return os.path.isfile(PRIVATE_KEY_FILE) and os.path.isfile(PUBLIC_KEY_FILE)


def generate(
    user_id: str = None,
    jwt_id: str = None,
    scope: str = "sherpany forsta",
    issued_at: datetime.datetime = None,
    expire_in: datetime.timedelta = None,
    private_key: str = None,
    algorithm: str = ALGORITHM,
):
    assert algorithm.startswith("RS")
    if private_key is None:
        private_key = load_private_key()
    if issued_at is None:
        issued_at = now()
    if expire_in is None:
        expire_in = datetime.timedelta(minutes=10)
    expires_at = issued_at + expire_in
    payload = {
        "jti": jwt_id or str(uuid.uuid4()),
        "typ": "access",
        "ver": "2.2",
        "uid": user_id or str(uuid.uuid4()),
        "iat": calendar.timegm(issued_at.utctimetuple()),
        "nbf": calendar.timegm(issued_at.utctimetuple()),
        "exp": calendar.timegm(expires_at.utctimetuple()),
        "scope": scope,
    }
    raw = jwt.encode(payload=payload, algorithm=algorithm, key=private_key).decode(
        "ascii"
    )
    return raw


def validate(raw, public_key: str = None, algorithms: Iterable[str] = (ALGORITHM,)):
    if public_key is None:
        public_key = load_public_key()
    try:
        jwt.decode(raw, public_key, verify=True, algorithms=list(algorithms))
    except:
        return False
    return True


if __name__ == "__main__":
    if not key_file_exists():
        create_and_save_key()

    private_key = load_private_key()
    public_key = load_public_key()

    print(private_key)
    print()
    print(public_key)

    print()
    print("expired JWT:")
    raw_expired = generate(
        issued_at=datetime.datetime(2018, 1, 1),
        expire_in=datetime.timedelta(minutes=10),
        private_key=private_key,
    )
    print(raw_expired)
    print(f"valid: {validate(raw_expired, public_key=public_key)}")

    print()
    print("valid JWT:")
    with Timer() as t:
        raw_valid = generate(
            issued_at=now() - datetime.timedelta(minutes=5),
            expire_in=datetime.timedelta(days=255),
            private_key=private_key,
        )
    print(raw_valid)
    print(f"generation time: {t.interval_rounded}s")
    with Timer() as t:
        print(f"valid: {validate(raw_valid, public_key=public_key, algorithms=['RS256'])}")
    print(f"validation time: {t.interval_rounded}s")
    print(f"length: {len(raw_valid)}")

    print()
    print("valid large JWT:")
    with Timer() as t:
        raw_valid = generate(
            issued_at=now() - datetime.timedelta(minutes=5),
            expire_in=datetime.timedelta(days=255),
            private_key=private_key,
            algorithm="RS512",
        )
    print(raw_valid)
    print(f"generation time: {t.interval_rounded}s")
    with Timer() as t:
        print(f"valid: {validate(raw_valid, public_key=public_key, algorithms=['RS512'])}")
    print(f"validation time: {t.interval_rounded}s")
    print(f"length: {len(raw_valid)}")