4
4
from django .contrib .auth import get_user_model
5
5
from django .contrib .auth .backends import ModelBackend
6
6
from django .contrib .auth .models import Group
7
- from django .core .exceptions import (ImproperlyConfigured , ObjectDoesNotExist ,
8
- PermissionDenied )
7
+ from django .core .exceptions import (
8
+ ImproperlyConfigured ,
9
+ ObjectDoesNotExist ,
10
+ PermissionDenied ,
11
+ )
9
12
10
13
from django_auth_adfs import signals
11
14
from django_auth_adfs .config import provider_config , settings
15
18
16
19
17
20
class AdfsBaseBackend (ModelBackend ):
18
-
19
21
def _ms_request (self , action , url , data = None , ** kwargs ):
20
22
"""
21
23
Make a Microsoft Entra/GraphQL request
@@ -36,7 +38,10 @@ def _ms_request(self, action, url, data=None, **kwargs):
36
38
if response .status_code == 400 :
37
39
if response .json ().get ("error_description" , "" ).startswith ("AADSTS50076" ):
38
40
raise MFARequired
39
- logger .error ("ADFS server returned an error: %s" , response .json ()["error_description" ])
41
+ logger .error (
42
+ "ADFS server returned an error: %s" ,
43
+ response .json ()["error_description" ],
44
+ )
40
45
raise PermissionDenied
41
46
42
47
if response .status_code != 200 :
@@ -47,16 +52,18 @@ def _ms_request(self, action, url, data=None, **kwargs):
47
52
def exchange_auth_code (self , authorization_code , request ):
48
53
logger .debug ("Received authorization code: %s" , authorization_code )
49
54
data = {
50
- ' grant_type' : ' authorization_code' ,
51
- ' client_id' : settings .CLIENT_ID ,
52
- ' redirect_uri' : provider_config .redirect_uri (request ),
53
- ' code' : authorization_code ,
55
+ " grant_type" : " authorization_code" ,
56
+ " client_id" : settings .CLIENT_ID ,
57
+ " redirect_uri" : provider_config .redirect_uri (request ),
58
+ " code" : authorization_code ,
54
59
}
55
60
if settings .CLIENT_SECRET :
56
- data [' client_secret' ] = settings .CLIENT_SECRET
61
+ data [" client_secret" ] = settings .CLIENT_SECRET
57
62
58
63
logger .debug ("Getting access token at: %s" , provider_config .token_endpoint )
59
- response = self ._ms_request (provider_config .session .post , provider_config .token_endpoint , data )
64
+ response = self ._ms_request (
65
+ provider_config .session .post , provider_config .token_endpoint , data
66
+ )
60
67
adfs_response = response .json ()
61
68
return adfs_response
62
69
@@ -79,11 +86,13 @@ def get_obo_access_token(self, access_token):
79
86
"requested_token_use" : "on_behalf_of" ,
80
87
}
81
88
if provider_config .token_endpoint .endswith ("/v2.0/token" ):
82
- data ["scope" ] = ' GroupMember.Read.All'
89
+ data ["scope" ] = " GroupMember.Read.All"
83
90
else :
84
- data ["resource" ] = ' https://graph.microsoft.com'
91
+ data ["resource" ] = " https://graph.microsoft.com"
85
92
86
- response = self ._ms_request (provider_config .session .get , provider_config .token_endpoint , data )
93
+ response = self ._ms_request (
94
+ provider_config .session .get , provider_config .token_endpoint , data
95
+ )
87
96
obo_access_token = response .json ()["access_token" ]
88
97
logger .debug ("Received OBO access token: %s" , obo_access_token )
89
98
return obo_access_token
@@ -117,8 +126,10 @@ def get_group_memberships_from_ms_graph(self, obo_access_token):
117
126
Returns:
118
127
claim_groups (list): List of the users group memberships
119
128
"""
120
- graph_url = "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group" .format (
121
- provider_config .msgraph_endpoint
129
+ graph_url = (
130
+ "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group" .format (
131
+ provider_config .msgraph_endpoint
132
+ )
122
133
)
123
134
headers = {"Authorization" : "Bearer {}" .format (obo_access_token )}
124
135
response = self ._ms_request (
@@ -147,25 +158,25 @@ def validate_access_token(self, access_token):
147
158
# Explicit is better then implicit and it protects against
148
159
# changes in the defaults the jwt module uses.
149
160
options = {
150
- ' verify_signature' : True ,
151
- ' verify_exp' : True ,
152
- ' verify_nbf' : True ,
153
- ' verify_iat' : True ,
154
- ' verify_aud' : True ,
155
- ' verify_iss' : True ,
156
- ' require_exp' : False ,
157
- ' require_iat' : False ,
158
- ' require_nbf' : False
161
+ " verify_signature" : True ,
162
+ " verify_exp" : True ,
163
+ " verify_nbf" : True ,
164
+ " verify_iat" : True ,
165
+ " verify_aud" : True ,
166
+ " verify_iss" : True ,
167
+ " require_exp" : False ,
168
+ " require_iat" : False ,
169
+ " require_nbf" : False ,
159
170
}
160
171
# Validate token and return claims
161
172
return jwt .decode (
162
173
access_token ,
163
174
key = key ,
164
- algorithms = [' RS256' , ' RS384' , ' RS512' ],
175
+ algorithms = [" RS256" , " RS384" , " RS512" ],
165
176
audience = settings .AUDIENCE ,
166
177
issuer = provider_config .issuer ,
167
178
options = options ,
168
- leeway = settings .JWT_LEEWAY
179
+ leeway = settings .JWT_LEEWAY ,
169
180
)
170
181
except jwt .ExpiredSignatureError as error :
171
182
logger .info ("Signature has expired: %s" , error )
@@ -175,7 +186,7 @@ def validate_access_token(self, access_token):
175
186
if idx < len (provider_config .signing_keys ) - 1 :
176
187
continue
177
188
else :
178
- logger .info (' Error decoding signature: %s' , error )
189
+ logger .info (" Error decoding signature: %s" , error )
179
190
raise PermissionDenied
180
191
except jwt .InvalidTokenError as error :
181
192
logger .info (str (error ))
@@ -187,12 +198,8 @@ def process_access_token(self, access_token, adfs_response=None):
187
198
188
199
logger .debug ("Received access token: %s" , access_token )
189
200
claims = self .validate_access_token (access_token )
190
- if (
191
- settings .BLOCK_GUEST_USERS
192
- and claims .get ('tid' )
193
- != settings .TENANT_ID
194
- ):
195
- logger .info ('Guest user denied' )
201
+ if settings .BLOCK_GUEST_USERS and claims .get ("tid" ) != settings .TENANT_ID :
202
+ logger .info ("Guest user denied" )
196
203
raise PermissionDenied
197
204
if not claims :
198
205
raise PermissionDenied
@@ -204,10 +211,7 @@ def process_access_token(self, access_token, adfs_response=None):
204
211
self .update_user_flags (user , claims , groups )
205
212
206
213
signals .post_authenticate .send (
207
- sender = self ,
208
- user = user ,
209
- claims = claims ,
210
- adfs_response = adfs_response
214
+ sender = self , user = user , claims = claims , adfs_response = adfs_response
211
215
)
212
216
213
217
user .full_clean ()
@@ -235,7 +239,9 @@ def process_user_groups(self, claims, access_token):
235
239
if settings .GROUPS_CLAIM in claims :
236
240
groups = claims [settings .GROUPS_CLAIM ]
237
241
if not isinstance (groups , list ):
238
- groups = [groups , ]
242
+ groups = [
243
+ groups ,
244
+ ]
239
245
elif (
240
246
settings .TENANT_ID != "adfs"
241
247
and "_claim_names" in claims
@@ -244,8 +250,10 @@ def process_user_groups(self, claims, access_token):
244
250
obo_access_token = self .get_obo_access_token (access_token )
245
251
groups = self .get_group_memberships_from_ms_graph (obo_access_token )
246
252
else :
247
- logger .debug ("The configured groups claim %s was not found in the access token" ,
248
- settings .GROUPS_CLAIM )
253
+ logger .debug (
254
+ "The configured groups claim %s was not found in the access token" ,
255
+ settings .GROUPS_CLAIM ,
256
+ )
249
257
250
258
return groups
251
259
@@ -264,19 +272,21 @@ def create_user(self, claims):
264
272
guest_username_claim = settings .GUEST_USERNAME_CLAIM
265
273
usermodel = get_user_model ()
266
274
267
- iss = claims .get (' iss' )
268
- idp = claims .get (' idp' , iss )
275
+ iss = claims .get (" iss" )
276
+ idp = claims .get (" idp" , iss )
269
277
if (
270
278
guest_username_claim
271
279
and not claims .get (username_claim )
272
280
and not settings .BLOCK_GUEST_USERS
273
- and (claims .get (' tid' ) != settings .TENANT_ID or iss != idp )
281
+ and (claims .get (" tid" ) != settings .TENANT_ID or iss != idp )
274
282
):
275
283
username_claim = guest_username_claim
276
284
277
285
if not claims .get (username_claim ):
278
- logger .error ("User claim's doesn't have the claim '%s' in his claims: %s" %
279
- (username_claim , claims ))
286
+ logger .error (
287
+ "User claim's doesn't have the claim '%s' in his claims: %s"
288
+ % (username_claim , claims )
289
+ )
280
290
raise PermissionDenied
281
291
282
292
userdata = {usermodel .USERNAME_FIELD : claims [username_claim ]}
@@ -288,7 +298,10 @@ def create_user(self, claims):
288
298
user = usermodel .objects .create (** userdata )
289
299
logger .debug ("User '%s' has been created." , claims [username_claim ])
290
300
else :
291
- logger .debug ("User '%s' doesn't exist and creating users is disabled." , claims [username_claim ])
301
+ logger .debug (
302
+ "User '%s' doesn't exist and creating users is disabled." ,
303
+ claims [username_claim ],
304
+ )
292
305
raise PermissionDenied
293
306
if not user .password :
294
307
user .set_unusable_password ()
@@ -308,27 +321,47 @@ def update_user_attributes(self, user, claims, claim_mapping=None):
308
321
"""
309
322
if claim_mapping is None :
310
323
claim_mapping = settings .CLAIM_MAPPING
311
- required_fields = [field .name for field in user ._meta .get_fields () if getattr (field , 'blank' , True ) is False ]
324
+ required_fields = [
325
+ field .name
326
+ for field in user ._meta .get_fields ()
327
+ if getattr (field , "blank" , True ) is False
328
+ ]
312
329
313
330
for field , claim in claim_mapping .items ():
314
331
if hasattr (user , field ) or user ._meta .fields_map .get (field ):
315
332
if not isinstance (claim , dict ):
316
333
if claim in claims :
317
334
setattr (user , field , claims [claim ])
318
- logger .debug ("Attribute '%s' for instance '%s' was set to '%s'." , field , user , claims [claim ])
335
+ logger .debug (
336
+ "Attribute '%s' for instance '%s' was set to '%s'." ,
337
+ field ,
338
+ user ,
339
+ claims [claim ],
340
+ )
319
341
else :
320
342
if field in required_fields :
321
343
msg = "Claim not found in access token: '{}'. Check ADFS claims mapping."
322
344
raise ImproperlyConfigured (msg .format (claim ))
323
345
else :
324
- logger .warning ("Claim '%s' for field '%s' was not found in "
325
- "the access token for instance '%s'. "
326
- "Field is not required and will be left empty" , claim , field , user )
346
+ logger .warning (
347
+ "Claim '%s' for field '%s' was not found in "
348
+ "the access token for instance '%s'. "
349
+ "Field is not required and will be left empty" ,
350
+ claim ,
351
+ field ,
352
+ user ,
353
+ )
327
354
else :
328
355
try :
329
- self .update_user_attributes (getattr (user , field ), claims , claim_mapping = claim )
356
+ self .update_user_attributes (
357
+ getattr (user , field ), claims , claim_mapping = claim
358
+ )
330
359
except ObjectDoesNotExist :
331
- logger .warning ("Object for field '{}' does not exist for: '{}'." .format (field , user ))
360
+ logger .warning (
361
+ "Object for field '{}' does not exist for: '{}'." .format (
362
+ field , user
363
+ )
364
+ )
332
365
333
366
else :
334
367
msg = "Model '{}' has no field named '{}'. Check ADFS claims mapping."
@@ -358,10 +391,13 @@ def update_user_groups(self, user, claim_groups):
358
391
# bulk_create could have been used here but we want to send signals.
359
392
new_claimed_groups = [
360
393
Group .objects .get_or_create (name = name )[0 ]
361
- for name in claim_groups if name not in existing_claimed_group_names
394
+ for name in claim_groups
395
+ if name not in existing_claimed_group_names
362
396
]
363
397
# Associate the users to all claimed groups
364
- user .groups .set (tuple (existing_claimed_groups ) + tuple (new_claimed_groups ))
398
+ user .groups .set (
399
+ tuple (existing_claimed_groups ) + tuple (new_claimed_groups )
400
+ )
365
401
else :
366
402
# Associate the user to only existing claimed groups
367
403
user .groups .set (existing_claimed_groups )
@@ -381,23 +417,42 @@ def update_user_flags(self, user, claims, claim_groups):
381
417
if not isinstance (group , list ):
382
418
group = [group ]
383
419
384
- if any (group_list_item in claim_groups for group_list_item in group ):
420
+ if any (
421
+ group_list_item in claim_groups for group_list_item in group
422
+ ):
385
423
value = True
386
424
else :
387
425
value = False
388
426
setattr (user , flag , value )
389
- logger .debug ("Attribute '%s' for user '%s' was set to '%s'." , flag , user , value )
427
+ logger .debug (
428
+ "Attribute '%s' for user '%s' was set to '%s'." ,
429
+ flag ,
430
+ user ,
431
+ value ,
432
+ )
390
433
else :
391
434
msg = "User model has no field named '{}'. Check ADFS boolean claims mapping."
392
435
raise ImproperlyConfigured (msg .format (flag ))
393
436
394
437
for field , claim in settings .BOOLEAN_CLAIM_MAPPING .items ():
395
438
if hasattr (user , field ):
396
439
bool_val = False
397
- if claim in claims and str (claims [claim ]).lower () in ['y' , 'yes' , 't' , 'true' , 'on' , '1' ]:
440
+ if claim in claims and str (claims [claim ]).lower () in [
441
+ "y" ,
442
+ "yes" ,
443
+ "t" ,
444
+ "true" ,
445
+ "on" ,
446
+ "1" ,
447
+ ]:
398
448
bool_val = True
399
449
setattr (user , field , bool_val )
400
- logger .debug ("Attribute '%s' for user '%s' was set to '%s'." , field , user , bool_val )
450
+ logger .debug (
451
+ "Attribute '%s' for user '%s' was set to '%s'." ,
452
+ field ,
453
+ user ,
454
+ bool_val ,
455
+ )
401
456
else :
402
457
msg = "User model has no field named '{}'. Check ADFS boolean claims mapping."
403
458
raise ImproperlyConfigured (msg .format (field ))
@@ -411,8 +466,10 @@ class AdfsAuthCodeBackend(AdfsBaseBackend):
411
466
412
467
def authenticate (self , request = None , authorization_code = None , ** kwargs ):
413
468
# If there's no token or code, we pass control to the next authentication backend
414
- if authorization_code is None or authorization_code == '' :
415
- logger .debug ("Authentication backend was called but no authorization code was received" )
469
+ if authorization_code is None or authorization_code == "" :
470
+ logger .debug (
471
+ "Authentication backend was called but no authorization code was received"
472
+ )
416
473
return
417
474
418
475
# If loaded data is too old, reload it again
@@ -435,8 +492,10 @@ def authenticate(self, request=None, access_token=None, **kwargs):
435
492
provider_config .load_config ()
436
493
437
494
# If there's no token or code, we pass control to the next authentication backend
438
- if access_token is None or access_token == '' :
439
- logger .debug ("Authentication backend was called but no access token was received" )
495
+ if access_token is None or access_token == "" :
496
+ logger .debug (
497
+ "Authentication backend was called but no access token was received"
498
+ )
440
499
return
441
500
442
501
access_token = access_token .decode ()
@@ -445,5 +504,6 @@ def authenticate(self, request=None, access_token=None, **kwargs):
445
504
446
505
447
506
class AdfsBackend (AdfsAuthCodeBackend ):
448
- """ Backwards compatible class name """
507
+ """Backwards compatible class name"""
508
+
449
509
pass
0 commit comments