Skip to content

Reduce SQL queries and restrict to claimed groups #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 131 additions & 65 deletions django_auth_adfs/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from django.contrib.auth.models import Group
from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist,
PermissionDenied)
from django.core.exceptions import (
ImproperlyConfigured,
ObjectDoesNotExist,
PermissionDenied,
)

from django_auth_adfs import signals
from django_auth_adfs.config import provider_config, settings
Expand All @@ -15,7 +18,6 @@


class AdfsBaseBackend(ModelBackend):

def _ms_request(self, action, url, data=None, **kwargs):
"""
Make a Microsoft Entra/GraphQL request
Expand All @@ -36,7 +38,10 @@ def _ms_request(self, action, url, data=None, **kwargs):
if response.status_code == 400:
if response.json().get("error_description", "").startswith("AADSTS50076"):
raise MFARequired
logger.error("ADFS server returned an error: %s", response.json()["error_description"])
logger.error(
"ADFS server returned an error: %s",
response.json()["error_description"],
)
raise PermissionDenied

if response.status_code != 200:
Expand All @@ -47,16 +52,18 @@ def _ms_request(self, action, url, data=None, **kwargs):
def exchange_auth_code(self, authorization_code, request):
logger.debug("Received authorization code: %s", authorization_code)
data = {
'grant_type': 'authorization_code',
'client_id': settings.CLIENT_ID,
'redirect_uri': provider_config.redirect_uri(request),
'code': authorization_code,
"grant_type": "authorization_code",
"client_id": settings.CLIENT_ID,
"redirect_uri": provider_config.redirect_uri(request),
"code": authorization_code,
}
if settings.CLIENT_SECRET:
data['client_secret'] = settings.CLIENT_SECRET
data["client_secret"] = settings.CLIENT_SECRET

logger.debug("Getting access token at: %s", provider_config.token_endpoint)
response = self._ms_request(provider_config.session.post, provider_config.token_endpoint, data)
response = self._ms_request(
provider_config.session.post, provider_config.token_endpoint, data
)
adfs_response = response.json()
return adfs_response

Expand All @@ -79,11 +86,13 @@ def get_obo_access_token(self, access_token):
"requested_token_use": "on_behalf_of",
}
if provider_config.token_endpoint.endswith("/v2.0/token"):
data["scope"] = 'GroupMember.Read.All'
data["scope"] = "GroupMember.Read.All"
else:
data["resource"] = 'https://graph.microsoft.com'
data["resource"] = "https://graph.microsoft.com"

response = self._ms_request(provider_config.session.get, provider_config.token_endpoint, data)
response = self._ms_request(
provider_config.session.get, provider_config.token_endpoint, data
)
obo_access_token = response.json()["access_token"]
logger.debug("Received OBO access token: %s", obo_access_token)
return obo_access_token
Expand Down Expand Up @@ -117,8 +126,10 @@ def get_group_memberships_from_ms_graph(self, obo_access_token):
Returns:
claim_groups (list): List of the users group memberships
"""
graph_url = "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format(
provider_config.msgraph_endpoint
graph_url = (
"https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group".format(
provider_config.msgraph_endpoint
)
)
headers = {"Authorization": "Bearer {}".format(obo_access_token)}
response = self._ms_request(
Expand Down Expand Up @@ -147,25 +158,25 @@ def validate_access_token(self, access_token):
# Explicit is better then implicit and it protects against
# changes in the defaults the jwt module uses.
options = {
'verify_signature': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iat': True,
'verify_aud': True,
'verify_iss': True,
'require_exp': False,
'require_iat': False,
'require_nbf': False
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"verify_aud": True,
"verify_iss": True,
"require_exp": False,
"require_iat": False,
"require_nbf": False,
}
# Validate token and return claims
return jwt.decode(
access_token,
key=key,
algorithms=['RS256', 'RS384', 'RS512'],
algorithms=["RS256", "RS384", "RS512"],
audience=settings.AUDIENCE,
issuer=provider_config.issuer,
options=options,
leeway=settings.JWT_LEEWAY
leeway=settings.JWT_LEEWAY,
)
except jwt.ExpiredSignatureError as error:
logger.info("Signature has expired: %s", error)
Expand All @@ -175,7 +186,7 @@ def validate_access_token(self, access_token):
if idx < len(provider_config.signing_keys) - 1:
continue
else:
logger.info('Error decoding signature: %s', error)
logger.info("Error decoding signature: %s", error)
raise PermissionDenied
except jwt.InvalidTokenError as error:
logger.info(str(error))
Expand All @@ -187,12 +198,8 @@ def process_access_token(self, access_token, adfs_response=None):

logger.debug("Received access token: %s", access_token)
claims = self.validate_access_token(access_token)
if (
settings.BLOCK_GUEST_USERS
and claims.get('tid')
!= settings.TENANT_ID
):
logger.info('Guest user denied')
if settings.BLOCK_GUEST_USERS and claims.get("tid") != settings.TENANT_ID:
logger.info("Guest user denied")
raise PermissionDenied
if not claims:
raise PermissionDenied
Expand All @@ -204,10 +211,7 @@ def process_access_token(self, access_token, adfs_response=None):
self.update_user_flags(user, claims, groups)

signals.post_authenticate.send(
sender=self,
user=user,
claims=claims,
adfs_response=adfs_response
sender=self, user=user, claims=claims, adfs_response=adfs_response
)

user.full_clean()
Expand Down Expand Up @@ -235,7 +239,9 @@ def process_user_groups(self, claims, access_token):
if settings.GROUPS_CLAIM in claims:
groups = claims[settings.GROUPS_CLAIM]
if not isinstance(groups, list):
groups = [groups, ]
groups = [
groups,
]
elif (
settings.TENANT_ID != "adfs"
and "_claim_names" in claims
Expand All @@ -244,8 +250,10 @@ def process_user_groups(self, claims, access_token):
obo_access_token = self.get_obo_access_token(access_token)
groups = self.get_group_memberships_from_ms_graph(obo_access_token)
else:
logger.debug("The configured groups claim %s was not found in the access token",
settings.GROUPS_CLAIM)
logger.debug(
"The configured groups claim %s was not found in the access token",
settings.GROUPS_CLAIM,
)

return groups

Expand All @@ -264,19 +272,21 @@ def create_user(self, claims):
guest_username_claim = settings.GUEST_USERNAME_CLAIM
usermodel = get_user_model()

iss = claims.get('iss')
idp = claims.get('idp', iss)
iss = claims.get("iss")
idp = claims.get("idp", iss)
if (
guest_username_claim
and not claims.get(username_claim)
and not settings.BLOCK_GUEST_USERS
and (claims.get('tid') != settings.TENANT_ID or iss != idp)
and (claims.get("tid") != settings.TENANT_ID or iss != idp)
):
username_claim = guest_username_claim

if not claims.get(username_claim):
logger.error("User claim's doesn't have the claim '%s' in his claims: %s" %
(username_claim, claims))
logger.error(
"User claim's doesn't have the claim '%s' in his claims: %s"
% (username_claim, claims)
)
raise PermissionDenied

userdata = {usermodel.USERNAME_FIELD: claims[username_claim]}
Expand All @@ -288,7 +298,10 @@ def create_user(self, claims):
user = usermodel.objects.create(**userdata)
logger.debug("User '%s' has been created.", claims[username_claim])
else:
logger.debug("User '%s' doesn't exist and creating users is disabled.", claims[username_claim])
logger.debug(
"User '%s' doesn't exist and creating users is disabled.",
claims[username_claim],
)
raise PermissionDenied
if not user.password:
user.set_unusable_password()
Expand All @@ -308,27 +321,47 @@ def update_user_attributes(self, user, claims, claim_mapping=None):
"""
if claim_mapping is None:
claim_mapping = settings.CLAIM_MAPPING
required_fields = [field.name for field in user._meta.get_fields() if getattr(field, 'blank', True) is False]
required_fields = [
field.name
for field in user._meta.get_fields()
if getattr(field, "blank", True) is False
]

for field, claim in claim_mapping.items():
if hasattr(user, field) or user._meta.fields_map.get(field):
if not isinstance(claim, dict):
if claim in claims:
setattr(user, field, claims[claim])
logger.debug("Attribute '%s' for instance '%s' was set to '%s'.", field, user, claims[claim])
logger.debug(
"Attribute '%s' for instance '%s' was set to '%s'.",
field,
user,
claims[claim],
)
else:
if field in required_fields:
msg = "Claim not found in access token: '{}'. Check ADFS claims mapping."
raise ImproperlyConfigured(msg.format(claim))
else:
logger.warning("Claim '%s' for field '%s' was not found in "
"the access token for instance '%s'. "
"Field is not required and will be left empty", claim, field, user)
logger.warning(
"Claim '%s' for field '%s' was not found in "
"the access token for instance '%s'. "
"Field is not required and will be left empty",
claim,
field,
user,
)
else:
try:
self.update_user_attributes(getattr(user, field), claims, claim_mapping=claim)
self.update_user_attributes(
getattr(user, field), claims, claim_mapping=claim
)
except ObjectDoesNotExist:
logger.warning("Object for field '{}' does not exist for: '{}'.".format(field, user))
logger.warning(
"Object for field '{}' does not exist for: '{}'.".format(
field, user
)
)

else:
msg = "Model '{}' has no field named '{}'. Check ADFS claims mapping."
Expand All @@ -349,6 +382,13 @@ def update_user_groups(self, user, claim_groups):
if sorted(claim_groups) != sorted(user_group_names):
# Get the list of already existing groups in one SQL query
existing_claimed_groups = Group.objects.filter(name__in=claim_groups)
existing_claimed_group_names = (
group.name for group in existing_claimed_groups
)
Comment on lines +385 to +387
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a redefinition of existing_claimed_group_names on line 394 that can be removed because of this.


if sorted(existing_claimed_group_names) == sorted(user_group_names):
# If the groups are already set, we don't need to do anything
return

if settings.MIRROR_GROUPS:
existing_claimed_group_names = (
Expand All @@ -358,10 +398,12 @@ def update_user_groups(self, user, claim_groups):
# bulk_create could have been used here but we want to send signals.
new_claimed_groups = [
Group.objects.get_or_create(name=name)[0]
for name in claim_groups if name not in existing_claimed_group_names
for name in claim_groups
if name not in existing_claimed_group_names
]
# Associate the users to all claimed groups
user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups))
# Set user's groups to all claimed groups (both existing and
# newly created) and remove any that are not in the claim.
user.groups.set(new_claimed_groups)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, the comment here matches the old code. The new code will set the user's groups to only the groups in the claim that didn't already have a Group instance for them.

else:
# Associate the user to only existing claimed groups
user.groups.set(existing_claimed_groups)
Expand All @@ -381,23 +423,42 @@ def update_user_flags(self, user, claims, claim_groups):
if not isinstance(group, list):
group = [group]

if any(group_list_item in claim_groups for group_list_item in group):
if any(
group_list_item in claim_groups for group_list_item in group
):
value = True
else:
value = False
setattr(user, flag, value)
logger.debug("Attribute '%s' for user '%s' was set to '%s'.", flag, user, value)
logger.debug(
"Attribute '%s' for user '%s' was set to '%s'.",
flag,
user,
value,
)
else:
msg = "User model has no field named '{}'. Check ADFS boolean claims mapping."
raise ImproperlyConfigured(msg.format(flag))

for field, claim in settings.BOOLEAN_CLAIM_MAPPING.items():
if hasattr(user, field):
bool_val = False
if claim in claims and str(claims[claim]).lower() in ['y', 'yes', 't', 'true', 'on', '1']:
if claim in claims and str(claims[claim]).lower() in [
"y",
"yes",
"t",
"true",
"on",
"1",
]:
bool_val = True
setattr(user, field, bool_val)
logger.debug("Attribute '%s' for user '%s' was set to '%s'.", field, user, bool_val)
logger.debug(
"Attribute '%s' for user '%s' was set to '%s'.",
field,
user,
bool_val,
)
else:
msg = "User model has no field named '{}'. Check ADFS boolean claims mapping."
raise ImproperlyConfigured(msg.format(field))
Expand All @@ -411,8 +472,10 @@ class AdfsAuthCodeBackend(AdfsBaseBackend):

def authenticate(self, request=None, authorization_code=None, **kwargs):
# If there's no token or code, we pass control to the next authentication backend
if authorization_code is None or authorization_code == '':
logger.debug("Authentication backend was called but no authorization code was received")
if authorization_code is None or authorization_code == "":
logger.debug(
"Authentication backend was called but no authorization code was received"
)
return

# If loaded data is too old, reload it again
Expand All @@ -435,8 +498,10 @@ def authenticate(self, request=None, access_token=None, **kwargs):
provider_config.load_config()

# If there's no token or code, we pass control to the next authentication backend
if access_token is None or access_token == '':
logger.debug("Authentication backend was called but no access token was received")
if access_token is None or access_token == "":
logger.debug(
"Authentication backend was called but no access token was received"
)
return

access_token = access_token.decode()
Expand All @@ -445,5 +510,6 @@ def authenticate(self, request=None, access_token=None, **kwargs):


class AdfsBackend(AdfsAuthCodeBackend):
""" Backwards compatible class name """
"""Backwards compatible class name"""

pass
Loading