-
-
Notifications
You must be signed in to change notification settings - Fork 106
Reduce SQL queries and restrict to claimed groups #377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
stephane
wants to merge
3
commits into
snok:main
Choose a base branch
from
stephane:reduce-sql
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,11 @@ | |
from django.contrib.auth import get_user_model | ||
from django.contrib.auth.backends import ModelBackend | ||
from django.contrib.auth.models import Group | ||
from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, | ||
PermissionDenied) | ||
from django.core.exceptions import ( | ||
ImproperlyConfigured, | ||
ObjectDoesNotExist, | ||
PermissionDenied, | ||
) | ||
|
||
from django_auth_adfs import signals | ||
from django_auth_adfs.config import provider_config, settings | ||
|
@@ -15,7 +18,6 @@ | |
|
||
|
||
class AdfsBaseBackend(ModelBackend): | ||
|
||
def _ms_request(self, action, url, data=None, **kwargs): | ||
""" | ||
Make a Microsoft Entra/GraphQL request | ||
|
@@ -36,7 +38,10 @@ def _ms_request(self, action, url, data=None, **kwargs): | |
if response.status_code == 400: | ||
if response.json().get("error_description", "").startswith("AADSTS50076"): | ||
raise MFARequired | ||
logger.error("ADFS server returned an error: %s", response.json()["error_description"]) | ||
logger.error( | ||
"ADFS server returned an error: %s", | ||
response.json()["error_description"], | ||
) | ||
raise PermissionDenied | ||
|
||
if response.status_code != 200: | ||
|
@@ -47,16 +52,18 @@ def _ms_request(self, action, url, data=None, **kwargs): | |
def exchange_auth_code(self, authorization_code, request): | ||
logger.debug("Received authorization code: %s", authorization_code) | ||
data = { | ||
'grant_type': 'authorization_code', | ||
'client_id': settings.CLIENT_ID, | ||
'redirect_uri': provider_config.redirect_uri(request), | ||
'code': authorization_code, | ||
"grant_type": "authorization_code", | ||
"client_id": settings.CLIENT_ID, | ||
"redirect_uri": provider_config.redirect_uri(request), | ||
"code": authorization_code, | ||
} | ||
if settings.CLIENT_SECRET: | ||
data['client_secret'] = settings.CLIENT_SECRET | ||
data["client_secret"] = settings.CLIENT_SECRET | ||
|
||
logger.debug("Getting access token at: %s", provider_config.token_endpoint) | ||
response = self._ms_request(provider_config.session.post, provider_config.token_endpoint, data) | ||
response = self._ms_request( | ||
provider_config.session.post, provider_config.token_endpoint, data | ||
) | ||
adfs_response = response.json() | ||
return adfs_response | ||
|
||
|
@@ -79,11 +86,13 @@ def get_obo_access_token(self, access_token): | |
"requested_token_use": "on_behalf_of", | ||
} | ||
if provider_config.token_endpoint.endswith("/v2.0/token"): | ||
data["scope"] = 'GroupMember.Read.All' | ||
data["scope"] = "GroupMember.Read.All" | ||
else: | ||
data["resource"] = 'https://graph.microsoft.com' | ||
data["resource"] = "https://graph.microsoft.com" | ||
|
||
response = self._ms_request(provider_config.session.get, provider_config.token_endpoint, data) | ||
response = self._ms_request( | ||
provider_config.session.get, provider_config.token_endpoint, data | ||
) | ||
obo_access_token = response.json()["access_token"] | ||
logger.debug("Received OBO access token: %s", obo_access_token) | ||
return obo_access_token | ||
|
@@ -117,8 +126,10 @@ def get_group_memberships_from_ms_graph(self, obo_access_token): | |
Returns: | ||
claim_groups (list): List of the users group memberships | ||
""" | ||
graph_url = "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format( | ||
provider_config.msgraph_endpoint | ||
graph_url = ( | ||
"https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format( | ||
provider_config.msgraph_endpoint | ||
) | ||
) | ||
headers = {"Authorization": "Bearer {}".format(obo_access_token)} | ||
response = self._ms_request( | ||
|
@@ -147,25 +158,25 @@ def validate_access_token(self, access_token): | |
# Explicit is better then implicit and it protects against | ||
# changes in the defaults the jwt module uses. | ||
options = { | ||
'verify_signature': True, | ||
'verify_exp': True, | ||
'verify_nbf': True, | ||
'verify_iat': True, | ||
'verify_aud': True, | ||
'verify_iss': True, | ||
'require_exp': False, | ||
'require_iat': False, | ||
'require_nbf': False | ||
"verify_signature": True, | ||
"verify_exp": True, | ||
"verify_nbf": True, | ||
"verify_iat": True, | ||
"verify_aud": True, | ||
"verify_iss": True, | ||
"require_exp": False, | ||
"require_iat": False, | ||
"require_nbf": False, | ||
} | ||
# Validate token and return claims | ||
return jwt.decode( | ||
access_token, | ||
key=key, | ||
algorithms=['RS256', 'RS384', 'RS512'], | ||
algorithms=["RS256", "RS384", "RS512"], | ||
audience=settings.AUDIENCE, | ||
issuer=provider_config.issuer, | ||
options=options, | ||
leeway=settings.JWT_LEEWAY | ||
leeway=settings.JWT_LEEWAY, | ||
) | ||
except jwt.ExpiredSignatureError as error: | ||
logger.info("Signature has expired: %s", error) | ||
|
@@ -175,7 +186,7 @@ def validate_access_token(self, access_token): | |
if idx < len(provider_config.signing_keys) - 1: | ||
continue | ||
else: | ||
logger.info('Error decoding signature: %s', error) | ||
logger.info("Error decoding signature: %s", error) | ||
raise PermissionDenied | ||
except jwt.InvalidTokenError as error: | ||
logger.info(str(error)) | ||
|
@@ -187,12 +198,8 @@ def process_access_token(self, access_token, adfs_response=None): | |
|
||
logger.debug("Received access token: %s", access_token) | ||
claims = self.validate_access_token(access_token) | ||
if ( | ||
settings.BLOCK_GUEST_USERS | ||
and claims.get('tid') | ||
!= settings.TENANT_ID | ||
): | ||
logger.info('Guest user denied') | ||
if settings.BLOCK_GUEST_USERS and claims.get("tid") != settings.TENANT_ID: | ||
logger.info("Guest user denied") | ||
raise PermissionDenied | ||
if not claims: | ||
raise PermissionDenied | ||
|
@@ -204,10 +211,7 @@ def process_access_token(self, access_token, adfs_response=None): | |
self.update_user_flags(user, claims, groups) | ||
|
||
signals.post_authenticate.send( | ||
sender=self, | ||
user=user, | ||
claims=claims, | ||
adfs_response=adfs_response | ||
sender=self, user=user, claims=claims, adfs_response=adfs_response | ||
) | ||
|
||
user.full_clean() | ||
|
@@ -235,7 +239,9 @@ def process_user_groups(self, claims, access_token): | |
if settings.GROUPS_CLAIM in claims: | ||
groups = claims[settings.GROUPS_CLAIM] | ||
if not isinstance(groups, list): | ||
groups = [groups, ] | ||
groups = [ | ||
groups, | ||
] | ||
elif ( | ||
settings.TENANT_ID != "adfs" | ||
and "_claim_names" in claims | ||
|
@@ -244,8 +250,10 @@ def process_user_groups(self, claims, access_token): | |
obo_access_token = self.get_obo_access_token(access_token) | ||
groups = self.get_group_memberships_from_ms_graph(obo_access_token) | ||
else: | ||
logger.debug("The configured groups claim %s was not found in the access token", | ||
settings.GROUPS_CLAIM) | ||
logger.debug( | ||
"The configured groups claim %s was not found in the access token", | ||
settings.GROUPS_CLAIM, | ||
) | ||
|
||
return groups | ||
|
||
|
@@ -264,19 +272,21 @@ def create_user(self, claims): | |
guest_username_claim = settings.GUEST_USERNAME_CLAIM | ||
usermodel = get_user_model() | ||
|
||
iss = claims.get('iss') | ||
idp = claims.get('idp', iss) | ||
iss = claims.get("iss") | ||
idp = claims.get("idp", iss) | ||
if ( | ||
guest_username_claim | ||
and not claims.get(username_claim) | ||
and not settings.BLOCK_GUEST_USERS | ||
and (claims.get('tid') != settings.TENANT_ID or iss != idp) | ||
and (claims.get("tid") != settings.TENANT_ID or iss != idp) | ||
): | ||
username_claim = guest_username_claim | ||
|
||
if not claims.get(username_claim): | ||
logger.error("User claim's doesn't have the claim '%s' in his claims: %s" % | ||
(username_claim, claims)) | ||
logger.error( | ||
"User claim's doesn't have the claim '%s' in his claims: %s" | ||
% (username_claim, claims) | ||
) | ||
raise PermissionDenied | ||
|
||
userdata = {usermodel.USERNAME_FIELD: claims[username_claim]} | ||
|
@@ -288,7 +298,10 @@ def create_user(self, claims): | |
user = usermodel.objects.create(**userdata) | ||
logger.debug("User '%s' has been created.", claims[username_claim]) | ||
else: | ||
logger.debug("User '%s' doesn't exist and creating users is disabled.", claims[username_claim]) | ||
logger.debug( | ||
"User '%s' doesn't exist and creating users is disabled.", | ||
claims[username_claim], | ||
) | ||
raise PermissionDenied | ||
if not user.password: | ||
user.set_unusable_password() | ||
|
@@ -308,27 +321,47 @@ def update_user_attributes(self, user, claims, claim_mapping=None): | |
""" | ||
if claim_mapping is None: | ||
claim_mapping = settings.CLAIM_MAPPING | ||
required_fields = [field.name for field in user._meta.get_fields() if getattr(field, 'blank', True) is False] | ||
required_fields = [ | ||
field.name | ||
for field in user._meta.get_fields() | ||
if getattr(field, "blank", True) is False | ||
] | ||
|
||
for field, claim in claim_mapping.items(): | ||
if hasattr(user, field) or user._meta.fields_map.get(field): | ||
if not isinstance(claim, dict): | ||
if claim in claims: | ||
setattr(user, field, claims[claim]) | ||
logger.debug("Attribute '%s' for instance '%s' was set to '%s'.", field, user, claims[claim]) | ||
logger.debug( | ||
"Attribute '%s' for instance '%s' was set to '%s'.", | ||
field, | ||
user, | ||
claims[claim], | ||
) | ||
else: | ||
if field in required_fields: | ||
msg = "Claim not found in access token: '{}'. Check ADFS claims mapping." | ||
raise ImproperlyConfigured(msg.format(claim)) | ||
else: | ||
logger.warning("Claim '%s' for field '%s' was not found in " | ||
"the access token for instance '%s'. " | ||
"Field is not required and will be left empty", claim, field, user) | ||
logger.warning( | ||
"Claim '%s' for field '%s' was not found in " | ||
"the access token for instance '%s'. " | ||
"Field is not required and will be left empty", | ||
claim, | ||
field, | ||
user, | ||
) | ||
else: | ||
try: | ||
self.update_user_attributes(getattr(user, field), claims, claim_mapping=claim) | ||
self.update_user_attributes( | ||
getattr(user, field), claims, claim_mapping=claim | ||
) | ||
except ObjectDoesNotExist: | ||
logger.warning("Object for field '{}' does not exist for: '{}'.".format(field, user)) | ||
logger.warning( | ||
"Object for field '{}' does not exist for: '{}'.".format( | ||
field, user | ||
) | ||
) | ||
|
||
else: | ||
msg = "Model '{}' has no field named '{}'. Check ADFS claims mapping." | ||
|
@@ -349,6 +382,13 @@ def update_user_groups(self, user, claim_groups): | |
if sorted(claim_groups) != sorted(user_group_names): | ||
# Get the list of already existing groups in one SQL query | ||
existing_claimed_groups = Group.objects.filter(name__in=claim_groups) | ||
existing_claimed_group_names = ( | ||
group.name for group in existing_claimed_groups | ||
) | ||
|
||
if sorted(existing_claimed_group_names) == sorted(user_group_names): | ||
# If the groups are already set, we don't need to do anything | ||
return | ||
|
||
if settings.MIRROR_GROUPS: | ||
existing_claimed_group_names = ( | ||
|
@@ -358,10 +398,12 @@ def update_user_groups(self, user, claim_groups): | |
# bulk_create could have been used here but we want to send signals. | ||
new_claimed_groups = [ | ||
Group.objects.get_or_create(name=name)[0] | ||
for name in claim_groups if name not in existing_claimed_group_names | ||
for name in claim_groups | ||
if name not in existing_claimed_group_names | ||
] | ||
# Associate the users to all claimed groups | ||
user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups)) | ||
# Set user's groups to all claimed groups (both existing and | ||
# newly created) and remove any that are not in the claim. | ||
user.groups.set(new_claimed_groups) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From my understanding, the comment here matches the old code. The new code will set the user's groups to only the groups in the claim that didn't already have a |
||
else: | ||
# Associate the user to only existing claimed groups | ||
user.groups.set(existing_claimed_groups) | ||
|
@@ -381,23 +423,42 @@ def update_user_flags(self, user, claims, claim_groups): | |
if not isinstance(group, list): | ||
group = [group] | ||
|
||
if any(group_list_item in claim_groups for group_list_item in group): | ||
if any( | ||
group_list_item in claim_groups for group_list_item in group | ||
): | ||
value = True | ||
else: | ||
value = False | ||
setattr(user, flag, value) | ||
logger.debug("Attribute '%s' for user '%s' was set to '%s'.", flag, user, value) | ||
logger.debug( | ||
"Attribute '%s' for user '%s' was set to '%s'.", | ||
flag, | ||
user, | ||
value, | ||
) | ||
else: | ||
msg = "User model has no field named '{}'. Check ADFS boolean claims mapping." | ||
raise ImproperlyConfigured(msg.format(flag)) | ||
|
||
for field, claim in settings.BOOLEAN_CLAIM_MAPPING.items(): | ||
if hasattr(user, field): | ||
bool_val = False | ||
if claim in claims and str(claims[claim]).lower() in ['y', 'yes', 't', 'true', 'on', '1']: | ||
if claim in claims and str(claims[claim]).lower() in [ | ||
"y", | ||
"yes", | ||
"t", | ||
"true", | ||
"on", | ||
"1", | ||
]: | ||
bool_val = True | ||
setattr(user, field, bool_val) | ||
logger.debug("Attribute '%s' for user '%s' was set to '%s'.", field, user, bool_val) | ||
logger.debug( | ||
"Attribute '%s' for user '%s' was set to '%s'.", | ||
field, | ||
user, | ||
bool_val, | ||
) | ||
else: | ||
msg = "User model has no field named '{}'. Check ADFS boolean claims mapping." | ||
raise ImproperlyConfigured(msg.format(field)) | ||
|
@@ -411,8 +472,10 @@ class AdfsAuthCodeBackend(AdfsBaseBackend): | |
|
||
def authenticate(self, request=None, authorization_code=None, **kwargs): | ||
# If there's no token or code, we pass control to the next authentication backend | ||
if authorization_code is None or authorization_code == '': | ||
logger.debug("Authentication backend was called but no authorization code was received") | ||
if authorization_code is None or authorization_code == "": | ||
logger.debug( | ||
"Authentication backend was called but no authorization code was received" | ||
) | ||
return | ||
|
||
# If loaded data is too old, reload it again | ||
|
@@ -435,8 +498,10 @@ def authenticate(self, request=None, access_token=None, **kwargs): | |
provider_config.load_config() | ||
|
||
# If there's no token or code, we pass control to the next authentication backend | ||
if access_token is None or access_token == '': | ||
logger.debug("Authentication backend was called but no access token was received") | ||
if access_token is None or access_token == "": | ||
logger.debug( | ||
"Authentication backend was called but no access token was received" | ||
) | ||
return | ||
|
||
access_token = access_token.decode() | ||
|
@@ -445,5 +510,6 @@ def authenticate(self, request=None, access_token=None, **kwargs): | |
|
||
|
||
class AdfsBackend(AdfsAuthCodeBackend): | ||
""" Backwards compatible class name """ | ||
"""Backwards compatible class name""" | ||
|
||
pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a redefinition of
existing_claimed_group_names
on line 394 that can be removed because of this.