15
15
16
16
17
17
class AdfsBaseBackend (ModelBackend ):
18
- def exchange_auth_code (self , authorization_code , request ):
19
- logger .debug ("Received authorization code: %s" , authorization_code )
20
- data = {
21
- 'grant_type' : 'authorization_code' ,
22
- 'client_id' : settings .CLIENT_ID ,
23
- 'redirect_uri' : provider_config .redirect_uri (request ),
24
- 'code' : authorization_code ,
25
- }
26
- if settings .CLIENT_SECRET :
27
- data ['client_secret' ] = settings .CLIENT_SECRET
28
18
29
- logger .debug ("Getting access token at: %s" , provider_config .token_endpoint )
30
- response = provider_config .session .post (provider_config .token_endpoint , data , timeout = settings .TIMEOUT )
19
+ def _ms_request (self , action , url , data = None , ** kwargs ):
20
+ """
21
+ Make a Microsoft Entra/GraphQL request
22
+
23
+
24
+ Args:
25
+ action (callable): The callable for making a request.
26
+ url (str): The URL the request should be sent to.
27
+ data (dict): Optional dictionary of data to be sent in the request.
28
+
29
+ Returns:
30
+ response: The response from the server. If it's not a 200, a
31
+ PermissionDenied is raised.
32
+ """
33
+ response = action (url , data = data , timeout = settings .TIMEOUT , ** kwargs )
31
34
# 200 = valid token received
32
35
# 400 = 'something' is wrong in our request
33
36
if response .status_code == 400 :
@@ -39,7 +42,21 @@ def exchange_auth_code(self, authorization_code, request):
39
42
if response .status_code != 200 :
40
43
logger .error ("Unexpected ADFS response: %s" , response .content .decode ())
41
44
raise PermissionDenied
45
+ return response
46
+
47
+ def exchange_auth_code (self , authorization_code , request ):
48
+ logger .debug ("Received authorization code: %s" , authorization_code )
49
+ data = {
50
+ 'grant_type' : 'authorization_code' ,
51
+ 'client_id' : settings .CLIENT_ID ,
52
+ 'redirect_uri' : provider_config .redirect_uri (request ),
53
+ 'code' : authorization_code ,
54
+ }
55
+ if settings .CLIENT_SECRET :
56
+ data ['client_secret' ] = settings .CLIENT_SECRET
42
57
58
+ logger .debug ("Getting access token at: %s" , provider_config .token_endpoint )
59
+ response = self ._ms_request (provider_config .session .post , provider_config .token_endpoint , data )
43
60
adfs_response = response .json ()
44
61
return adfs_response
45
62
@@ -66,17 +83,7 @@ def get_obo_access_token(self, access_token):
66
83
else :
67
84
data ["resource" ] = 'https://graph.microsoft.com'
68
85
69
- response = provider_config .session .get (provider_config .token_endpoint , data = data , timeout = settings .TIMEOUT )
70
- # 200 = valid token received
71
- # 400 = 'something' is wrong in our request
72
- if response .status_code == 400 :
73
- logger .error ("ADFS server returned an error: %s" , response .json ()["error_description" ])
74
- raise PermissionDenied
75
-
76
- if response .status_code != 200 :
77
- logger .error ("Unexpected ADFS response: %s" , response .content .decode ())
78
- raise PermissionDenied
79
-
86
+ response = self ._ms_request (provider_config .session .get , provider_config .token_endpoint , data )
80
87
obo_access_token = response .json ()["access_token" ]
81
88
logger .debug ("Received OBO access token: %s" , obo_access_token )
82
89
return obo_access_token
@@ -95,17 +102,11 @@ def get_group_memberships_from_ms_graph(self, obo_access_token):
95
102
provider_config .msgraph_endpoint
96
103
)
97
104
headers = {"Authorization" : "Bearer {}" .format (obo_access_token )}
98
- response = provider_config .session .get (graph_url , headers = headers , timeout = settings .TIMEOUT )
99
- # 200 = valid token received
100
- # 400 = 'something' is wrong in our request
101
- if response .status_code in [400 , 401 ]:
102
- logger .error ("MS Graph server returned an error: %s" , response .json ()["message" ])
103
- raise PermissionDenied
104
-
105
- if response .status_code != 200 :
106
- logger .error ("Unexpected MS Graph response: %s" , response .content .decode ())
107
- raise PermissionDenied
108
-
105
+ response = self ._ms_request (
106
+ action = provider_config .session .get ,
107
+ url = graph_url ,
108
+ headers = headers ,
109
+ )
109
110
claim_groups = []
110
111
for group_data in response .json ()["value" ]:
111
112
if group_data ["displayName" ] is None :
0 commit comments