Skip to content

Instantly share code, notes, and snippets.

@jdalegonzalez
Last active January 31, 2023 13:03
Show Gist options
  • Save jdalegonzalez/0ca2272c015266e667924836364349c3 to your computer and use it in GitHub Desktop.
Save jdalegonzalez/0ca2272c015266e667924836364349c3 to your computer and use it in GitHub Desktop.
Uses the the python-jose package to decode and validate an amazon identity or access token.
#!/usr/bin/env python3
"""
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
This software consists of voluntary contributions made by many individuals
and is licensed under the MIT license. For more information, see
<http://www.doctrine-project.org>.
"""
""" Handles validation of a JWT web-token passed by the client
"""
import os
import sys
import argparse
import requests
import simplejson as json
from jose import jwt, JWTError
def pool_url(aws_region, aws_user_pool):
""" Create an Amazon cognito issuer URL from a region and pool id
Args:
aws_region (string): The region the pool was created in.
aws_user_pool (string): The Amazon region ID.
Returns:
string: a URL
"""
return (
"https://cognito-idp.{}.amazonaws.com/{}".
format(aws_region, aws_user_pool)
)
# def pool_url
def get_client_id_from_access_token(aws_region, aws_user_pool, token):
""" Pulls the client ID out of an Access Token
"""
claims = get_claims(aws_region, aws_user_pool, token)
if claims.get('token_use') != 'access':
raise ValueError('Not an access token')
return claims.get('client_id')
# def get_client_id
def get_client_id_from_id_token(token):
""" Pulls the audience (client id) out of an id_token
"""
# header, payload, _ = get_token_segments(token)
payload = jwt.get_unverified_claims(token)
return payload.get('aud')
# def get_client_id
def get_user_email(aws_region, aws_user_pool, client_id, id_token):
""" Pulls the user email out of an id token
"""
if client_id is None:
client_id = os.environ.get('AWS_CLIENT_ID')
if client_id is None:
client_id = get_client_id_from_id_token(id_token)
claims = get_claims(aws_region, aws_user_pool, id_token, client_id)
if claims.get('token_use') != 'id':
raise ValueError('Not an ID Token')
return claims.get('email')
# def get_user_email
def get_claims(aws_region, aws_user_pool, token, audience=None):
""" Given a token (and optionally an audience), validate and
return the claims for the token
"""
# header, _, _ = get_token_segments(token)
header = jwt.get_unverified_header(token)
kid = header['kid']
verify_url = pool_url(aws_region, aws_user_pool)
keys = aws_key_dict(aws_region, aws_user_pool)
key = keys.get(kid)
kargs = {"issuer": verify_url}
if audience is not None:
kargs["audience"] = audience
claims = jwt.decode(
token,
key,
**kargs
)
return claims
# def get_claims
def aws_key_dict(aws_region, aws_user_pool):
""" Fetches the AWS JWT validation file (if necessary) and then converts
this file into a keyed dictionary that can be used to validate a web-token
we've been passed
Args:
aws_user_pool (string): the ID for the user pool
Returns:
dict: a dictonary of values
"""
filename = os.path.abspath(
os.path.join(
os.path.dirname(sys.argv[0]), 'aws_{}.json'.format(aws_user_pool)
)
)
if not os.path.isfile(filename):
# If we can't find the file already, try to download it.
aws_data = requests.get(
pool_url(aws_region, aws_user_pool) + '/.well-known/jwks.json'
)
aws_jwt = json.loads(aws_data.text)
with open(filename, 'w+') as json_data:
json_data.write(aws_data.text)
json_data.close()
else:
with open(filename) as json_data:
aws_jwt = json.load(json_data)
json_data.close()
# We want a dictionary keyed by the kid, not a list.
result = {}
for item in aws_jwt['keys']:
result[item['kid']] = item
return result
# def aws_key_dict
def env_with_error(val, message, default=None):
""" Tries to fetch a value from the environment and throws an arror if it's
missing. Used so that we can return a better error message
Args:
val (string); The value to fetch from the environment
message (string): The message to raise
default (string): An optional default value.
Returns:
string: The value from the environment
"""
result = os.environ.get(val)
if result is None:
result = default
if result is None:
raise KeyError(message)
return result
# def get_with_error
def run_test():
""" Validates an identity token passed in as an argument.
We can get the client_id from one of three places. If we're
passed an access token, we can get the client_id from that.
If we're passed a client_id, we'll use it. If we're given
neither an access token nor a client_id as an argument, we'll
look for something in the environemnt. If that isn't set,
we'll use the client ID passed as the audience in the identity
token itself.
"""
# pylint: disable=missing-docstring,too-few-public-methods
class Bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
parser = argparse.ArgumentParser(
description='Validates an AWS id token and prints the email address.')
parser.add_argument(
'-c', '--client_id',
help='The application id to which the id_token applies',
metavar='<client_id>',
dest='client_id'
)
parser.add_argument(
'-a', '--access_token',
help='An access token returned from the authentication authority' +
' used to retrieve an client_id',
metavar='<access_token>',
)
parser.add_argument(
'-r', '--aws_region',
help='The AWS Region that the pool is defined for. (ie. us-west-2)',
metavar='<aws_region>',
dest='aws_region'
)
parser.add_argument(
'-p', '--aws_pool',
help='The AWS Pool ID that token comes from.',
metavar='<aws_pool_id>',
dest='aws_pool'
)
parser.add_argument(
'id_token',
nargs=1,
help='The ID token to be validated and have the email address printed',
metavar='<id_token>'
)
args = parser.parse_args()
aws_region = env_with_error(
"AWS_REGION", 'Missing AWS_REGION environment variable',
args.aws_region
)
aws_pool = env_with_error(
"AWS_USER_POOL", 'Missing AWS_USER_POOL environment variable',
args.aws_pool
)
client_id = None
if args.client_id is not None:
client_id = args.client_id
elif args.access_token is not None:
client_id = get_client_id_from_access_token(
aws_region, aws_pool, args.access_token)
try:
print(
Bcolors.OKGREEN + Bcolors.BOLD + "SUCCESS: " + Bcolors.ENDC +
get_user_email(aws_region, aws_pool, client_id, args.id_token[0]))
except JWTError as error:
print(
Bcolors.BOLD + Bcolors.FAIL + "FAILED: " + Bcolors.ENDC +
str(error))
# def run_test
if __name__ == '__main__':
run_test()
@nk9
Copy link

nk9 commented Oct 4, 2022

It would be great if this code could have an open license attached, MIT or similar.

@jdalegonzalez
Copy link
Author

It would be great if this code could have an open license attached, MIT or similar.

Done @nk9

@nk9
Copy link

nk9 commented Oct 5, 2022

Much appreciated, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment