Created
September 12, 2024 18:49
-
-
Save dopry/9a83167f8a2af5785e44072e299ac8f2 to your computer and use it in GitHub Desktop.
Wagtail OIDC example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{% extends "wagtailadmin/login.html" %} | |
{% block branding_login %}CMS{% endblock %} | |
{% block login_form %} | |
{% if user.is_authenticated %} | |
<p>Current user: {{ user.email }}</p> | |
<form action="{% url 'oidc_logout' %}" method="post"> | |
{% csrf_token %} | |
<input type="submit" value="logout" /> | |
</form> | |
{% else %} | |
<a href="{% url 'oidc_authentication_init' %}" class="button">Sign In</a> | |
{% endif %} | |
{% endblock %} | |
{% block submit_buttons %}{% endblock %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import requests | |
import time | |
import unicodedata | |
from django.contrib.auth import BACKEND_SESSION_KEY | |
from django.contrib.auth.models import Group | |
from django.core.exceptions import SuspiciousOperation | |
from django.http import HttpResponseRedirect | |
from django.urls import reverse | |
from django.utils.decorators import method_decorator | |
from django.utils.deprecation import MiddlewareMixin | |
from django.utils.module_loading import import_string | |
from django.utils.html import urlencode | |
from django.views.decorators.cache import never_cache | |
from mozilla_django_oidc.auth import OIDCAuthenticationBackend | |
from mozilla_django_oidc.utils import (absolutify, | |
import_from_settings) | |
from wagtail.admin.views.account import LogoutView | |
LOGGER = logging.getLogger(__name__) | |
class TokenRequestInvalidGrantException(Exception): | |
"""The user did something suspicious""" | |
class WagtailOIDCAuthenticationBackend(OIDCAuthenticationBackend): | |
""" | |
We've liberally overidden this Backend to use refresh tokens | |
and opaque access tokens and to setup out group memberships | |
for editors. | |
""" | |
def filter_users_by_claims(self, claims): | |
"""Return all users matching the specified email.""" | |
email = claims.get('email') | |
if not email: | |
return self.UserModel.objects.none() | |
user = self.UserModel.objects.filter(email__iexact=email) | |
return user | |
def create_user(self, claims): | |
user = super(WagtailOIDCAuthenticationBackend, self).create_user(claims) | |
user = self.update_user(user, claims) | |
return user | |
def get_token(self, payload): | |
"""Return token object as a dictionary.""" | |
auth = None | |
response = requests.post( | |
self.OIDC_OP_TOKEN_ENDPOINT, | |
data=payload, | |
auth=auth, | |
verify=self.get_settings('OIDC_VERIFY_SSL', True), | |
timeout=self.get_settings('OIDC_TIMEOUT', None), | |
proxies=self.get_settings('OIDC_PROXY', None)) | |
if (response.status_code == 400): | |
# catch oidc errors | |
LOGGER.error('Failed to get token: %s, payload: %s', response.text, payload) | |
obj = response.json() | |
if obj.get('error') == 'invalid_grant': | |
raise TokenRequestInvalidGrantException(obj.get('error_description')) | |
raise Exception(obj.get('error')) | |
else: | |
response.raise_for_status() | |
LOGGER.debug('token response: %s', response.text) | |
return response.json() | |
def authenticate(self, request, **kwargs): | |
"""Authenticates a user based on the OIDC code flow.""" | |
self.request = request | |
if not self.request: | |
return None | |
state = self.request.GET.get('state') | |
code = self.request.GET.get('code') | |
nonce = kwargs.pop('nonce', None) | |
if not code or not state: | |
return None | |
reverse_url = self.get_settings('OIDC_AUTHENTICATION_CALLBACK_URL', | |
'oidc_authentication_callback') | |
token_payload = { | |
'client_id': self.OIDC_RP_CLIENT_ID, | |
'client_secret': self.OIDC_RP_CLIENT_SECRET, | |
'grant_type': 'authorization_code', | |
'code': code, | |
'redirect_uri': absolutify( | |
self.request, | |
reverse(reverse_url) | |
), | |
} | |
# Get the token | |
token_info = self.get_token(token_payload) | |
id_token = token_info.get('id_token') | |
access_token = token_info.get('access_token') | |
# Validate the token | |
payload = self.verify_token(id_token, nonce=nonce) | |
if payload: | |
self.store_tokens(request, token_info) | |
try: | |
return self.get_or_create_user(access_token, id_token, payload) | |
except SuspiciousOperation as exc: | |
LOGGER.warning('failed to get or create user: %s', exc) | |
return None | |
return None | |
def store_tokens(self, request, token_info): | |
id_token = token_info.get('id_token') | |
access_token = token_info.get('access_token') | |
refresh_token = token_info.get('refresh_token') | |
expires_in = token_info.get('expires_in') | |
session = request.session | |
if access_token: | |
session['oidc_access_token'] = access_token | |
if id_token: | |
session['oidc_id_token'] = id_token | |
if refresh_token: | |
session['oidc_refresh_token'] = refresh_token | |
# refresh at 80% of the expiration time so we should generally refresh early. | |
# primarly used by the RefreshTokenMiddleware | |
request.session['oidc_token_expiration'] = time.time() + (0.8 * expires_in) | |
LOGGER.debug('refresh token stored in session: %s ', refresh_token) | |
def update_user(self, user, claims): | |
LOGGER.debug('update_user', user, claims) | |
given_name = claims.get('given_name', '') | |
if given_name: | |
user.first_name = given_name | |
family_name = claims.get('family_name', '') | |
if family_name: | |
user.last_name = family_name | |
# reset groups to blank, and re-add any required permissions. | |
user.groups.clear() | |
is_staff = claims.get('is_staff', None) | |
if is_staff is not None: | |
user.is_staff = is_staff | |
group = Group.objects.get(name='Editors') | |
user.groups.add(group) | |
is_superuser = claims.get('is_superuser', None) | |
if is_superuser is not None: | |
user.is_superuser = is_superuser | |
group = Group.objects.get(name='Editors') | |
user.groups.add(group) | |
groups = claims.get('groups', '') | |
for group in groups.split(' '): | |
try: | |
_group = Group.objects.get(name=group) | |
if _group: | |
user.groups.add(_group) | |
except Exception: | |
# suppress group not found errors. | |
# TODO: Log a proper error here. | |
pass | |
user.save() | |
return user | |
def verify_claims(self, claims=[]): | |
print('verify_claims', claims) | |
verified = super(WagtailOIDCAuthenticationBackend, self).verify_claims(claims) | |
is_superuser = claims.get('is_superuser', False) | |
is_staff = claims.get('is_staff', False) | |
groups = claims.get('groups', '') | |
groupNames = groups.split(' ') | |
is_editor = 'Editors' in groupNames | |
return verified and (is_editor or is_staff or is_superuser) | |
def generate_username(email): | |
# Using Python 3 and Django 1.11+, usernames can contain alphanumeric | |
# (ascii and unicode), _, @, +, . and - characters. So we normalize | |
# it and slice at 150 characters. | |
return unicodedata.normalize('NFKC', email)[:150] | |
class RefreshTokenMiddleware(MiddlewareMixin): | |
""" | |
Originally based on the session refresh middleware from mozilla-django-oidc, but | |
then super stripped down to work with the refresh token. The interface to store | |
tokens was changed. | |
""" | |
def __init__(self, *args, **kwargs): | |
super(RefreshTokenMiddleware, self).__init__(*args, **kwargs) | |
self.OIDC_RP_CLIENT_ID = self.get_settings('OIDC_RP_CLIENT_ID') | |
self.OIDC_RP_CLIENT_SECRET = self.get_settings('OIDC_RP_CLIENT_SECRET') | |
self.OIDC_RP_SCOPES = self.get_settings('OIDC_RP_SCOPES', 'openid email') | |
self.OIDC_USE_NONCE = self.get_settings('OIDC_USE_NONCE', True) | |
self.OIDC_NONCE_SIZE = self.get_settings('OIDC_NONCE_SIZE', 32) | |
@staticmethod | |
def get_settings(attr, *args): | |
return import_from_settings(attr, *args) | |
def process_request(self, request): | |
LOGGER.debug('RefreshTokenMiddleware') | |
expiration = request.session.get('oidc_token_expiration', 0) | |
now = time.time() | |
if expiration > now: | |
# The id_token is still valid, so we don't have to do anything. | |
LOGGER.debug('token is not expired (%s > %s)', expiration, now) | |
return | |
LOGGER.debug('token has expired') | |
refresh_token = request.session.get('oidc_refresh_token', None) | |
if not refresh_token: | |
"""No refresh token, so we couldn't refresh the access token if we wanted to.""" | |
LOGGER.debug('no refresh token') | |
return | |
backend_class_path = request.session.get(BACKEND_SESSION_KEY) | |
if not backend_class_path: | |
"""No backend session, so we couldn't refresh the access token if we wanted to.""" | |
LOGGER.debug('no auth backend for session') | |
return | |
auth_class = import_string(backend_class_path) | |
if not issubclass(auth_class, WagtailOIDCAuthenticationBackend): | |
"""The backend is not our OIDC backend, so we won't try to refresh the access token.""" | |
LOGGER.debug('Not our OIDC auth backend') | |
return | |
LOGGER.debug("auth_class: %s", auth_class) | |
auth_backend = auth_class() | |
# LOGGER.debug("auth_backend", auth_backend) | |
token_payload = { | |
'client_id': self.OIDC_RP_CLIENT_ID, | |
'client_secret': self.OIDC_RP_CLIENT_SECRET, | |
'grant_type': 'refresh_token', | |
'refresh_token': refresh_token, | |
} | |
try: | |
# Get the token | |
token_info = auth_backend.get_token(token_payload) | |
LOGGER.debug("token_info: %s", token_info) | |
auth_backend.store_tokens(request, token_info) | |
except TokenRequestInvalidGrantException as e: | |
del request.session['oidc_refresh_token'] | |
del request.session['oidc_id_token'] | |
""" We are actually using refresh_token_expiration to track the refresh token expiration. """ | |
del request.session['oidc_token_expiration'] | |
del request.session['oidc_access_token'] | |
LOGGER.debug('invalid_grant, typically the refresh token was already used.', e) | |
return | |
except Exception as e: | |
LOGGER.error("Error refreshing token: %s", e) | |
raise e | |
class WagtailOIDCLogoutView(LogoutView): | |
"""Expand the Wagtail LogoutView to redirect the user to OIDC end_session_endpoint""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.OIDC_RP_CLIENT_ID = self.get_settings("OIDC_RP_CLIENT_ID") | |
self.OIDC_RP_END_SESSION_ENDPOINT = self.get_settings("OIDC_RP_END_SESSION_ENDPOINT") | |
@staticmethod | |
def get_settings(attr, *args): | |
return import_from_settings(attr, *args) | |
def get_next_page(self): | |
""" | |
The page the user with be redirected to after the django logout | |
The page is the OIDC end session endpoint, with a querystring that informs it | |
to redirect the client back to wagtail. | |
django oauth toolkit requres the client_id and id_token_hint set when doing an | |
unprompted post logout redirect | |
see: method validate_logout_request from | |
https://github.com/jazzband/django-oauth-toolkit/blob/master/oauth2_provider/views/oidc.py#L213 | |
see: https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout | |
""" | |
end_session_endpoint = self.OIDC_RP_END_SESSION_ENDPOINT | |
home_url = self.request.build_absolute_uri(reverse("wagtailadmin_home")) | |
parameters = dict( | |
post_logout_redirect_uri=home_url, | |
client_id=self.OIDC_RP_CLIENT_ID, | |
) | |
if "oidc_id_token" in self.request.session: | |
parameters['id_token_hint'] = self.request.session["oidc_id_token"] | |
return f"{end_session_endpoint}?{urlencode(parameters)}" | |
@method_decorator(never_cache) | |
def dispatch(self, request, *args, **kwargs): | |
""" | |
The parent class calls get_next_page after logout when the session information has been cleared. | |
We need to pull oidc_id_token out of the session first before the logout clears it | |
""" | |
next_page = self.get_next_page() | |
super().dispatch(request, *args, **kwargs) | |
return HttpResponseRedirect(next_page) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mozilla-django-oidc==2.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
INSTALLED_APPS = [ | |
# ... | |
"mozilla_django_oidc", | |
] | |
MIDDLEWARE = [ | |
.... | |
oidc.RefreshTokenMiddleware | |
] | |
AUTHENTICATION_BACKENDS = [ | |
... | |
oidc.WagtailOIDCAuthenticationBackend | |
] | |
# openif configuration | |
OIDC_OP_JWKS_ENDPOINT = env("OIDC_OP_JWKS_ENDPOINT") | |
OIDC_RP_CLIENT_ID = env("OIDC_RP_CLIENT_ID") | |
OIDC_RP_CLIENT_SECRET = env("OIDC_RP_CLIENT_SECRET") | |
OIDC_RP_END_SESSION_ENDPOINT = env("OIDC_RP_END_SESSION_ENDPOINT") | |
OIDC_OP_AUTHORIZATION_ENDPOINT = env("OIDC_OP_AUTHORIZATION_ENDPOINT") | |
OIDC_OP_TOKEN_ENDPOINT = env("OIDC_OP_TOKEN_ENDPOINT") | |
OIDC_OP_USER_ENDPOINT = env("OIDC_OP_USER_ENDPOINT") | |
# Renew in 5 minutes? (default is 15), to see if it solves log | |
# premature logout issue | |
OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS = 0.8 * 3600 | |
OIDC_RP_SIGN_ALGO = "RS256" | |
OIDC_USERNAME_ALGO = "oidc.generate_username" | |
OIDC_RP_SCOPES = "openid email profile" | |
LOGIN_REDIRECT_URL = "/" | |
LOGOUT_REDIRECT_URL = "/" | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
from oidc import WagtailOIDCLogoutView | |
url_patterns = [ | |
# override the wagtail logout view with an OIDC specific version | |
path("logout/", WagtailOIDCLogoutView.as_view(), name="wagtailadmin_logout"), | |
path("admin/", admin.site.urls), | |
path("oidc/", include("mozilla_django_oidc.urls")), | |
... | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment