-
-
Save data-scientist-ml1/06a6d70d9700d4f24e61ead55423d1c2 to your computer and use it in GitHub Desktop.
DRF simple JWT logout flow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.utils.text import gettext_lazy as _ | |
from rest_framework import serializers | |
from rest_framework_simplejwt.tokens import RefreshToken, TokenError | |
class RefreshTokenSerializer(serializers.Serializer): | |
refresh = serializers.CharField() | |
default_error_messages = { | |
'bad_token': _('Token is invalid or expired') | |
} | |
def validate(self, attrs): | |
self.token = attrs['refresh'] | |
return attrs | |
def save(self, **kwargs): | |
try: | |
RefreshToken(self.token).blacklist() | |
except TokenError: | |
self.fail('bad_token') | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from rest_framework.test import APITestCase | |
from rest_framework.reverse import reverse | |
from rest_framework_simplejwt.exceptions import TokenError | |
from rest_framework_simplejwt.tokens import RefreshToken | |
from apps.account.models import User | |
class TestLoginCase(APITestCase): | |
login_url = reverse('token_obtain_pair') | |
refresh_token_url = reverse('token_refresh') | |
logout_url = reverse('logout') | |
email = '[email protected]' | |
password = 'kah2ie3urh4k' | |
def setUp(self): | |
self.user = User.objects.create_user(self.email, self.password) | |
def _login(self): | |
data = { | |
'email': self.email, 'password': self.password | |
} | |
r = self.client.post(self.login_url, data) | |
body = r.json() | |
if 'access' in body: | |
self.client.credentials( | |
HTTP_AUTHORIZATION='Bearer %s' % body['access']) | |
return r.status_code, body | |
def test_logout_response_200(self): | |
_, body = self._login() | |
data = {'refresh': body['refresh']} | |
r = self.client.post(self.logout_url, data) | |
body = r.content | |
self.assertEquals(r.status_code, 204, body) | |
self.assertFalse(body, body) | |
def test_logout_with_bad_refresh_token_response_400(self): | |
self._login() | |
data = {'refresh': 'dsf.sdfsdf.sdf'} | |
r = self.client.post(self.logout_url, data) | |
body = r.json() | |
self.assertEquals(r.status_code, 400, body) | |
self.assertTrue(body, body) | |
def test_logout_refresh_token_in_blacklist(self): | |
_, body = self._login() | |
r = self.client.post(self.logout_url, body) | |
token = partial(RefreshToken, body['refresh']) | |
self.assertRaises(TokenError, token) | |
def test_access_token_still_valid_after_logout(self): | |
_, body = self._login() | |
self.client.post(self.logout_url, body) | |
r = self.client.get(self.profile_url) | |
body = r.json() | |
self.assertEquals(r.status_code, 200, body) | |
self.assertTrue(body, body) | |
def test_access_token_invalid_in_hour_after_logout(self): | |
_, body = self._login() | |
self.client.post(self.logout_url, body) | |
m = mock.Mock() | |
m.return_value = aware_utcnow() + timedelta(minutes=60) | |
with mock.patch('rest_framework_simplejwt.tokens.aware_utcnow', m): | |
r = self.client.get(self.profile_url) | |
body = r.json() | |
self.assertEquals(r.status_code, 401, body) | |
self.assertTrue(body, body) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rest_framework import permissions, status | |
from rest_framework.generics import GenericAPIView | |
from rest_framework.response import Response | |
from .serializers import RefreshTokenSerializer | |
class LogoutView(GenericAPIView): | |
serializer_class = RefreshTokenSerializer | |
permission_classes = (permissions.IsAuthenticated, ) | |
def post(self, request, *args): | |
sz = self.get_serializer(data=request.data) | |
sz.is_valid(raise_exception=True) | |
sz.save() | |
return Response(status=status.HTTP_204_NO_CONTENT) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment