Created
February 13, 2020 19:56
-
-
Save kevlarr/937b12b02295b631010f9401b8c2160c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
See: | |
https://aws.amazon.com/premiumsupport/knowledge-center/decode-verify-cognito-json-token/ | |
https://medium.com/datadriveninvestor/jwt-authentication-with-fastapi-and-aws-cognito-1333f7f2729e | |
""" | |
from datetime import datetime | |
import logging | |
from os import environ | |
import re | |
from typing import Any, Dict, List, Optional | |
from fastapi import HTTPException | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from jose import jwt, jwk, JWTError | |
from jose.utils import base64url_decode | |
from mypy_extensions import TypedDict | |
from pydantic import BaseModel | |
import requests | |
from starlette.requests import Request | |
from starlette.status import HTTP_401_UNAUTHORIZED | |
LOG = logging.getLogger(__name__) | |
AWS_REGION = environ.get("AWS_REGION") | |
USER_POOL_ID = environ.get("USER_POOL_ID") | |
UNAUTHORIZED = HTTPException( | |
status_code=HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
KeyId = str | |
# Needs to remain a base dict for token parsing w/ JOSE | |
JWKey = TypedDict("JWKey", { | |
"alg": str, # Algorithm | |
"e": str, | |
"kid": KeyId, | |
"kty": str, | |
"n": str, | |
"use": str, | |
}) | |
class JWKeys(BaseModel): | |
keys: List[JWKey] | |
# Need to request public cognito keys | |
__JWKS: JWKeys = JWKeys.parse_obj(requests.get( | |
f"https://cognito-idp.{AWS_REGION}.amazonaws.com/" | |
f"{USER_POOL_ID}/.well-known/jwks.json" | |
).json()) | |
KEY_MAP: Dict[KeyId, JWKey] = {jwk["kid"]: jwk for jwk in __JWKS.keys} | |
JWtClaims = TypedDict("Claims", { | |
"aud": str, | |
"auth_time": str, | |
"cognito:groups": str, | |
"cognito:username": str, | |
"email": str, | |
"email_verified": str, | |
"event_id": str, | |
"exp": str, | |
"iat": str, | |
"iss": str, | |
"sub": str, | |
"token_use": str, | |
}) | |
class JWTCredentials(BaseModel): | |
jwt_token: str | |
header: Dict[str, str] | |
claims: Dict[str, Any] | |
signature: str | |
message: str | |
class JWTBearer(HTTPBearer): | |
""" Custom HTTPBearer that can be used as a FastAPI dependency """ | |
async def __call__(self, req: Request) -> Dict: | |
http_creds: HTTPAuthorizationCredentials = await super().__call__(req) | |
try: | |
assert http_creds and http_creds.scheme == "Bearer" | |
token = http_creds.credentials | |
message, signature = token.rsplit(".", 1) | |
jwt_creds = JWTCredentials( | |
jwt_token=token, | |
header=jwt.get_unverified_header(token), | |
claims=jwt.get_unverified_claims(token), | |
signature=signature, | |
message=message, | |
) | |
assert self.token_valid(jwt_creds) | |
except (AssertionError, JWTError): | |
raise UNAUTHORIZED | |
return { | |
"username": jwt_creds.claims["cognito:username"], | |
"groups": jwt_creds.claims["cognito:groups"][-1], | |
} | |
def token_valid(self, creds: JWTCredentials) -> bool: | |
""" Checks validity of token signature and expiry """ | |
try: | |
assert creds.claims["exp"] > datetime.now().timestamp() | |
public_key = KEY_MAP[creds.header["kid"]] | |
key = jwk.construct(public_key) | |
signature = base64url_decode(creds.signature.encode()) | |
assert key.verify(creds.message.encode(), signature) | |
return True | |
except (KeyError, AssertionError): | |
return False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment