Skip to content

Instantly share code, notes, and snippets.

@dopry
Created September 12, 2024 18:49
Show Gist options
  • Save dopry/9a83167f8a2af5785e44072e299ac8f2 to your computer and use it in GitHub Desktop.
Save dopry/9a83167f8a2af5785e44072e299ac8f2 to your computer and use it in GitHub Desktop.
Wagtail OIDC example
{% extends "wagtailadmin/login.html" %}
{% block branding_login %}CMS{% endblock %}
{% block login_form %}
{% if user.is_authenticated %}
<p>Current user: {{ user.email }}</p>
<form action="{% url 'oidc_logout' %}" method="post">
{% csrf_token %}
<input type="submit" value="logout" />
</form>
{% else %}
<a href="{% url 'oidc_authentication_init' %}" class="button">Sign In</a>
{% endif %}
{% endblock %}
{% block submit_buttons %}{% endblock %}
import logging
import requests
import time
import unicodedata
from django.contrib.auth import BACKEND_SESSION_KEY
from django.contrib.auth.models import Group
from django.core.exceptions import SuspiciousOperation
from django.http import HttpResponseRedirect
from django.urls import reverse
from django.utils.decorators import method_decorator
from django.utils.deprecation import MiddlewareMixin
from django.utils.module_loading import import_string
from django.utils.html import urlencode
from django.views.decorators.cache import never_cache
from mozilla_django_oidc.auth import OIDCAuthenticationBackend
from mozilla_django_oidc.utils import (absolutify,
import_from_settings)
from wagtail.admin.views.account import LogoutView
LOGGER = logging.getLogger(__name__)
class TokenRequestInvalidGrantException(Exception):
"""The user did something suspicious"""
class WagtailOIDCAuthenticationBackend(OIDCAuthenticationBackend):
"""
We've liberally overidden this Backend to use refresh tokens
and opaque access tokens and to setup out group memberships
for editors.
"""
def filter_users_by_claims(self, claims):
"""Return all users matching the specified email."""
email = claims.get('email')
if not email:
return self.UserModel.objects.none()
user = self.UserModel.objects.filter(email__iexact=email)
return user
def create_user(self, claims):
user = super(WagtailOIDCAuthenticationBackend, self).create_user(claims)
user = self.update_user(user, claims)
return user
def get_token(self, payload):
"""Return token object as a dictionary."""
auth = None
response = requests.post(
self.OIDC_OP_TOKEN_ENDPOINT,
data=payload,
auth=auth,
verify=self.get_settings('OIDC_VERIFY_SSL', True),
timeout=self.get_settings('OIDC_TIMEOUT', None),
proxies=self.get_settings('OIDC_PROXY', None))
if (response.status_code == 400):
# catch oidc errors
LOGGER.error('Failed to get token: %s, payload: %s', response.text, payload)
obj = response.json()
if obj.get('error') == 'invalid_grant':
raise TokenRequestInvalidGrantException(obj.get('error_description'))
raise Exception(obj.get('error'))
else:
response.raise_for_status()
LOGGER.debug('token response: %s', response.text)
return response.json()
def authenticate(self, request, **kwargs):
"""Authenticates a user based on the OIDC code flow."""
self.request = request
if not self.request:
return None
state = self.request.GET.get('state')
code = self.request.GET.get('code')
nonce = kwargs.pop('nonce', None)
if not code or not state:
return None
reverse_url = self.get_settings('OIDC_AUTHENTICATION_CALLBACK_URL',
'oidc_authentication_callback')
token_payload = {
'client_id': self.OIDC_RP_CLIENT_ID,
'client_secret': self.OIDC_RP_CLIENT_SECRET,
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': absolutify(
self.request,
reverse(reverse_url)
),
}
# Get the token
token_info = self.get_token(token_payload)
id_token = token_info.get('id_token')
access_token = token_info.get('access_token')
# Validate the token
payload = self.verify_token(id_token, nonce=nonce)
if payload:
self.store_tokens(request, token_info)
try:
return self.get_or_create_user(access_token, id_token, payload)
except SuspiciousOperation as exc:
LOGGER.warning('failed to get or create user: %s', exc)
return None
return None
def store_tokens(self, request, token_info):
id_token = token_info.get('id_token')
access_token = token_info.get('access_token')
refresh_token = token_info.get('refresh_token')
expires_in = token_info.get('expires_in')
session = request.session
if access_token:
session['oidc_access_token'] = access_token
if id_token:
session['oidc_id_token'] = id_token
if refresh_token:
session['oidc_refresh_token'] = refresh_token
# refresh at 80% of the expiration time so we should generally refresh early.
# primarly used by the RefreshTokenMiddleware
request.session['oidc_token_expiration'] = time.time() + (0.8 * expires_in)
LOGGER.debug('refresh token stored in session: %s ', refresh_token)
def update_user(self, user, claims):
LOGGER.debug('update_user', user, claims)
given_name = claims.get('given_name', '')
if given_name:
user.first_name = given_name
family_name = claims.get('family_name', '')
if family_name:
user.last_name = family_name
# reset groups to blank, and re-add any required permissions.
user.groups.clear()
is_staff = claims.get('is_staff', None)
if is_staff is not None:
user.is_staff = is_staff
group = Group.objects.get(name='Editors')
user.groups.add(group)
is_superuser = claims.get('is_superuser', None)
if is_superuser is not None:
user.is_superuser = is_superuser
group = Group.objects.get(name='Editors')
user.groups.add(group)
groups = claims.get('groups', '')
for group in groups.split(' '):
try:
_group = Group.objects.get(name=group)
if _group:
user.groups.add(_group)
except Exception:
# suppress group not found errors.
# TODO: Log a proper error here.
pass
user.save()
return user
def verify_claims(self, claims=[]):
print('verify_claims', claims)
verified = super(WagtailOIDCAuthenticationBackend, self).verify_claims(claims)
is_superuser = claims.get('is_superuser', False)
is_staff = claims.get('is_staff', False)
groups = claims.get('groups', '')
groupNames = groups.split(' ')
is_editor = 'Editors' in groupNames
return verified and (is_editor or is_staff or is_superuser)
def generate_username(email):
# Using Python 3 and Django 1.11+, usernames can contain alphanumeric
# (ascii and unicode), _, @, +, . and - characters. So we normalize
# it and slice at 150 characters.
return unicodedata.normalize('NFKC', email)[:150]
class RefreshTokenMiddleware(MiddlewareMixin):
"""
Originally based on the session refresh middleware from mozilla-django-oidc, but
then super stripped down to work with the refresh token. The interface to store
tokens was changed.
"""
def __init__(self, *args, **kwargs):
super(RefreshTokenMiddleware, self).__init__(*args, **kwargs)
self.OIDC_RP_CLIENT_ID = self.get_settings('OIDC_RP_CLIENT_ID')
self.OIDC_RP_CLIENT_SECRET = self.get_settings('OIDC_RP_CLIENT_SECRET')
self.OIDC_RP_SCOPES = self.get_settings('OIDC_RP_SCOPES', 'openid email')
self.OIDC_USE_NONCE = self.get_settings('OIDC_USE_NONCE', True)
self.OIDC_NONCE_SIZE = self.get_settings('OIDC_NONCE_SIZE', 32)
@staticmethod
def get_settings(attr, *args):
return import_from_settings(attr, *args)
def process_request(self, request):
LOGGER.debug('RefreshTokenMiddleware')
expiration = request.session.get('oidc_token_expiration', 0)
now = time.time()
if expiration > now:
# The id_token is still valid, so we don't have to do anything.
LOGGER.debug('token is not expired (%s > %s)', expiration, now)
return
LOGGER.debug('token has expired')
refresh_token = request.session.get('oidc_refresh_token', None)
if not refresh_token:
"""No refresh token, so we couldn't refresh the access token if we wanted to."""
LOGGER.debug('no refresh token')
return
backend_class_path = request.session.get(BACKEND_SESSION_KEY)
if not backend_class_path:
"""No backend session, so we couldn't refresh the access token if we wanted to."""
LOGGER.debug('no auth backend for session')
return
auth_class = import_string(backend_class_path)
if not issubclass(auth_class, WagtailOIDCAuthenticationBackend):
"""The backend is not our OIDC backend, so we won't try to refresh the access token."""
LOGGER.debug('Not our OIDC auth backend')
return
LOGGER.debug("auth_class: %s", auth_class)
auth_backend = auth_class()
# LOGGER.debug("auth_backend", auth_backend)
token_payload = {
'client_id': self.OIDC_RP_CLIENT_ID,
'client_secret': self.OIDC_RP_CLIENT_SECRET,
'grant_type': 'refresh_token',
'refresh_token': refresh_token,
}
try:
# Get the token
token_info = auth_backend.get_token(token_payload)
LOGGER.debug("token_info: %s", token_info)
auth_backend.store_tokens(request, token_info)
except TokenRequestInvalidGrantException as e:
del request.session['oidc_refresh_token']
del request.session['oidc_id_token']
""" We are actually using refresh_token_expiration to track the refresh token expiration. """
del request.session['oidc_token_expiration']
del request.session['oidc_access_token']
LOGGER.debug('invalid_grant, typically the refresh token was already used.', e)
return
except Exception as e:
LOGGER.error("Error refreshing token: %s", e)
raise e
class WagtailOIDCLogoutView(LogoutView):
"""Expand the Wagtail LogoutView to redirect the user to OIDC end_session_endpoint"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.OIDC_RP_CLIENT_ID = self.get_settings("OIDC_RP_CLIENT_ID")
self.OIDC_RP_END_SESSION_ENDPOINT = self.get_settings("OIDC_RP_END_SESSION_ENDPOINT")
@staticmethod
def get_settings(attr, *args):
return import_from_settings(attr, *args)
def get_next_page(self):
"""
The page the user with be redirected to after the django logout
The page is the OIDC end session endpoint, with a querystring that informs it
to redirect the client back to wagtail.
django oauth toolkit requres the client_id and id_token_hint set when doing an
unprompted post logout redirect
see: method validate_logout_request from
https://github.com/jazzband/django-oauth-toolkit/blob/master/oauth2_provider/views/oidc.py#L213
see: https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout
"""
end_session_endpoint = self.OIDC_RP_END_SESSION_ENDPOINT
home_url = self.request.build_absolute_uri(reverse("wagtailadmin_home"))
parameters = dict(
post_logout_redirect_uri=home_url,
client_id=self.OIDC_RP_CLIENT_ID,
)
if "oidc_id_token" in self.request.session:
parameters['id_token_hint'] = self.request.session["oidc_id_token"]
return f"{end_session_endpoint}?{urlencode(parameters)}"
@method_decorator(never_cache)
def dispatch(self, request, *args, **kwargs):
"""
The parent class calls get_next_page after logout when the session information has been cleared.
We need to pull oidc_id_token out of the session first before the logout clears it
"""
next_page = self.get_next_page()
super().dispatch(request, *args, **kwargs)
return HttpResponseRedirect(next_page)
mozilla-django-oidc==2.0.0
INSTALLED_APPS = [
# ...
"mozilla_django_oidc",
]
MIDDLEWARE = [
....
oidc.RefreshTokenMiddleware
]
AUTHENTICATION_BACKENDS = [
...
oidc.WagtailOIDCAuthenticationBackend
]
# openif configuration
OIDC_OP_JWKS_ENDPOINT = env("OIDC_OP_JWKS_ENDPOINT")
OIDC_RP_CLIENT_ID = env("OIDC_RP_CLIENT_ID")
OIDC_RP_CLIENT_SECRET = env("OIDC_RP_CLIENT_SECRET")
OIDC_RP_END_SESSION_ENDPOINT = env("OIDC_RP_END_SESSION_ENDPOINT")
OIDC_OP_AUTHORIZATION_ENDPOINT = env("OIDC_OP_AUTHORIZATION_ENDPOINT")
OIDC_OP_TOKEN_ENDPOINT = env("OIDC_OP_TOKEN_ENDPOINT")
OIDC_OP_USER_ENDPOINT = env("OIDC_OP_USER_ENDPOINT")
# Renew in 5 minutes? (default is 15), to see if it solves log
# premature logout issue
OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS = 0.8 * 3600
OIDC_RP_SIGN_ALGO = "RS256"
OIDC_USERNAME_ALGO = "oidc.generate_username"
OIDC_RP_SCOPES = "openid email profile"
LOGIN_REDIRECT_URL = "/"
LOGOUT_REDIRECT_URL = "/"
...
from oidc import WagtailOIDCLogoutView
url_patterns = [
# override the wagtail logout view with an OIDC specific version
path("logout/", WagtailOIDCLogoutView.as_view(), name="wagtailadmin_logout"),
path("admin/", admin.site.urls),
path("oidc/", include("mozilla_django_oidc.urls")),
...
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment