From 77748cb479a1253b7dbb897443d5acb129e74e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Raimbault?= Date: Thu, 12 Jun 2025 02:36:09 +0200 Subject: [PATCH 1/3] Format with ruff --- django_auth_adfs/backend.py | 188 ++++++++++++++++++++++++------------ 1 file changed, 124 insertions(+), 64 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index c3165cf..38b2358 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -4,8 +4,11 @@ from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group -from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, - PermissionDenied) +from django.core.exceptions import ( + ImproperlyConfigured, + ObjectDoesNotExist, + PermissionDenied, +) from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -15,7 +18,6 @@ class AdfsBaseBackend(ModelBackend): - def _ms_request(self, action, url, data=None, **kwargs): """ Make a Microsoft Entra/GraphQL request @@ -36,7 +38,10 @@ def _ms_request(self, action, url, data=None, **kwargs): if response.status_code == 400: if response.json().get("error_description", "").startswith("AADSTS50076"): raise MFARequired - logger.error("ADFS server returned an error: %s", response.json()["error_description"]) + logger.error( + "ADFS server returned an error: %s", + response.json()["error_description"], + ) raise PermissionDenied if response.status_code != 200: @@ -47,16 +52,18 @@ def _ms_request(self, action, url, data=None, **kwargs): def exchange_auth_code(self, authorization_code, request): logger.debug("Received authorization code: %s", authorization_code) data = { - 'grant_type': 'authorization_code', - 'client_id': settings.CLIENT_ID, - 'redirect_uri': provider_config.redirect_uri(request), - 'code': authorization_code, + "grant_type": "authorization_code", + "client_id": settings.CLIENT_ID, + "redirect_uri": provider_config.redirect_uri(request), + "code": authorization_code, } if settings.CLIENT_SECRET: - data['client_secret'] = settings.CLIENT_SECRET + data["client_secret"] = settings.CLIENT_SECRET logger.debug("Getting access token at: %s", provider_config.token_endpoint) - response = self._ms_request(provider_config.session.post, provider_config.token_endpoint, data) + response = self._ms_request( + provider_config.session.post, provider_config.token_endpoint, data + ) adfs_response = response.json() return adfs_response @@ -79,11 +86,13 @@ def get_obo_access_token(self, access_token): "requested_token_use": "on_behalf_of", } if provider_config.token_endpoint.endswith("/v2.0/token"): - data["scope"] = 'GroupMember.Read.All' + data["scope"] = "GroupMember.Read.All" else: - data["resource"] = 'https://graph.microsoft.com' + data["resource"] = "https://graph.microsoft.com" - response = self._ms_request(provider_config.session.get, provider_config.token_endpoint, data) + response = self._ms_request( + provider_config.session.get, provider_config.token_endpoint, data + ) obo_access_token = response.json()["access_token"] logger.debug("Received OBO access token: %s", obo_access_token) return obo_access_token @@ -117,8 +126,10 @@ def get_group_memberships_from_ms_graph(self, obo_access_token): Returns: claim_groups (list): List of the users group memberships """ - graph_url = "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format( - provider_config.msgraph_endpoint + graph_url = ( + "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format( + provider_config.msgraph_endpoint + ) ) headers = {"Authorization": "Bearer {}".format(obo_access_token)} response = self._ms_request( @@ -147,25 +158,25 @@ def validate_access_token(self, access_token): # Explicit is better then implicit and it protects against # changes in the defaults the jwt module uses. options = { - 'verify_signature': True, - 'verify_exp': True, - 'verify_nbf': True, - 'verify_iat': True, - 'verify_aud': True, - 'verify_iss': True, - 'require_exp': False, - 'require_iat': False, - 'require_nbf': False + "verify_signature": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + "verify_aud": True, + "verify_iss": True, + "require_exp": False, + "require_iat": False, + "require_nbf": False, } # Validate token and return claims return jwt.decode( access_token, key=key, - algorithms=['RS256', 'RS384', 'RS512'], + algorithms=["RS256", "RS384", "RS512"], audience=settings.AUDIENCE, issuer=provider_config.issuer, options=options, - leeway=settings.JWT_LEEWAY + leeway=settings.JWT_LEEWAY, ) except jwt.ExpiredSignatureError as error: logger.info("Signature has expired: %s", error) @@ -175,7 +186,7 @@ def validate_access_token(self, access_token): if idx < len(provider_config.signing_keys) - 1: continue else: - logger.info('Error decoding signature: %s', error) + logger.info("Error decoding signature: %s", error) raise PermissionDenied except jwt.InvalidTokenError as error: logger.info(str(error)) @@ -187,12 +198,8 @@ def process_access_token(self, access_token, adfs_response=None): logger.debug("Received access token: %s", access_token) claims = self.validate_access_token(access_token) - if ( - settings.BLOCK_GUEST_USERS - and claims.get('tid') - != settings.TENANT_ID - ): - logger.info('Guest user denied') + if settings.BLOCK_GUEST_USERS and claims.get("tid") != settings.TENANT_ID: + logger.info("Guest user denied") raise PermissionDenied if not claims: raise PermissionDenied @@ -204,10 +211,7 @@ def process_access_token(self, access_token, adfs_response=None): self.update_user_flags(user, claims, groups) signals.post_authenticate.send( - sender=self, - user=user, - claims=claims, - adfs_response=adfs_response + sender=self, user=user, claims=claims, adfs_response=adfs_response ) user.full_clean() @@ -235,7 +239,9 @@ def process_user_groups(self, claims, access_token): if settings.GROUPS_CLAIM in claims: groups = claims[settings.GROUPS_CLAIM] if not isinstance(groups, list): - groups = [groups, ] + groups = [ + groups, + ] elif ( settings.TENANT_ID != "adfs" and "_claim_names" in claims @@ -244,8 +250,10 @@ def process_user_groups(self, claims, access_token): obo_access_token = self.get_obo_access_token(access_token) groups = self.get_group_memberships_from_ms_graph(obo_access_token) else: - logger.debug("The configured groups claim %s was not found in the access token", - settings.GROUPS_CLAIM) + logger.debug( + "The configured groups claim %s was not found in the access token", + settings.GROUPS_CLAIM, + ) return groups @@ -264,19 +272,21 @@ def create_user(self, claims): guest_username_claim = settings.GUEST_USERNAME_CLAIM usermodel = get_user_model() - iss = claims.get('iss') - idp = claims.get('idp', iss) + iss = claims.get("iss") + idp = claims.get("idp", iss) if ( guest_username_claim and not claims.get(username_claim) and not settings.BLOCK_GUEST_USERS - and (claims.get('tid') != settings.TENANT_ID or iss != idp) + and (claims.get("tid") != settings.TENANT_ID or iss != idp) ): username_claim = guest_username_claim if not claims.get(username_claim): - logger.error("User claim's doesn't have the claim '%s' in his claims: %s" % - (username_claim, claims)) + logger.error( + "User claim's doesn't have the claim '%s' in his claims: %s" + % (username_claim, claims) + ) raise PermissionDenied userdata = {usermodel.USERNAME_FIELD: claims[username_claim]} @@ -288,7 +298,10 @@ def create_user(self, claims): user = usermodel.objects.create(**userdata) logger.debug("User '%s' has been created.", claims[username_claim]) else: - logger.debug("User '%s' doesn't exist and creating users is disabled.", claims[username_claim]) + logger.debug( + "User '%s' doesn't exist and creating users is disabled.", + claims[username_claim], + ) raise PermissionDenied if not user.password: user.set_unusable_password() @@ -308,27 +321,47 @@ def update_user_attributes(self, user, claims, claim_mapping=None): """ if claim_mapping is None: claim_mapping = settings.CLAIM_MAPPING - required_fields = [field.name for field in user._meta.get_fields() if getattr(field, 'blank', True) is False] + required_fields = [ + field.name + for field in user._meta.get_fields() + if getattr(field, "blank", True) is False + ] for field, claim in claim_mapping.items(): if hasattr(user, field) or user._meta.fields_map.get(field): if not isinstance(claim, dict): if claim in claims: setattr(user, field, claims[claim]) - logger.debug("Attribute '%s' for instance '%s' was set to '%s'.", field, user, claims[claim]) + logger.debug( + "Attribute '%s' for instance '%s' was set to '%s'.", + field, + user, + claims[claim], + ) else: if field in required_fields: msg = "Claim not found in access token: '{}'. Check ADFS claims mapping." raise ImproperlyConfigured(msg.format(claim)) else: - logger.warning("Claim '%s' for field '%s' was not found in " - "the access token for instance '%s'. " - "Field is not required and will be left empty", claim, field, user) + logger.warning( + "Claim '%s' for field '%s' was not found in " + "the access token for instance '%s'. " + "Field is not required and will be left empty", + claim, + field, + user, + ) else: try: - self.update_user_attributes(getattr(user, field), claims, claim_mapping=claim) + self.update_user_attributes( + getattr(user, field), claims, claim_mapping=claim + ) except ObjectDoesNotExist: - logger.warning("Object for field '{}' does not exist for: '{}'.".format(field, user)) + logger.warning( + "Object for field '{}' does not exist for: '{}'.".format( + field, user + ) + ) else: msg = "Model '{}' has no field named '{}'. Check ADFS claims mapping." @@ -358,10 +391,13 @@ def update_user_groups(self, user, claim_groups): # bulk_create could have been used here but we want to send signals. new_claimed_groups = [ Group.objects.get_or_create(name=name)[0] - for name in claim_groups if name not in existing_claimed_group_names + for name in claim_groups + if name not in existing_claimed_group_names ] # Associate the users to all claimed groups - user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups)) + user.groups.set( + tuple(existing_claimed_groups) + tuple(new_claimed_groups) + ) else: # Associate the user to only existing claimed groups user.groups.set(existing_claimed_groups) @@ -381,12 +417,19 @@ def update_user_flags(self, user, claims, claim_groups): if not isinstance(group, list): group = [group] - if any(group_list_item in claim_groups for group_list_item in group): + if any( + group_list_item in claim_groups for group_list_item in group + ): value = True else: value = False setattr(user, flag, value) - logger.debug("Attribute '%s' for user '%s' was set to '%s'.", flag, user, value) + logger.debug( + "Attribute '%s' for user '%s' was set to '%s'.", + flag, + user, + value, + ) else: msg = "User model has no field named '{}'. Check ADFS boolean claims mapping." raise ImproperlyConfigured(msg.format(flag)) @@ -394,10 +437,22 @@ def update_user_flags(self, user, claims, claim_groups): for field, claim in settings.BOOLEAN_CLAIM_MAPPING.items(): if hasattr(user, field): bool_val = False - if claim in claims and str(claims[claim]).lower() in ['y', 'yes', 't', 'true', 'on', '1']: + if claim in claims and str(claims[claim]).lower() in [ + "y", + "yes", + "t", + "true", + "on", + "1", + ]: bool_val = True setattr(user, field, bool_val) - logger.debug("Attribute '%s' for user '%s' was set to '%s'.", field, user, bool_val) + logger.debug( + "Attribute '%s' for user '%s' was set to '%s'.", + field, + user, + bool_val, + ) else: msg = "User model has no field named '{}'. Check ADFS boolean claims mapping." raise ImproperlyConfigured(msg.format(field)) @@ -411,8 +466,10 @@ class AdfsAuthCodeBackend(AdfsBaseBackend): def authenticate(self, request=None, authorization_code=None, **kwargs): # If there's no token or code, we pass control to the next authentication backend - if authorization_code is None or authorization_code == '': - logger.debug("Authentication backend was called but no authorization code was received") + if authorization_code is None or authorization_code == "": + logger.debug( + "Authentication backend was called but no authorization code was received" + ) return # If loaded data is too old, reload it again @@ -435,8 +492,10 @@ def authenticate(self, request=None, access_token=None, **kwargs): provider_config.load_config() # If there's no token or code, we pass control to the next authentication backend - if access_token is None or access_token == '': - logger.debug("Authentication backend was called but no access token was received") + if access_token is None or access_token == "": + logger.debug( + "Authentication backend was called but no access token was received" + ) return access_token = access_token.decode() @@ -445,5 +504,6 @@ def authenticate(self, request=None, access_token=None, **kwargs): class AdfsBackend(AdfsAuthCodeBackend): - """ Backwards compatible class name """ + """Backwards compatible class name""" + pass From ec4f1295a28f468cdd8064314085caf2a110f944 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Raimbault?= Date: Thu, 12 Jun 2025 02:37:47 +0200 Subject: [PATCH 2/3] Check the new groups are different before update Claimed groups can contain groups they are not relevant for our app so the first quick check is to ensure the list are different then to use a SQL query to fetch the groups of the user. Only set the groups if they are different, it avoids a SELECT of the Django ORM (called by the set()). --- django_auth_adfs/backend.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 38b2358..4b1e577 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -382,6 +382,13 @@ def update_user_groups(self, user, claim_groups): if sorted(claim_groups) != sorted(user_group_names): # Get the list of already existing groups in one SQL query existing_claimed_groups = Group.objects.filter(name__in=claim_groups) + existing_claimed_group_names = ( + group.name for group in existing_claimed_groups + ) + + if sorted(existing_claimed_group_names) == sorted(user_group_names): + # If the groups are already set, we don't need to do anything + return if settings.MIRROR_GROUPS: existing_claimed_group_names = ( From fe2a501f26311d1320ddfb3fac556ea164503bee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Raimbault?= Date: Thu, 12 Jun 2025 01:26:32 +0200 Subject: [PATCH 3/3] Mirror groups should remove any that are not in the claim --- django_auth_adfs/backend.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 4b1e577..f8bbd76 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -401,10 +401,9 @@ def update_user_groups(self, user, claim_groups): for name in claim_groups if name not in existing_claimed_group_names ] - # Associate the users to all claimed groups - user.groups.set( - tuple(existing_claimed_groups) + tuple(new_claimed_groups) - ) + # Set user's groups to all claimed groups (both existing and + # newly created) and remove any that are not in the claim. + user.groups.set(new_claimed_groups) else: # Associate the user to only existing claimed groups user.groups.set(existing_claimed_groups)