Skip to content

Commit 77748cb

Browse files
committed
Format with ruff
1 parent 9b3496d commit 77748cb

File tree

1 file changed

+124
-64
lines changed

1 file changed

+124
-64
lines changed

django_auth_adfs/backend.py

Lines changed: 124 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
from django.contrib.auth import get_user_model
55
from django.contrib.auth.backends import ModelBackend
66
from django.contrib.auth.models import Group
7-
from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist,
8-
PermissionDenied)
7+
from django.core.exceptions import (
8+
ImproperlyConfigured,
9+
ObjectDoesNotExist,
10+
PermissionDenied,
11+
)
912

1013
from django_auth_adfs import signals
1114
from django_auth_adfs.config import provider_config, settings
@@ -15,7 +18,6 @@
1518

1619

1720
class AdfsBaseBackend(ModelBackend):
18-
1921
def _ms_request(self, action, url, data=None, **kwargs):
2022
"""
2123
Make a Microsoft Entra/GraphQL request
@@ -36,7 +38,10 @@ def _ms_request(self, action, url, data=None, **kwargs):
3638
if response.status_code == 400:
3739
if response.json().get("error_description", "").startswith("AADSTS50076"):
3840
raise MFARequired
39-
logger.error("ADFS server returned an error: %s", response.json()["error_description"])
41+
logger.error(
42+
"ADFS server returned an error: %s",
43+
response.json()["error_description"],
44+
)
4045
raise PermissionDenied
4146

4247
if response.status_code != 200:
@@ -47,16 +52,18 @@ def _ms_request(self, action, url, data=None, **kwargs):
4752
def exchange_auth_code(self, authorization_code, request):
4853
logger.debug("Received authorization code: %s", authorization_code)
4954
data = {
50-
'grant_type': 'authorization_code',
51-
'client_id': settings.CLIENT_ID,
52-
'redirect_uri': provider_config.redirect_uri(request),
53-
'code': authorization_code,
55+
"grant_type": "authorization_code",
56+
"client_id": settings.CLIENT_ID,
57+
"redirect_uri": provider_config.redirect_uri(request),
58+
"code": authorization_code,
5459
}
5560
if settings.CLIENT_SECRET:
56-
data['client_secret'] = settings.CLIENT_SECRET
61+
data["client_secret"] = settings.CLIENT_SECRET
5762

5863
logger.debug("Getting access token at: %s", provider_config.token_endpoint)
59-
response = self._ms_request(provider_config.session.post, provider_config.token_endpoint, data)
64+
response = self._ms_request(
65+
provider_config.session.post, provider_config.token_endpoint, data
66+
)
6067
adfs_response = response.json()
6168
return adfs_response
6269

@@ -79,11 +86,13 @@ def get_obo_access_token(self, access_token):
7986
"requested_token_use": "on_behalf_of",
8087
}
8188
if provider_config.token_endpoint.endswith("/v2.0/token"):
82-
data["scope"] = 'GroupMember.Read.All'
89+
data["scope"] = "GroupMember.Read.All"
8390
else:
84-
data["resource"] = 'https://graph.microsoft.com'
91+
data["resource"] = "https://graph.microsoft.com"
8592

86-
response = self._ms_request(provider_config.session.get, provider_config.token_endpoint, data)
93+
response = self._ms_request(
94+
provider_config.session.get, provider_config.token_endpoint, data
95+
)
8796
obo_access_token = response.json()["access_token"]
8897
logger.debug("Received OBO access token: %s", obo_access_token)
8998
return obo_access_token
@@ -117,8 +126,10 @@ def get_group_memberships_from_ms_graph(self, obo_access_token):
117126
Returns:
118127
claim_groups (list): List of the users group memberships
119128
"""
120-
graph_url = "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format(
121-
provider_config.msgraph_endpoint
129+
graph_url = (
130+
"https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format(
131+
provider_config.msgraph_endpoint
132+
)
122133
)
123134
headers = {"Authorization": "Bearer {}".format(obo_access_token)}
124135
response = self._ms_request(
@@ -147,25 +158,25 @@ def validate_access_token(self, access_token):
147158
# Explicit is better then implicit and it protects against
148159
# changes in the defaults the jwt module uses.
149160
options = {
150-
'verify_signature': True,
151-
'verify_exp': True,
152-
'verify_nbf': True,
153-
'verify_iat': True,
154-
'verify_aud': True,
155-
'verify_iss': True,
156-
'require_exp': False,
157-
'require_iat': False,
158-
'require_nbf': False
161+
"verify_signature": True,
162+
"verify_exp": True,
163+
"verify_nbf": True,
164+
"verify_iat": True,
165+
"verify_aud": True,
166+
"verify_iss": True,
167+
"require_exp": False,
168+
"require_iat": False,
169+
"require_nbf": False,
159170
}
160171
# Validate token and return claims
161172
return jwt.decode(
162173
access_token,
163174
key=key,
164-
algorithms=['RS256', 'RS384', 'RS512'],
175+
algorithms=["RS256", "RS384", "RS512"],
165176
audience=settings.AUDIENCE,
166177
issuer=provider_config.issuer,
167178
options=options,
168-
leeway=settings.JWT_LEEWAY
179+
leeway=settings.JWT_LEEWAY,
169180
)
170181
except jwt.ExpiredSignatureError as error:
171182
logger.info("Signature has expired: %s", error)
@@ -175,7 +186,7 @@ def validate_access_token(self, access_token):
175186
if idx < len(provider_config.signing_keys) - 1:
176187
continue
177188
else:
178-
logger.info('Error decoding signature: %s', error)
189+
logger.info("Error decoding signature: %s", error)
179190
raise PermissionDenied
180191
except jwt.InvalidTokenError as error:
181192
logger.info(str(error))
@@ -187,12 +198,8 @@ def process_access_token(self, access_token, adfs_response=None):
187198

188199
logger.debug("Received access token: %s", access_token)
189200
claims = self.validate_access_token(access_token)
190-
if (
191-
settings.BLOCK_GUEST_USERS
192-
and claims.get('tid')
193-
!= settings.TENANT_ID
194-
):
195-
logger.info('Guest user denied')
201+
if settings.BLOCK_GUEST_USERS and claims.get("tid") != settings.TENANT_ID:
202+
logger.info("Guest user denied")
196203
raise PermissionDenied
197204
if not claims:
198205
raise PermissionDenied
@@ -204,10 +211,7 @@ def process_access_token(self, access_token, adfs_response=None):
204211
self.update_user_flags(user, claims, groups)
205212

206213
signals.post_authenticate.send(
207-
sender=self,
208-
user=user,
209-
claims=claims,
210-
adfs_response=adfs_response
214+
sender=self, user=user, claims=claims, adfs_response=adfs_response
211215
)
212216

213217
user.full_clean()
@@ -235,7 +239,9 @@ def process_user_groups(self, claims, access_token):
235239
if settings.GROUPS_CLAIM in claims:
236240
groups = claims[settings.GROUPS_CLAIM]
237241
if not isinstance(groups, list):
238-
groups = [groups, ]
242+
groups = [
243+
groups,
244+
]
239245
elif (
240246
settings.TENANT_ID != "adfs"
241247
and "_claim_names" in claims
@@ -244,8 +250,10 @@ def process_user_groups(self, claims, access_token):
244250
obo_access_token = self.get_obo_access_token(access_token)
245251
groups = self.get_group_memberships_from_ms_graph(obo_access_token)
246252
else:
247-
logger.debug("The configured groups claim %s was not found in the access token",
248-
settings.GROUPS_CLAIM)
253+
logger.debug(
254+
"The configured groups claim %s was not found in the access token",
255+
settings.GROUPS_CLAIM,
256+
)
249257

250258
return groups
251259

@@ -264,19 +272,21 @@ def create_user(self, claims):
264272
guest_username_claim = settings.GUEST_USERNAME_CLAIM
265273
usermodel = get_user_model()
266274

267-
iss = claims.get('iss')
268-
idp = claims.get('idp', iss)
275+
iss = claims.get("iss")
276+
idp = claims.get("idp", iss)
269277
if (
270278
guest_username_claim
271279
and not claims.get(username_claim)
272280
and not settings.BLOCK_GUEST_USERS
273-
and (claims.get('tid') != settings.TENANT_ID or iss != idp)
281+
and (claims.get("tid") != settings.TENANT_ID or iss != idp)
274282
):
275283
username_claim = guest_username_claim
276284

277285
if not claims.get(username_claim):
278-
logger.error("User claim's doesn't have the claim '%s' in his claims: %s" %
279-
(username_claim, claims))
286+
logger.error(
287+
"User claim's doesn't have the claim '%s' in his claims: %s"
288+
% (username_claim, claims)
289+
)
280290
raise PermissionDenied
281291

282292
userdata = {usermodel.USERNAME_FIELD: claims[username_claim]}
@@ -288,7 +298,10 @@ def create_user(self, claims):
288298
user = usermodel.objects.create(**userdata)
289299
logger.debug("User '%s' has been created.", claims[username_claim])
290300
else:
291-
logger.debug("User '%s' doesn't exist and creating users is disabled.", claims[username_claim])
301+
logger.debug(
302+
"User '%s' doesn't exist and creating users is disabled.",
303+
claims[username_claim],
304+
)
292305
raise PermissionDenied
293306
if not user.password:
294307
user.set_unusable_password()
@@ -308,27 +321,47 @@ def update_user_attributes(self, user, claims, claim_mapping=None):
308321
"""
309322
if claim_mapping is None:
310323
claim_mapping = settings.CLAIM_MAPPING
311-
required_fields = [field.name for field in user._meta.get_fields() if getattr(field, 'blank', True) is False]
324+
required_fields = [
325+
field.name
326+
for field in user._meta.get_fields()
327+
if getattr(field, "blank", True) is False
328+
]
312329

313330
for field, claim in claim_mapping.items():
314331
if hasattr(user, field) or user._meta.fields_map.get(field):
315332
if not isinstance(claim, dict):
316333
if claim in claims:
317334
setattr(user, field, claims[claim])
318-
logger.debug("Attribute '%s' for instance '%s' was set to '%s'.", field, user, claims[claim])
335+
logger.debug(
336+
"Attribute '%s' for instance '%s' was set to '%s'.",
337+
field,
338+
user,
339+
claims[claim],
340+
)
319341
else:
320342
if field in required_fields:
321343
msg = "Claim not found in access token: '{}'. Check ADFS claims mapping."
322344
raise ImproperlyConfigured(msg.format(claim))
323345
else:
324-
logger.warning("Claim '%s' for field '%s' was not found in "
325-
"the access token for instance '%s'. "
326-
"Field is not required and will be left empty", claim, field, user)
346+
logger.warning(
347+
"Claim '%s' for field '%s' was not found in "
348+
"the access token for instance '%s'. "
349+
"Field is not required and will be left empty",
350+
claim,
351+
field,
352+
user,
353+
)
327354
else:
328355
try:
329-
self.update_user_attributes(getattr(user, field), claims, claim_mapping=claim)
356+
self.update_user_attributes(
357+
getattr(user, field), claims, claim_mapping=claim
358+
)
330359
except ObjectDoesNotExist:
331-
logger.warning("Object for field '{}' does not exist for: '{}'.".format(field, user))
360+
logger.warning(
361+
"Object for field '{}' does not exist for: '{}'.".format(
362+
field, user
363+
)
364+
)
332365

333366
else:
334367
msg = "Model '{}' has no field named '{}'. Check ADFS claims mapping."
@@ -358,10 +391,13 @@ def update_user_groups(self, user, claim_groups):
358391
# bulk_create could have been used here but we want to send signals.
359392
new_claimed_groups = [
360393
Group.objects.get_or_create(name=name)[0]
361-
for name in claim_groups if name not in existing_claimed_group_names
394+
for name in claim_groups
395+
if name not in existing_claimed_group_names
362396
]
363397
# Associate the users to all claimed groups
364-
user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups))
398+
user.groups.set(
399+
tuple(existing_claimed_groups) + tuple(new_claimed_groups)
400+
)
365401
else:
366402
# Associate the user to only existing claimed groups
367403
user.groups.set(existing_claimed_groups)
@@ -381,23 +417,42 @@ def update_user_flags(self, user, claims, claim_groups):
381417
if not isinstance(group, list):
382418
group = [group]
383419

384-
if any(group_list_item in claim_groups for group_list_item in group):
420+
if any(
421+
group_list_item in claim_groups for group_list_item in group
422+
):
385423
value = True
386424
else:
387425
value = False
388426
setattr(user, flag, value)
389-
logger.debug("Attribute '%s' for user '%s' was set to '%s'.", flag, user, value)
427+
logger.debug(
428+
"Attribute '%s' for user '%s' was set to '%s'.",
429+
flag,
430+
user,
431+
value,
432+
)
390433
else:
391434
msg = "User model has no field named '{}'. Check ADFS boolean claims mapping."
392435
raise ImproperlyConfigured(msg.format(flag))
393436

394437
for field, claim in settings.BOOLEAN_CLAIM_MAPPING.items():
395438
if hasattr(user, field):
396439
bool_val = False
397-
if claim in claims and str(claims[claim]).lower() in ['y', 'yes', 't', 'true', 'on', '1']:
440+
if claim in claims and str(claims[claim]).lower() in [
441+
"y",
442+
"yes",
443+
"t",
444+
"true",
445+
"on",
446+
"1",
447+
]:
398448
bool_val = True
399449
setattr(user, field, bool_val)
400-
logger.debug("Attribute '%s' for user '%s' was set to '%s'.", field, user, bool_val)
450+
logger.debug(
451+
"Attribute '%s' for user '%s' was set to '%s'.",
452+
field,
453+
user,
454+
bool_val,
455+
)
401456
else:
402457
msg = "User model has no field named '{}'. Check ADFS boolean claims mapping."
403458
raise ImproperlyConfigured(msg.format(field))
@@ -411,8 +466,10 @@ class AdfsAuthCodeBackend(AdfsBaseBackend):
411466

412467
def authenticate(self, request=None, authorization_code=None, **kwargs):
413468
# If there's no token or code, we pass control to the next authentication backend
414-
if authorization_code is None or authorization_code == '':
415-
logger.debug("Authentication backend was called but no authorization code was received")
469+
if authorization_code is None or authorization_code == "":
470+
logger.debug(
471+
"Authentication backend was called but no authorization code was received"
472+
)
416473
return
417474

418475
# If loaded data is too old, reload it again
@@ -435,8 +492,10 @@ def authenticate(self, request=None, access_token=None, **kwargs):
435492
provider_config.load_config()
436493

437494
# If there's no token or code, we pass control to the next authentication backend
438-
if access_token is None or access_token == '':
439-
logger.debug("Authentication backend was called but no access token was received")
495+
if access_token is None or access_token == "":
496+
logger.debug(
497+
"Authentication backend was called but no access token was received"
498+
)
440499
return
441500

442501
access_token = access_token.decode()
@@ -445,5 +504,6 @@ def authenticate(self, request=None, access_token=None, **kwargs):
445504

446505

447506
class AdfsBackend(AdfsAuthCodeBackend):
448-
""" Backwards compatible class name """
507+
"""Backwards compatible class name"""
508+
449509
pass

0 commit comments

Comments
 (0)