Last active
December 30, 2022 10:10
-
-
Save linuskohl/024a487c2435ba1287e2d1c9d7406aea to your computer and use it in GitHub Desktop.
Helper functions to validate JSON Web Tokens for flask RESTful APIs by fetching JWKs from OpenID Provider Metadata. Used with Okta.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import wraps | |
from flask import request, abort, g | |
import json | |
import jwt | |
import requests | |
from typing import Union, List | |
from ..config import cache | |
from ..env import JWT_ISSUER, JWT_CLIENTID, JWT_AUDIENCE | |
DISCOVERY_URL = "/.well-known/oauth-authorization-server" | |
def login_required(f): | |
""" | |
Decorator to load JWT and globally sets user and uid. | |
JW Tokens are verified to match the issuer, audience and signature. | |
""" | |
@wraps(f) | |
def wrap(*args, **kwargs): | |
authorization = request.headers.get("authorization", None) | |
if not authorization: | |
abort(403) | |
try: | |
token_raw = authorization.split(' ')[1] | |
key_id = jwt.get_unverified_header(token_raw)['kid'] | |
jwk = get_jwk(JWT_ISSUER, JWT_CLIENTID, key_id, cache) | |
token = jwt.decode(token_raw, | |
jwk, | |
verify=True, | |
issuer=JWT_ISSUER, | |
audience=JWT_AUDIENCE, | |
algorithms=['RS256']) | |
g.user = token['sub'] | |
g.user_id = token['uid'] | |
print(g.user_id) | |
except Exception as e: | |
abort(403) | |
return f(*args, **kwargs) | |
return wrap | |
def get_jwk(issuer: str, client_id: str, kid: str, cache=None): | |
""" | |
Get JWK with key id | |
Args: | |
issuer(str): JWT Issuer | |
client_id(str): JWT Client ID | |
kid(str): Key ID | |
cache(Cache): Cache object to store keys | |
Returns: | |
Dict: JWT | |
""" | |
# try to load from cache | |
key = None | |
if cache: | |
key = cache.get(kid) | |
if key is None: | |
keys = fetch_jwks_for(issuer, client_id) | |
for k in keys: | |
# persist all keys | |
if cache: | |
cache.set(kid, k) | |
if k['kid'] == kid: | |
key = k | |
if key: | |
return jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key)) | |
raise Exception | |
def fetch_jwks_for(issuer: str, client_id: str) -> Union[None, List]: | |
""" | |
Get JWKs from OpenID Provider Metadata | |
Args: | |
issuer(str): JWT Issuer | |
client_id(str): JWT Client ID | |
Returns: | |
List: List of key objects | |
""" | |
oidp_metadata = fetch_metadata_for(issuer, client_id) | |
jwks_uri = oidp_metadata.get('jwks_uri') | |
jwks = requests.get(jwks_uri) | |
return jwks.json().get('keys') | |
def fetch_metadata_for(issuer: str, client_id: str) -> dict: | |
""" | |
Get OpenID Provider Metadata information | |
Args: | |
issuer(str): JWT Issuer | |
client_id(str): JWT Client ID | |
Returns: | |
dict: OpenID Provider Metadata | |
""" | |
url = issuer + DISCOVERY_URL | |
data = {'client_id': client_id} | |
r = requests.get(url, params=data) | |
r.raise_for_status() | |
return r.json() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment