Skip to content

Instantly share code, notes, and snippets.

@clintonb
Last active May 2, 2022 16:26
Show Gist options
  • Save clintonb/6ee13e39ca6cc5c56c49 to your computer and use it in GitHub Desktop.
Save clintonb/6ee13e39ca6cc5c56c49 to your computer and use it in GitHub Desktop.
python-social-auth OpenID Connect Backend
"""
This file contains Django authentication backends. For more information visit
https://docs.djangoproject.com/en/dev/topics/auth/customizing/.
"""
from calendar import timegm
from datetime import datetime
from django.conf import settings
import jwt
from social.backends.oauth import BaseOAuth2
# pylint: disable=abstract-method
from social.exceptions import AuthTokenError
class EdXOAuth2(BaseOAuth2):
name = 'edx-oauth2'
AUTHORIZATION_URL = '{0}/authorize/'.format(settings.SOCIAL_AUTH_EDX_OAUTH2_URL_ROOT)
ACCESS_TOKEN_URL = '{0}/access_token/'.format(settings.SOCIAL_AUTH_EDX_OAUTH2_URL_ROOT)
ACCESS_TOKEN_METHOD = 'POST'
REDIRECT_STATE = False
ID_KEY = 'username'
EXTRA_DATA = [
('username', 'id'),
('code', 'code'),
('expires_in', 'expires'),
('refresh_token', 'refresh_token', True),
]
def get_user_details(self, response):
"""Return user details from edX account"""
return {
'username': response.get('username'),
'email': '',
'fullname': '',
'first_name': '',
'last_name': ''
}
class OpenIdConnectAssociation(object):
""" Use Association model to save the nonce by force. """
def __init__(self, handle, secret='', issued=0, lifetime=0, assoc_type=''):
self.handle = handle # as nonce
self.secret = secret.encode() # not use
self.issued = issued # not use
self.lifetime = lifetime # not use
self.assoc_type = assoc_type # as state
class OpenIdConnectAuth(BaseOAuth2):
ID_TOKEN_ISSUER = None
DEFAULT_SCOPE = ['openid']
RESPONSE_TYPE = 'code'
EXTRA_DATA = [
('access_token', 'access_token'),
('token_type', 'token_type'),
('expires_in', 'expires_in'),
('refresh_token', 'refresh_token'),
('id_token', 'id_token')
]
def auth_params(self, state=None):
"""Return extra arguments needed on auth process."""
params = super(OpenIdConnectAuth, self).auth_params(state)
params['nonce'] = self._get_and_store_nonce(self.AUTHORIZATION_URL, state)
return params
def auth_complete_params(self, state=None):
params = super(OpenIdConnectAuth, self).auth_complete_params(state)
# Add a nonce to the request so that to help counter CSRF
params['nonce'] = self._get_and_store_nonce(self.ACCESS_TOKEN_URL, state)
return params
def _get_and_store_nonce(self, url, state):
# Create a nonce
nonce = self.strategy.random_string(64)
# Store the nonce
association = OpenIdConnectAssociation(nonce, assoc_type=state)
self.strategy.storage.association.store(url, association)
return nonce
def _get_nonce(self, nonce):
server_url = self.ACCESS_TOKEN_URL
try:
return self.strategy.storage.association.get(server_url=server_url, handle=nonce)[0]
except: # pylint: disable=bare-except
return None
def _remove_nonce(self, nonce_id):
try:
self.strategy.storage.association.remove([nonce_id])
except: # pylint: disable=bare-except
return None
def _validate_and_return_id_token(self, id_token):
client_id, _client_secret = self.get_key_and_secret()
try:
# Decode the JWT and raise an error if the secret is invalid or
# the response has expired.
decryption_key = self.setting('ID_TOKEN_DECRYPTION_KEY')
id_token = jwt.decode(id_token, decryption_key)
except (jwt.DecodeError, jwt.ExpiredSignature) as de:
raise AuthTokenError(self, de)
# Verify the issuer of the id_token is correct
if id_token['iss'] != self.ID_TOKEN_ISSUER:
raise AuthTokenError(self, 'Incorrect id_token: iss')
# Verify the token was issued in the last 10 minutes
utc_timestamp = timegm(datetime.utcnow().utctimetuple())
if id_token['iat'] < (utc_timestamp - 600):
raise AuthTokenError(self, 'Incorrect id_token: iat')
# Verify this client is the correct recipient of the id_token
aud = id_token.get('aud')
if aud != client_id:
raise AuthTokenError(self, 'Incorrect id_token: aud')
# Validate the nonce to ensure the request was not modified
nonce = id_token.get('nonce')
if not nonce:
raise AuthTokenError(self, 'Incorrect id_token: nonce')
nonce_obj = self._get_nonce(id_token['nonce'])
if nonce_obj:
self._remove_nonce(nonce_obj.id)
else:
raise AuthTokenError(self, 'Incorrect id_token: nonce')
return id_token
class EdXOpenIdConnect(OpenIdConnectAuth):
name = 'edx-oidc'
ID_TOKEN_ISSUER = settings.SOCIAL_AUTH_EDX_OIDC_URL_ROOT
AUTHORIZATION_URL = '{0}/authorize/'.format(settings.SOCIAL_AUTH_EDX_OIDC_URL_ROOT)
ACCESS_TOKEN_URL = '{0}/access_token/'.format(settings.SOCIAL_AUTH_EDX_OIDC_URL_ROOT)
def user_data(self, access_token, *args, **kwargs):
return self._validate_and_return_id_token(kwargs['response'].get('id_token'))
def get_user_details(self, response):
return {
u'username': response['username'],
u'email': response['email'],
u'full_name': response['name'],
u'first_name': response['given_name'],
u'last_name': response['family_name']
}
from calendar import timegm
import json
import datetime
from django.conf import settings
import jwt
from social.exceptions import AuthTokenError
from social.tests.backends.oauth import OAuth2Test
class EdXOAuth2Tests(OAuth2Test):
backend_path = 'analytics_dashboard.backends.EdXOAuth2'
expected_username = 'edx'
access_token_body = json.dumps({
'access_token': 'foobar',
'token_type': 'bearer',
'username': 'edx'
})
def test_login(self):
self.do_login()
def test_partial_pipeline(self):
self.do_partial_pipeline()
class OpenIdConnectTestMixin(object):
"""
Mixin to test OpenID Connect consumers. Inheriting classes should also inherit OAuth2Test.
"""
expected_username = u'edx'
client_key = 'a-key'
client_secret = 'a-secret-key'
issuer = None # id_token issuer
def setUp(self):
super(OpenIdConnectTestMixin, self).setUp()
self.access_token_body = self._parse_nonce_and_return_access_token_body
def extra_settings(self):
xs = super(OpenIdConnectTestMixin, self).extra_settings()
xs.update({
'SOCIAL_AUTH_{}_KEY'.format(self.name): self.client_key,
'SOCIAL_AUTH_{}_SECRET'.format(self.name): self.client_secret,
'SOCIAL_AUTH_{}_ID_TOKEN_DECRYPTION_KEY'.format(self.name): self.client_secret
})
return xs
def _parse_nonce_and_return_access_token_body(self, request, _url, headers):
"""
Get the nonce from the request parameters, add it to the id_token, and return the complete response.
"""
body = self._prepare_access_token_body(nonce=request.parsed_body[u'nonce'][0])
return 200, headers, body
def _prepare_access_token_body(self, client_key=None, client_secret=None, expiration_datetime=None,
issue_datetime=None, nonce=None):
"""
Prepares a provider access token response
Arguments
client_id (str) -- OAuth ID for the client that requested authentication.
client_secret (str) -- OAuth secret for the client that requested authentication.
expiration_time (datetime) -- Date and time after which the response should be considered invalid.
"""
body = {'access_token': 'foobar', 'token_type': 'bearer'}
client_key = client_key or self.client_key
client_secret = client_secret or self.client_secret
now = datetime.datetime.utcnow()
expiration_datetime = expiration_datetime or (now + datetime.timedelta(seconds=30))
issue_datetime = issue_datetime or now
nonce = nonce or None
id_token = {
u'iss': self.issuer,
u'nonce': nonce,
u'aud': client_key,
u'azp': client_key,
u'exp': timegm(expiration_datetime.utctimetuple()),
u'iat': timegm(issue_datetime.utctimetuple()),
u'username': self.expected_username,
u'name': u'Ed Xavier',
u'given_name': u'Ed',
u'family_name': u'Xavier',
u'email': u'[email protected]'
}
body[u'id_token'] = jwt.encode(id_token, client_secret)
return json.dumps(body)
def test_login(self):
self.do_login()
def test_partial_pipeline(self):
self.do_partial_pipeline()
def assertAutTokenErrorRaised(self, expected_message, **access_token_kwargs):
self.access_token_body = self._prepare_access_token_body(**access_token_kwargs)
self.assertRaisesRegexp(AuthTokenError, expected_message, self.do_login)
def test_invalid_secret(self):
self.assertAutTokenErrorRaised('Token error: Signature verification failed', client_secret='wrong!')
def test_expired_signature(self):
expiration_datetime = datetime.datetime.utcnow() - datetime.timedelta(seconds=30)
self.assertAutTokenErrorRaised('Token error: Signature has expired', expiration_datetime=expiration_datetime)
def test_invalid_audience(self):
self.assertAutTokenErrorRaised('Token error: Incorrect id_token: aud', client_key='someone-else')
def test_invalid_issue_time(self):
expiration_datetime = datetime.datetime.utcnow() - datetime.timedelta(hours=1)
self.assertAutTokenErrorRaised('Token error: Incorrect id_token: iat', issue_datetime=expiration_datetime)
def test_invalid_nonce(self):
self.assertAutTokenErrorRaised('Token error: Incorrect id_token: nonce', nonce='something-wrong')
class EdXOpenIdConnectTests(OpenIdConnectTestMixin, OAuth2Test):
backend_path = 'analytics_dashboard.backends.EdXOpenIdConnect'
issuer = settings.SOCIAL_AUTH_EDX_OIDC_URL_ROOT
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment