import datetime from typing import Iterable import os import uuid import jwt import pytz import time import calendar from Cryptodome.PublicKey import RSA PRIVATE_KEY_FILE = "jwt-test.key" PUBLIC_KEY_FILE = "jwt-test.pub" ALGORITHM = "RS512" class Timer: def __enter__(self): self.start = time.clock() return self def __exit__(self, *args): self.end = time.clock() self.interval = self.end - self.start @property def interval_rounded(self): return round(self.interval, 6) def now(): return datetime.datetime.utcnow().replace(tzinfo=pytz.utc) def create_key(): key = RSA.generate(2048) private_key = key.export_key().decode("utf8") public_key = key.publickey().export_key().decode("utf8") return private_key, public_key def create_and_save_key(): private_key, public_key = create_key() with open(PRIVATE_KEY_FILE, "w+") as fobj: fobj.write(private_key) with open(PUBLIC_KEY_FILE, "w+") as fobj: fobj.write(public_key) def load_public_key(): return open(PUBLIC_KEY_FILE, "r").read() def load_private_key(): return open(PRIVATE_KEY_FILE, "r").read() def key_file_exists(): return os.path.isfile(PRIVATE_KEY_FILE) and os.path.isfile(PUBLIC_KEY_FILE) def generate( user_id: str = None, jwt_id: str = None, scope: str = "sherpany forsta", issued_at: datetime.datetime = None, expire_in: datetime.timedelta = None, private_key: str = None, algorithm: str = ALGORITHM, ): assert algorithm.startswith("RS") if private_key is None: private_key = load_private_key() if issued_at is None: issued_at = now() if expire_in is None: expire_in = datetime.timedelta(minutes=10) expires_at = issued_at + expire_in payload = { "jti": jwt_id or str(uuid.uuid4()), "typ": "access", "ver": "2.2", "uid": user_id or str(uuid.uuid4()), "iat": calendar.timegm(issued_at.utctimetuple()), "nbf": calendar.timegm(issued_at.utctimetuple()), "exp": calendar.timegm(expires_at.utctimetuple()), "scope": scope, } raw = jwt.encode(payload=payload, algorithm=algorithm, key=private_key).decode( "ascii" ) return raw def validate(raw, public_key: str = None, algorithms: Iterable[str] = (ALGORITHM,)): if public_key is None: public_key = load_public_key() try: jwt.decode(raw, public_key, verify=True, algorithms=list(algorithms)) except: return False return True if __name__ == "__main__": if not key_file_exists(): create_and_save_key() private_key = load_private_key() public_key = load_public_key() print(private_key) print() print(public_key) print() print("expired JWT:") raw_expired = generate( issued_at=datetime.datetime(2018, 1, 1), expire_in=datetime.timedelta(minutes=10), private_key=private_key, ) print(raw_expired) print(f"valid: {validate(raw_expired, public_key=public_key)}") print() print("valid JWT:") with Timer() as t: raw_valid = generate( issued_at=now() - datetime.timedelta(minutes=5), expire_in=datetime.timedelta(days=255), private_key=private_key, ) print(raw_valid) print(f"generation time: {t.interval_rounded}s") with Timer() as t: print(f"valid: {validate(raw_valid, public_key=public_key, algorithms=['RS256'])}") print(f"validation time: {t.interval_rounded}s") print(f"length: {len(raw_valid)}") print() print("valid large JWT:") with Timer() as t: raw_valid = generate( issued_at=now() - datetime.timedelta(minutes=5), expire_in=datetime.timedelta(days=255), private_key=private_key, algorithm="RS512", ) print(raw_valid) print(f"generation time: {t.interval_rounded}s") with Timer() as t: print(f"valid: {validate(raw_valid, public_key=public_key, algorithms=['RS512'])}") print(f"validation time: {t.interval_rounded}s") print(f"length: {len(raw_valid)}")