Skip to content

Commit e51135c

Browse files
authored
Merge pull request #255 from stephane/speedup-group-creation
Reduce the number of SQL queries in updates of groups
2 parents de92fa1 + e440441 commit e51135c

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

django_auth_adfs/backend.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from django.contrib.auth import get_user_model
55
from django.contrib.auth.backends import ModelBackend
66
from django.contrib.auth.models import Group
7-
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied
7+
from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist,
8+
PermissionDenied)
89

910
from django_auth_adfs import signals
1011
from django_auth_adfs.config import provider_config, settings
@@ -322,27 +323,27 @@ def update_user_groups(self, user, claim_groups):
322323
"""
323324
if settings.GROUPS_CLAIM is not None:
324325
# Update the user's group memberships
325-
django_groups = [group.name for group in user.groups.all()]
326+
user_group_names = user.groups.all().values_list("name", flat=True)
327+
328+
if sorted(claim_groups) != sorted(user_group_names):
329+
# Get the list of already existing groups in one SQL query
330+
existing_claimed_groups = Group.objects.filter(name__in=claim_groups)
326331

327-
if sorted(claim_groups) != sorted(django_groups):
328-
existing_groups = list(Group.objects.filter(name__in=claim_groups).iterator())
329-
existing_group_names = frozenset(group.name for group in existing_groups)
330-
new_groups = []
331332
if settings.MIRROR_GROUPS:
332-
new_groups = [
333+
existing_claimed_group_names = (
334+
group.name for group in existing_claimed_groups
335+
)
336+
# One SQL query by created group.
337+
# bulk_create could have been used here but we want to send signals.
338+
new_claimed_groups = [
333339
Group.objects.get_or_create(name=name)[0]
334-
for name in claim_groups
335-
if name not in existing_group_names
340+
for name in claim_groups if name not in existing_claimed_group_names
336341
]
342+
# Associate the users to all claimed groups
343+
user.groups.set(tuple(existing_claimed_groups) + tuple(new_claimed_groups))
337344
else:
338-
for name in claim_groups:
339-
if name not in existing_group_names:
340-
try:
341-
group = Group.objects.get(name=name)
342-
new_groups.append(group)
343-
except ObjectDoesNotExist:
344-
pass
345-
user.groups.set(existing_groups + new_groups)
345+
# Associate the user to only existing claimed groups
346+
user.groups.set(existing_claimed_groups)
346347

347348
def update_user_flags(self, user, claims, claim_groups):
348349
"""

tests/test_authentication.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,22 @@
33
from django_auth_adfs.exceptions import MFARequired
44

55
try:
6-
from urllib.parse import urlparse, parse_qs
6+
from urllib.parse import parse_qs, urlparse
77
except ImportError: # Python 2.7
88
from urlparse import urlparse, parse_qs
99

1010
from copy import deepcopy
1111

12-
from django.contrib.auth.models import User, Group
12+
from django.contrib.auth.models import Group, User
1313
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
1414
from django.db.models.signals import post_save
15-
from django.test import TestCase, RequestFactory
15+
from django.test import RequestFactory, TestCase
1616
from mock import Mock, patch
1717

1818
from django_auth_adfs import signals
1919
from django_auth_adfs.backend import AdfsAuthCodeBackend
2020
from django_auth_adfs.config import ProviderConfig, Settings
21+
2122
from .models import Profile
2223
from .utils import mock_adfs
2324

@@ -175,6 +176,38 @@ def test_no_group_claim(self):
175176
self.assertEqual(user.email, "john.doe@example.com")
176177
self.assertEqual(len(user.groups.all()), 0)
177178

179+
@mock_adfs("2016")
180+
def test_group_claim_with_mirror_groups(self):
181+
# Remove one group
182+
Group.objects.filter(name="group1").delete()
183+
184+
backend = AdfsAuthCodeBackend()
185+
with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True):
186+
user = backend.authenticate(self.request, authorization_code="dummycode")
187+
self.assertIsInstance(user, User)
188+
self.assertEqual(user.first_name, "John")
189+
self.assertEqual(user.last_name, "Doe")
190+
self.assertEqual(user.email, "john.doe@example.com")
191+
# group1 is restored
192+
group_names = user.groups.order_by("name").values_list("name", flat=True)
193+
self.assertSequenceEqual(group_names, ['group1', 'group2'])
194+
195+
@mock_adfs("2016")
196+
def test_group_claim_without_mirror_groups(self):
197+
# Remove one group
198+
Group.objects.filter(name="group1").delete()
199+
200+
backend = AdfsAuthCodeBackend()
201+
with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", False):
202+
user = backend.authenticate(self.request, authorization_code="dummycode")
203+
self.assertIsInstance(user, User)
204+
self.assertEqual(user.first_name, "John")
205+
self.assertEqual(user.last_name, "Doe")
206+
self.assertEqual(user.email, "john.doe@example.com")
207+
# User is not added to group1 because the group doesn't exist
208+
group_names = user.groups.values_list("name", flat=True)
209+
self.assertSequenceEqual(group_names, ['group2'])
210+
178211
@mock_adfs("2016", empty_keys=True)
179212
def test_empty_keys(self):
180213
backend = AdfsAuthCodeBackend()

0 commit comments

Comments
 (0)