diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index c3165cf..f8bbd76 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -4,8 +4,11 @@ from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group -from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, - PermissionDenied) +from django.core.exceptions import ( + ImproperlyConfigured, + ObjectDoesNotExist, + PermissionDenied, +) from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -15,7 +18,6 @@ class AdfsBaseBackend(ModelBackend): - def _ms_request(self, action, url, data=None, **kwargs): """ Make a Microsoft Entra/GraphQL request @@ -36,7 +38,10 @@ def _ms_request(self, action, url, data=None, **kwargs): if response.status_code == 400: if response.json().get("error_description", "").startswith("AADSTS50076"): raise MFARequired - logger.error("ADFS server returned an error: %s", response.json()["error_description"]) + logger.error( + "ADFS server returned an error: %s", + response.json()["error_description"], + ) raise PermissionDenied if response.status_code != 200: @@ -47,16 +52,18 @@ def _ms_request(self, action, url, data=None, **kwargs): def exchange_auth_code(self, authorization_code, request): logger.debug("Received authorization code: %s", authorization_code) data = { - 'grant_type': 'authorization_code', - 'client_id': settings.CLIENT_ID, - 'redirect_uri': provider_config.redirect_uri(request), - 'code': authorization_code, + "grant_type": "authorization_code", + "client_id": settings.CLIENT_ID, + "redirect_uri": provider_config.redirect_uri(request), + "code": authorization_code, } if settings.CLIENT_SECRET: - data['client_secret'] = settings.CLIENT_SECRET + data["client_secret"] = settings.CLIENT_SECRET logger.debug("Getting access token at: %s", provider_config.token_endpoint) - response = self._ms_request(provider_config.session.post, provider_config.token_endpoint, data) + response = self._ms_request( + provider_config.session.post, provider_config.token_endpoint, data + ) adfs_response = response.json() return adfs_response @@ -79,11 +86,13 @@ def get_obo_access_token(self, access_token): "requested_token_use": "on_behalf_of", } if provider_config.token_endpoint.endswith("/v2.0/token"): - data["scope"] = 'GroupMember.Read.All' + data["scope"] = "GroupMember.Read.All" else: - data["resource"] = 'https://graph.microsoft.com' + data["resource"] = "https://graph.microsoft.com" - response = self._ms_request(provider_config.session.get, provider_config.token_endpoint, data) + response = self._ms_request( + provider_config.session.get, provider_config.token_endpoint, data + ) obo_access_token = response.json()["access_token"] logger.debug("Received OBO access token: %s", obo_access_token) return obo_access_token @@ -117,8 +126,10 @@ def get_group_memberships_from_ms_graph(self, obo_access_token): Returns: claim_groups (list): List of the users group memberships """ - graph_url = "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format( - provider_config.msgraph_endpoint + graph_url = ( + "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format( + provider_config.msgraph_endpoint + ) ) headers = {"Authorization": "Bearer {}".format(obo_access_token)} response = self._ms_request( @@ -147,25 +158,25 @@ def validate_access_token(self, access_token): # Explicit is better then implicit and it protects against # changes in the defaults the jwt module uses. options = { - 'verify_signature': True, - 'verify_exp': True, - 'verify_nbf': True, - 'verify_iat': True, - 'verify_aud': True, - 'verify_iss': True, - 'require_exp': False, - 'require_iat': False, - 'require_nbf': False + "verify_signature": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + "verify_aud": True, + "verify_iss": True, + "require_exp": False, + "require_iat": False, + "require_nbf": False, } # Validate token and return claims return jwt.decode( access_token, key=key, - algorithms=['RS256', 'RS384', 'RS512'], + algorithms=["RS256", "RS384", "RS512"], audience=settings.AUDIENCE, issuer=provider_config.issuer, options=options, - leeway=settings.JWT_LEEWAY + leeway=settings.JWT_LEEWAY, ) except jwt.ExpiredSignatureError as error: logger.info("Signature has expired: %s", error) @@ -175,7 +186,7 @@ def validate_access_token(self, access_token): if idx < len(provider_config.signing_keys) - 1: continue else: - logger.info('Error decoding signature: %s', error) + logger.info("Error decoding signature: %s", error) raise PermissionDenied except jwt.InvalidTokenError as error: logger.info(str(error)) @@ -187,12 +198,8 @@ def process_access_token(self, access_token, adfs_response=None): logger.debug("Received access token: %s", access_token) claims = self.validate_access_token(access_token) - if ( - settings.BLOCK_GUEST_USERS - and claims.get('tid') - != settings.TENANT_ID - ): - logger.info('Guest user denied') + if settings.BLOCK_GUEST_USERS and claims.get("tid") != settings.TENANT_ID: + logger.info("Guest user denied") raise PermissionDenied if not claims: raise PermissionDenied @@ -204,10 +211,7 @@ def process_access_token(self, access_token, adfs_response=None): self.update_user_flags(user, claims, groups) signals.post_authenticate.send( - sender=self, - user=user, - claims=claims, - adfs_response=adfs_response + sender=self, user=user, claims=claims, adfs_response=adfs_response ) user.full_clean() @@ -235,7 +239,9 @@ def process_user_groups(self, claims, access_token): if settings.GROUPS_CLAIM in claims: groups = claims[settings.GROUPS_CLAIM] if not isinstance(groups, list): - groups = [groups, ] + groups = [ + groups, + ] elif ( settings.TENANT_ID != "adfs" and "_claim_names" in claims @@ -244,8 +250,10 @@ def process_user_groups(self, claims, access_token): obo_access_token = self.get_obo_access_token(access_token) groups = self.get_group_memberships_from_ms_graph(obo_access_token) else: - logger.debug("The configured groups claim %s was not found in the access token", - settings.GROUPS_CLAIM) + logger.debug( + "The configured groups claim %s was not found in the access token", + settings.GROUPS_CLAIM, + ) return groups @@ -264,19 +272,21 @@ def create_user(self, claims): guest_username_claim = settings.GUEST_USERNAME_CLAIM usermodel = get_user_model() - iss = claims.get('iss') - idp = claims.get('idp', iss) + iss = claims.get("iss") + idp = claims.get("idp", iss) if ( guest_username_claim and not claims.get(username_claim) and not settings.BLOCK_GUEST_USERS - and (claims.get('tid') != settings.TENANT_ID or iss != idp) + and (claims.get("tid") != settings.TENANT_ID or iss != idp) ): username_claim = guest_username_claim if not claims.get(username_claim): - logger.error("User claim's doesn't have the claim '%s' in his claims: %s" % - (username_claim, claims)) + logger.error( + "User claim's doesn't have the claim '%s' in his claims: %s" + % (username_claim, claims) + ) raise PermissionDenied userdata = {usermodel.USERNAME_FIELD: claims[username_claim]} @@ -288,7 +298,10 @@ def create_user(self, claims): user = usermodel.objects.create(**userdata) logger.debug("User '%s' has been created.", claims[username_claim]) else: - logger.debug("User '%s' doesn't exist and creating users is disabled.", claims[username_claim]) + logger.debug( + "User '%s' doesn't exist and creating users is disabled.", + claims[username_claim], + ) raise PermissionDenied if not user.password: user.set_unusable_password() @@ -308,27 +321,47 @@ def update_user_attributes(self, user, claims, claim_mapping=None): """ if claim_mapping is None: claim_mapping = settings.CLAIM_MAPPING - required_fields = [field.name for field in user._meta.get_fields() if getattr(field, 'blank', True) is False] + required_fields = [ + field.name + for field in user._meta.get_fields() + if getattr(field, "blank", True) is False + ] for field, claim in claim_mapping.items(): if hasattr(user, field) or user._meta.fields_map.get(field): if not isinstance(claim, dict): if claim in claims: setattr(user, field, claims[claim]) - logger.debug("Attribute '%s' for instance '%s' was set to '%s'.", field, user, claims[claim]) + logger.debug( + "Attribute '%s' for instance '%s' was set to '%s'.", + field, + user, + claims[claim], + ) else: if field in required_fields: msg = "Claim not found in access token: '{}'. Check ADFS claims mapping." raise ImproperlyConfigured(msg.format(claim)) else: - logger.warning("Claim '%s' for field '%s' was not found in " - "the access token for instance '%s'. " - "Field is not required and will be left empty", claim, field, user) + logger.warning( + "Claim '%s' for field '%s' was not found in " + "the access token for instance '%s'. " + "Field is not required and will be left empty", + claim, + field, + user, + ) else: try: - self.update_user_attributes(getattr(user, field), claims, claim_mapping=claim) + self.update_user_attributes( + getattr(user, field), claims, claim_mapping=claim + ) except ObjectDoesNotExist: - logger.warning("Object for field '{}' does not exist for: '{}'.".format(field, user)) + logger.warning( + "Object for field '{}' does not exist for: '{}'.".format( + field, user + ) + ) else: msg = "Model '{}' has no field named '{}'. Check ADFS claims mapping." @@ -349,6 +382,13 @@ def update_user_groups(self, user, claim_groups): if sorted(claim_groups) != sorted(user_group_names): # Get the list of already existing groups in one SQL query existing_claimed_groups = Group.objects.filter(name__in=claim_groups) + existing_claimed_group_names = ( + group.name for group in existing_claimed_groups + ) + + if sorted(existing_claimed_group_names) == sorted(user_group_names): + # If the groups are already set, we don't need to do anything + return if settings.MIRROR_GROUPS: existing_claimed_group_names = ( @@ -358,10 +398,12 @@ def update_user_groups(self, user, claim_groups): # bulk_create could have been used here but we want to send signals. new_claimed_groups = [ Group.objects.get_or_create(name=name)[0] - for name in claim_groups if name not in existing_claimed_group_names + for name in claim_groups + if name not in existing_claimed_group_names ] - # Associate the users to all claimed groups - user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups)) + # Set user's groups to all claimed groups (both existing and + # newly created) and remove any that are not in the claim. + user.groups.set(new_claimed_groups) else: # Associate the user to only existing claimed groups user.groups.set(existing_claimed_groups) @@ -381,12 +423,19 @@ def update_user_flags(self, user, claims, claim_groups): if not isinstance(group, list): group = [group] - if any(group_list_item in claim_groups for group_list_item in group): + if any( + group_list_item in claim_groups for group_list_item in group + ): value = True else: value = False setattr(user, flag, value) - logger.debug("Attribute '%s' for user '%s' was set to '%s'.", flag, user, value) + logger.debug( + "Attribute '%s' for user '%s' was set to '%s'.", + flag, + user, + value, + ) else: msg = "User model has no field named '{}'. Check ADFS boolean claims mapping." raise ImproperlyConfigured(msg.format(flag)) @@ -394,10 +443,22 @@ def update_user_flags(self, user, claims, claim_groups): for field, claim in settings.BOOLEAN_CLAIM_MAPPING.items(): if hasattr(user, field): bool_val = False - if claim in claims and str(claims[claim]).lower() in ['y', 'yes', 't', 'true', 'on', '1']: + if claim in claims and str(claims[claim]).lower() in [ + "y", + "yes", + "t", + "true", + "on", + "1", + ]: bool_val = True setattr(user, field, bool_val) - logger.debug("Attribute '%s' for user '%s' was set to '%s'.", field, user, bool_val) + logger.debug( + "Attribute '%s' for user '%s' was set to '%s'.", + field, + user, + bool_val, + ) else: msg = "User model has no field named '{}'. Check ADFS boolean claims mapping." raise ImproperlyConfigured(msg.format(field)) @@ -411,8 +472,10 @@ class AdfsAuthCodeBackend(AdfsBaseBackend): def authenticate(self, request=None, authorization_code=None, **kwargs): # If there's no token or code, we pass control to the next authentication backend - if authorization_code is None or authorization_code == '': - logger.debug("Authentication backend was called but no authorization code was received") + if authorization_code is None or authorization_code == "": + logger.debug( + "Authentication backend was called but no authorization code was received" + ) return # If loaded data is too old, reload it again @@ -435,8 +498,10 @@ def authenticate(self, request=None, access_token=None, **kwargs): provider_config.load_config() # If there's no token or code, we pass control to the next authentication backend - if access_token is None or access_token == '': - logger.debug("Authentication backend was called but no access token was received") + if access_token is None or access_token == "": + logger.debug( + "Authentication backend was called but no access token was received" + ) return access_token = access_token.decode() @@ -445,5 +510,6 @@ def authenticate(self, request=None, access_token=None, **kwargs): class AdfsBackend(AdfsAuthCodeBackend): - """ Backwards compatible class name """ + """Backwards compatible class name""" + pass