diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index c3165cf..20abd24 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -1,4 +1,5 @@ import logging +from datetime import datetime, timedelta import jwt from django.contrib.auth import get_user_model @@ -424,6 +425,63 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): return user +class AdfsAuthCodeRefreshBackend(AdfsBaseBackend): + """ + Authentication backend that supports storing and refreshing ADFS tokens in the session. + Use this backend in conjunction with AdfsRefreshMiddleware. + """ + def authenticate(self, request=None, authorization_code=None, **kwargs): + # If there's no token or code, we pass control to the next authentication backend + if authorization_code is None or authorization_code == '': + logger.debug("Authentication backend was called but no authorization code was received") + return + + # If there's no request object, we pass control to the next authentication backend + if request is None: + logger.debug("Authentication backend was called without request") + return + + # If loaded data is too old, reload it again + provider_config.load_config() + + adfs_response = self.exchange_auth_code(authorization_code, request) + access_token = adfs_response["access_token"] + user = self.process_access_token(access_token, adfs_response) + self._store_adfs_tokens_in_session(request, adfs_response) + return user + + def ensure_valid_access_token(self, request): + now = datetime.now() + settings.REFRESH_THRESHOLD + expiry = datetime.fromisoformat(request.session["_adfs_token_expiry"]) + if now > expiry: + adfs_refresh_response = self._refresh_access_token( + request.session["_adfs_refresh_token"] + ) + self._store_adfs_tokens_in_session(request, adfs_refresh_response) + + def _refresh_access_token(self, refresh_token): + provider_config.load_config() + response = provider_config.session.post( + provider_config.token_endpoint, + data=f'client_id={settings.CLIENT_ID}&client_secret={settings.CLIENT_SECRET}&grant_type=refresh_token' + + f'&refresh_token={refresh_token}' + ) + response.raise_for_status() + adfs_response = response.json() + return adfs_response + + def _store_adfs_tokens_in_session(self, request, adfs_response): + assert "refresh_token" in adfs_response, ( + "AdfsAuthCodeRefreshBackend requires a refresh token to function correctly. " + "Make sure your ADFS server is configured to return a refresh token." + ) + request.session["_adfs_access_token"] = adfs_response["access_token"] + expiry = datetime.now() + timedelta(seconds=int(adfs_response["expires_in"])) + request.session["_adfs_token_expiry"] = expiry.isoformat() + request.session["_adfs_refresh_token"] = adfs_response["refresh_token"] + request.session.save() + + class AdfsAccessTokenBackend(AdfsBaseBackend): """ Authentication backend to allow authenticating users against a diff --git a/django_auth_adfs/config.py b/django_auth_adfs/config.py index 317781f..9964be2 100644 --- a/django_auth_adfs/config.py +++ b/django_auth_adfs/config.py @@ -72,6 +72,7 @@ def __init__(self): self.USERNAME_CLAIM = "winaccountname" self.GUEST_USERNAME_CLAIM = None self.JWT_LEEWAY = 0 + self.REFRESH_THRESHOLD = timedelta(minutes=5) self.CUSTOM_FAILED_RESPONSE_VIEW = lambda request, error_message, status: render( request, 'django_auth_adfs/login_failed.html', {'error_message': error_message}, status=status ) diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 649a239..2c506fd 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -1,12 +1,19 @@ """ Based on https://djangosnippets.org/snippets/1179/ """ +import logging from re import compile +from requests import HTTPError from django.conf import settings as django_settings +from django.contrib import auth from django.contrib.auth.views import redirect_to_login +from django.contrib.auth import logout +from django.core.exceptions import (PermissionDenied) + from django.urls import reverse +from django_auth_adfs.backend import AdfsAuthCodeRefreshBackend from django_auth_adfs.exceptions import MFARequired from django_auth_adfs.config import settings @@ -19,6 +26,8 @@ if hasattr(settings, 'LOGIN_EXEMPT_URLS'): LOGIN_EXEMPT_URLS += [compile(expr) for expr in settings.LOGIN_EXEMPT_URLS] +logger = logging.getLogger("django_auth_adfs") + class LoginRequiredMiddleware: """ @@ -49,3 +58,41 @@ def __call__(self, request): return redirect_to_login('django_auth_adfs:login-force-mfa') return self.get_response(request) + + +class AdfsRefreshMiddleware: + """ + Middleware that refreshes the access token for the user if it is close to + expiring. This is done by checking the session for the '_adfs_token_expiry' + key and comparing it with the current time plus a threshold defined in + settings.REFRESH_THRESHOLD. + """ + + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if hasattr(django_settings, "SESSION_ENGINE"): + assert ( + django_settings.SESSION_ENGINE + != "django.contrib.sessions.backends.signed_cookies" + ), ( + "You are trying to use ADFS Refresh middleware with signed cookie-based sessions. " + "For security reasons, we do not recommend this configuration. " + "Please change SESSION_ENGINE to a different backend, such as 'django.contrib.sessions.backends.db' " + ) + + try: + backend_str = request.session[auth.BACKEND_SESSION_KEY] + except KeyError: + pass + else: + backend = auth.load_backend(backend_str) + if isinstance(backend, AdfsAuthCodeRefreshBackend): + try: + backend.ensure_valid_access_token(request) + except (PermissionDenied, HTTPError) as error: + logger.debug("Error refreshing access token: %s", error) + logout(request) + + return self.get_response(request) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index aabe899..37fa615 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,5 +1,9 @@ import base64 +from datetime import datetime, timedelta + +from django.urls import reverse + from django_auth_adfs.exceptions import MFARequired try: @@ -10,13 +14,12 @@ from copy import deepcopy from django.contrib.auth.models import Group, User -from django.core.exceptions import ObjectDoesNotExist, PermissionDenied +from django.core.exceptions import ObjectDoesNotExist from django.db.models.signals import post_save -from django.test import RequestFactory, TestCase +from django.test import TestCase, override_settings from mock import Mock, patch from django_auth_adfs import signals -from django_auth_adfs.backend import AdfsAuthCodeBackend from django_auth_adfs.config import ProviderConfig, Settings from .models import Profile @@ -28,20 +31,18 @@ def setUp(self): Group.objects.create(name='group1') Group.objects.create(name='group2') Group.objects.create(name='group3') - self.request = RequestFactory().get('/oauth2/callback') self.signal_handler = Mock() signals.post_authenticate.connect(self.signal_handler) @mock_adfs("2012") def test_post_authenticate_signal_send(self): - backend = AdfsAuthCodeBackend() - backend.authenticate(self.request, authorization_code="dummycode") + self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) self.assertEqual(self.signal_handler.call_count, 1) @mock_adfs("2012") def test_with_auth_code_2012(self): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -52,8 +53,8 @@ def test_with_auth_code_2012(self): @mock_adfs("2016") def test_with_auth_code_2016(self): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -64,9 +65,15 @@ def test_with_auth_code_2016(self): @mock_adfs("2016", mfa_error=True) def test_mfa_error_backends(self): - with self.assertRaises(MFARequired): - backend = AdfsAuthCodeBackend() - backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + self.assertEqual(response.status_code, 302) + self.assertEqual( + response['Location'], + "https://adfs.example.com/adfs/oauth2/authorize/?response_type=code&" + "client_id=your-configured-client-id&resource=your-adfs-RPT-name&" + "redirect_uri=http%3A%2F%2Ftestserver%2Foauth2%2Fcallback&state=Lw%3D%3D&scope=openid&" + "amr_values=ngcmfa" + ) @mock_adfs("azure") def test_with_auth_code_azure(self): @@ -77,8 +84,8 @@ def test_with_auth_code_azure(self): with patch("django_auth_adfs.config.django_settings", settings): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -100,9 +107,8 @@ def test_with_auth_code_azure_guest_block(self): with patch('django_auth_adfs.backend.settings', Settings()): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - with self.assertRaises(PermissionDenied, msg=''): - backend = AdfsAuthCodeBackend() - _ = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + self.assertEqual(response.status_code, 401) @mock_adfs("azure", guest=True) def test_with_auth_code_azure_guest_no_block(self): @@ -117,8 +123,8 @@ def test_with_auth_code_azure_guest_no_block(self): with patch('django_auth_adfs.backend.settings', Settings()): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -139,8 +145,8 @@ def test_version_two_endpoint_calls_correct_url(self): with patch('django_auth_adfs.backend.settings', Settings()): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -151,14 +157,15 @@ def test_version_two_endpoint_calls_correct_url(self): @mock_adfs("2016") def test_empty(self): - backend = AdfsAuthCodeBackend() - self.assertIsNone(backend.authenticate(self.request)) + response = self.client.get(reverse('django_auth_adfs:callback')) + user = response.wsgi_request.user + self.assertTrue(user.is_anonymous) @mock_adfs("2016") def test_group_claim(self): - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", "nonexisting"): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -167,9 +174,9 @@ def test_group_claim(self): @mock_adfs("2016") def test_no_group_claim(self): - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", None): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -181,9 +188,9 @@ def test_group_claim_with_mirror_groups(self): # Remove one group Group.objects.filter(name="group1").delete() - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -197,9 +204,9 @@ def test_group_claim_without_mirror_groups(self): # Remove one group Group.objects.filter(name="group1").delete() - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", False): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -210,9 +217,9 @@ def test_group_claim_without_mirror_groups(self): @mock_adfs("2016", empty_keys=True) def test_empty_keys(self): - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.config.provider_config.signing_keys", []): - self.assertRaises(PermissionDenied, backend.authenticate, self.request, authorization_code='testcode') + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertEqual(response.status_code, 401) @mock_adfs("2016") def test_group_removal(self): @@ -227,9 +234,8 @@ def test_group_removal(self): self.assertEqual(user.groups.all()[0].name, "group3") self.assertEqual(len(user.groups.all()), 1) - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -253,9 +259,8 @@ def test_group_removal_overlap(self): self.assertEqual(user.groups.all()[1].name, "group3") self.assertEqual(len(user.groups.all()), 2) - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -272,9 +277,8 @@ def test_group_to_flag_mapping(self): } with patch("django_auth_adfs.backend.settings.GROUP_TO_FLAG_MAPPING", group_to_flag_mapping): with patch("django_auth_adfs.backend.settings.BOOLEAN_CLAIM_MAPPING", {}): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -289,9 +293,8 @@ def test_boolean_claim_mapping(self): "is_superuser": "user_is_superuser", } with patch("django_auth_adfs.backend.settings.BOOLEAN_CLAIM_MAPPING", boolean_claim_mapping): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -312,9 +315,8 @@ def test_extended_model_claim_mapping_missing_instance(self): }, } with patch("django_auth_adfs.backend.settings.CLAIM_MAPPING", claim_mapping): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -340,9 +342,8 @@ def create_profile(sender, instance, created, **kwargs): }, } with patch("django_auth_adfs.backend.settings.CLAIM_MAPPING", claim_mapping): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -521,5 +522,49 @@ def test_nonexisting_user(self): settings.AUTH_ADFS["CREATE_NEW_USERS"] = False with patch("django_auth_adfs.config.django_settings", settings), \ patch("django_auth_adfs.backend.settings", Settings()): - backend = AdfsAuthCodeBackend() - self.assertRaises(PermissionDenied, backend.authenticate, self.request, authorization_code='testcode') + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertEqual(response.status_code, 401) + + @mock_adfs("2016") + def test_access_token_unexpired(self): + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertFalse(response.wsgi_request.user.is_anonymous) + response = self.client.get(reverse('test')) + self.assertEqual(response.status_code, 200) + + @mock_adfs("2016") + def test_access_token_expired(self): + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertFalse(response.wsgi_request.user.is_anonymous) + fromisoformat = datetime.fromisoformat + with patch('django_auth_adfs.backend.datetime') as dt: + dt.fromisoformat = fromisoformat + dt.now.return_value = datetime.now() + timedelta(hours=1) + response = self.client.get(reverse('test')) + self.assertEqual(response.status_code, 200) + + @override_settings(AUTHENTICATION_BACKENDS=['django_auth_adfs.backend.AdfsAuthCodeRefreshBackend']) + @override_settings( + MIDDLEWARE=[ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', + "django_auth_adfs.middleware.AdfsRefreshMiddleware", + "django_auth_adfs.middleware.LoginRequiredMiddleware", + ] + ) + @mock_adfs("2016", refresh_token_expired=True) + def test_refresh_token_expired(self): + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertFalse(response.wsgi_request.user.is_anonymous) + fromisoformat = datetime.fromisoformat + with patch('django_auth_adfs.backend.datetime') as dt: + dt.fromisoformat = fromisoformat + dt.now.return_value = datetime.now() + timedelta(hours=2) + response = self.client.get(reverse('test')) + self.assertEqual(response.status_code, 302) + self.assertEqual(response['Location'], f"{reverse('django_auth_adfs:login')}?next=/") + self.assertTrue(response.wsgi_request.user.is_anonymous) diff --git a/tests/urls.py b/tests/urls.py index e3a608d..9ad8a6e 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,6 +1,9 @@ -from django.urls import include, re_path +from django.urls import include, re_path, path + +from tests.views import TestView urlpatterns = [ + path('', TestView.as_view(), name='test'), re_path(r'^oauth2/', include('django_auth_adfs.urls')), re_path(r'^oauth2/', include('django_auth_adfs.drf_urls')), ] diff --git a/tests/utils.py b/tests/utils.py index f6040d2..bda61c6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,7 @@ import time from datetime import datetime, tzinfo, timedelta from functools import partial +from urllib.parse import parse_qs import jwt import responses @@ -98,9 +99,14 @@ def build_access_token_azure_groups_in_claim_source(request): return do_build_access_token(request, issuer, groups_in_claim_names=True) +def build_access_token_adfs_expired(request): + issuer = "http://adfs.example.com/adfs/services/trust" + return do_build_access_token(request, issuer, refresh_token_expired=True) + + def do_build_mfa_error(request): response = {'error_description': 'AADSTS50076'} - return 400, [], json.dumps(response) + return 400, {}, json.dumps(response) def do_build_graph_response(request): @@ -111,7 +117,11 @@ def do_build_graph_response_no_group_perm(request): return do_build_ms_graph_groups(request, missing_group_names=True) -def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, groups_in_claim_names=False): +def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, groups_in_claim_names=False, + refresh_token_expired=False): + data = parse_qs(request.body) + if data.get('grant_type') == ['refresh_token'] and data.get('refresh_token') == ['expired_refresh_token']: + return 401, {}, None issued_at = int(time.time()) expires = issued_at + 3600 auth_time = datetime.utcnow() @@ -159,16 +169,20 @@ def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, } } token = jwt.encode(claims, signing_key_b, algorithm="RS256") + if refresh_token_expired: + refresh_token = 'expired_refresh_token' + else: + refresh_token = 'random_refresh_token' response = { 'resource': 'django_website.adfs.relying_party_id', 'token_type': 'bearer', 'refresh_token_expires_in': 28799, - 'refresh_token': 'random_refresh_token', + 'refresh_token': refresh_token, 'expires_in': 3600, 'id_token': 'not_used', 'access_token': token.decode() if isinstance(token, bytes) else token # PyJWT>=2 returns a str instead of bytes } - return 200, [], json.dumps(response) + return 200, {}, json.dumps(response) def do_build_obo_access_token(request): @@ -228,7 +242,7 @@ def do_build_obo_access_token(request): 'refresh_token': 'not_used', 'access_token': token.decode() if isinstance(token, bytes) else token # PyJWT>=2 returns a str instead of bytes } - return 200, [], json.dumps(response) + return 200, {}, json.dumps(response) def do_build_ms_graph_groups(request, missing_group_names=False): @@ -308,7 +322,7 @@ def do_build_ms_graph_groups(request, missing_group_names=False): if missing_group_names: for group in response["value"]: group["displayName"] = None - return 200, [], json.dumps(response) + return 200, {}, json.dumps(response) def build_openid_keys(request, empty_keys=False): @@ -337,7 +351,7 @@ def build_openid_keys(request, empty_keys=False): }, ] } - return 200, [], json.dumps(keys) + return 200, {}, json.dumps(keys) def build_adfs_meta(request): @@ -345,7 +359,7 @@ def build_adfs_meta(request): data = "".join(f.readlines()) data = data.replace("REPLACE_WITH_CERT_A", base64.b64encode(signing_cert_a).decode()) data = data.replace("REPLACE_WITH_CERT_B", base64.b64encode(signing_cert_b).decode()) - return 200, [], data + return 200, {}, data def mock_adfs( @@ -356,6 +370,7 @@ def mock_adfs( version=None, requires_obo=False, missing_graph_group_perm=False, + refresh_token_expired=False, ): if adfs_version not in ["2012", "2016", "azure"]: raise NotImplementedError("This version of ADFS is not implemented") @@ -465,6 +480,12 @@ def wrapper(*original_args, **original_kwargs): callback=do_build_mfa_error, content_type='application/json', ) + elif refresh_token_expired: + rsps.add_callback( + rsps.POST, token_endpoint, + callback=build_access_token_adfs_expired, + content_type='application/json', + ) else: rsps.add_callback( rsps.POST, token_endpoint, diff --git a/tests/views.py b/tests/views.py index b16e402..7bb0bed 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,2 +1,11 @@ +from django.http import HttpResponse +from django.views import View + + def test_failed_response(request, error_message, status): pass + + +class TestView(View): + def get(self, request): + return HttpResponse('okay')