Skip to content

Commit 3fd34c7

Browse files
committed
Consolidate request handling logic into helper method.
The _ms_request conslidates the error handling logic into a single method. This should make overriding methods easier for end users.
1 parent 5ea3c87 commit 3fd34c7

File tree

2 files changed

+41
-35
lines changed

2 files changed

+41
-35
lines changed

django_auth_adfs/backend.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@
1515

1616

1717
class AdfsBaseBackend(ModelBackend):
18-
def exchange_auth_code(self, authorization_code, request):
19-
logger.debug("Received authorization code: %s", authorization_code)
20-
data = {
21-
'grant_type': 'authorization_code',
22-
'client_id': settings.CLIENT_ID,
23-
'redirect_uri': provider_config.redirect_uri(request),
24-
'code': authorization_code,
25-
}
26-
if settings.CLIENT_SECRET:
27-
data['client_secret'] = settings.CLIENT_SECRET
2818

29-
logger.debug("Getting access token at: %s", provider_config.token_endpoint)
30-
response = provider_config.session.post(provider_config.token_endpoint, data, timeout=settings.TIMEOUT)
19+
def _ms_request(self, action, url, data=None, **kwargs):
20+
"""
21+
Make a Microsoft Entra/GraphQL request
22+
23+
24+
Args:
25+
action (callable): The callable for making a request.
26+
url (str): The URL the request should be sent to.
27+
data (dict): Optional dictionary of data to be sent in the request.
28+
29+
Returns:
30+
response: The response from the server. If it's not a 200, a
31+
PermissionDenied is raised.
32+
"""
33+
response = action(url, data=data, timeout=settings.TIMEOUT, **kwargs)
3134
# 200 = valid token received
3235
# 400 = 'something' is wrong in our request
3336
if response.status_code == 400:
@@ -39,7 +42,21 @@ def exchange_auth_code(self, authorization_code, request):
3942
if response.status_code != 200:
4043
logger.error("Unexpected ADFS response: %s", response.content.decode())
4144
raise PermissionDenied
45+
return response
46+
47+
def exchange_auth_code(self, authorization_code, request):
48+
logger.debug("Received authorization code: %s", authorization_code)
49+
data = {
50+
'grant_type': 'authorization_code',
51+
'client_id': settings.CLIENT_ID,
52+
'redirect_uri': provider_config.redirect_uri(request),
53+
'code': authorization_code,
54+
}
55+
if settings.CLIENT_SECRET:
56+
data['client_secret'] = settings.CLIENT_SECRET
4257

58+
logger.debug("Getting access token at: %s", provider_config.token_endpoint)
59+
response = self._ms_request(provider_config.session.post, provider_config.token_endpoint, data)
4360
adfs_response = response.json()
4461
return adfs_response
4562

@@ -66,17 +83,7 @@ def get_obo_access_token(self, access_token):
6683
else:
6784
data["resource"] = 'https://graph.microsoft.com'
6885

69-
response = provider_config.session.get(provider_config.token_endpoint, data=data, timeout=settings.TIMEOUT)
70-
# 200 = valid token received
71-
# 400 = 'something' is wrong in our request
72-
if response.status_code == 400:
73-
logger.error("ADFS server returned an error: %s", response.json()["error_description"])
74-
raise PermissionDenied
75-
76-
if response.status_code != 200:
77-
logger.error("Unexpected ADFS response: %s", response.content.decode())
78-
raise PermissionDenied
79-
86+
response = self._ms_request(provider_config.session.get, provider_config.token_endpoint, data)
8087
obo_access_token = response.json()["access_token"]
8188
logger.debug("Received OBO access token: %s", obo_access_token)
8289
return obo_access_token
@@ -95,17 +102,11 @@ def get_group_memberships_from_ms_graph(self, obo_access_token):
95102
provider_config.msgraph_endpoint
96103
)
97104
headers = {"Authorization": "Bearer {}".format(obo_access_token)}
98-
response = provider_config.session.get(graph_url, headers=headers, timeout=settings.TIMEOUT)
99-
# 200 = valid token received
100-
# 400 = 'something' is wrong in our request
101-
if response.status_code in [400, 401]:
102-
logger.error("MS Graph server returned an error: %s", response.json()["message"])
103-
raise PermissionDenied
104-
105-
if response.status_code != 200:
106-
logger.error("Unexpected MS Graph response: %s", response.content.decode())
107-
raise PermissionDenied
108-
105+
response = self._ms_request(
106+
action=provider_config.session.get,
107+
url=graph_url,
108+
headers=headers,
109+
)
109110
claim_groups = []
110111
for group_data in response.json()["value"]:
111112
if group_data["displayName"] is None:

django_auth_adfs/rest_framework.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
BaseAuthentication, get_authorization_header
77
)
88

9+
from django_auth_adfs.exceptions import MFARequired
10+
911

1012
class AdfsAccessTokenAuthentication(BaseAuthentication):
1113
"""
@@ -33,7 +35,10 @@ def authenticate(self, request):
3335
# Authenticate the user
3436
# The AdfsAuthCodeBackend authentication backend will notice the "access_token" parameter
3537
# and skip the request for an access token using the authorization code
36-
user = authenticate(access_token=auth[1])
38+
try:
39+
user = authenticate(access_token=auth[1])
40+
except MFARequired as e:
41+
raise exceptions.AuthenticationFailed('MFA auth is required.') from e
3742

3843
if user is None:
3944
raise exceptions.AuthenticationFailed('Invalid access token.')

0 commit comments

Comments
 (0)