diff --git a/axes/admin.py b/axes/admin.py index 48aa95bd..877ab75f 100644 --- a/axes/admin.py +++ b/axes/admin.py @@ -7,14 +7,25 @@ class AccessAttemptAdmin(admin.ModelAdmin): - list_display = ( - "attempt_time", - "ip_address", - "user_agent", - "username", - "path_info", - "failures_since_start", - ) + if settings.AXES_USE_ATTEMPT_EXPIRATION: + list_display = ( + "attempt_time", + "expiration", + "ip_address", + "user_agent", + "username", + "path_info", + "failures_since_start", + ) + else: + list_display = ( + "attempt_time", + "ip_address", + "user_agent", + "username", + "path_info", + "failures_since_start", + ) list_filter = ["attempt_time", "path_info"] @@ -23,7 +34,7 @@ class AccessAttemptAdmin(admin.ModelAdmin): date_hierarchy = "attempt_time" fieldsets = ( - (None, {"fields": ("username", "path_info", "failures_since_start")}), + (None, {"fields": ("username", "path_info", "failures_since_start", "expiration")}), (_("Form Data"), {"fields": ("get_data", "post_data")}), (_("Meta Data"), {"fields": ("user_agent", "ip_address", "http_accept")}), ) @@ -38,11 +49,14 @@ class AccessAttemptAdmin(admin.ModelAdmin): "get_data", "post_data", "failures_since_start", + "expiration", ] def has_add_permission(self, request: HttpRequest) -> bool: return False + def expiration(self, obj: AccessAttempt): + return obj.expiration.expires_at if hasattr(obj, "expiration") else _("Not set") class AccessLogAdmin(admin.ModelAdmin): list_display = ( diff --git a/axes/conf.py b/axes/conf.py index d19a3f41..2de5a1b3 100644 --- a/axes/conf.py +++ b/axes/conf.py @@ -87,6 +87,8 @@ settings.AXES_COOLOFF_TIME = getattr(settings, "AXES_COOLOFF_TIME", None) +settings.AXES_USE_ATTEMPT_EXPIRATION = getattr(settings, "AXES_USE_ATTEMPT_EXPIRATION", False) + settings.AXES_VERBOSE = getattr(settings, "AXES_VERBOSE", settings.AXES_ENABLED) # whitelist and blacklist diff --git a/axes/handlers/database.py b/axes/handlers/database.py index 64f63576..c344300e 100644 --- a/axes/handlers/database.py +++ b/axes/handlers/database.py @@ -19,8 +19,9 @@ get_failure_limit, get_lockout_parameters, get_query_str, + get_attempt_expiration, ) -from axes.models import AccessAttempt, AccessFailureLog, AccessLog +from axes.models import AccessAttempt, AccessAttemptExpiration, AccessFailureLog, AccessLog from axes.signals import user_locked_out log = getLogger(__name__) @@ -219,6 +220,21 @@ def user_login_failed(self, sender, credentials: dict, request=None, **kwargs): client_str, ) + if settings.AXES_USE_ATTEMPT_EXPIRATION: + if not hasattr(attempt, "expiration") or attempt.expiration is None: + log.debug( + "AXES: Creating new AccessAttemptExpiration for %s", client_str + ) + attempt.expiration = AccessAttemptExpiration.objects.create( + access_attempt=attempt, + expires_at=get_attempt_expiration(request) + ) + else: + attempt.expiration.expires_at = max( + get_attempt_expiration(request), attempt.expiration.expires_at + ) + attempt.expiration.save() + # 3. or 4. database query: Calculate the current maximum failure number from the existing attempts failures_since_start = self.get_failures(request, credentials) request.axes_failures_since_start = failures_since_start @@ -382,13 +398,22 @@ def clean_expired_user_attempts( ) return 0 - threshold = get_cool_off_threshold(request) - count, _ = AccessAttempt.objects.filter(attempt_time__lt=threshold).delete() - log.info( - "AXES: Cleaned up %s expired access attempts from database that were older than %s", - count, - threshold, - ) + if settings.AXES_USE_ATTEMPT_EXPIRATION: + threshold = timezone.now() + count, _ = AccessAttempt.objects.filter(expiration__expires_at__lte=threshold).delete() + log.info( + "AXES: Cleaned up %s expired access attempts from database that expiry were older than %s", + count, + threshold, + ) + else: + threshold = get_cool_off_threshold(request) + count, _ = AccessAttempt.objects.filter(attempt_time__lte=threshold).delete() + log.info( + "AXES: Cleaned up %s expired access attempts from database that were older than %s", + count, + threshold, + ) return count def reset_user_attempts( diff --git a/axes/helpers.py b/axes/helpers.py index 198460f8..5e1a5f92 100644 --- a/axes/helpers.py +++ b/axes/helpers.py @@ -1,4 +1,4 @@ -from datetime import timedelta +from datetime import timedelta, datetime from hashlib import sha256 from logging import getLogger from string import Template @@ -100,6 +100,21 @@ def get_cool_off_iso8601(delta: timedelta) -> str: return f"P{days_str}T{time_str}" return f"P{days_str}" +def get_attempt_expiration(request: Optional[HttpRequest] = None) -> datetime: + """ + Get threshold for fetching access attempts from the database. + """ + + cool_off = get_cool_off(request) + if cool_off is None: + raise TypeError( + "Cool off threshold can not be calculated with settings.AXES_COOLOFF_TIME set to None" + ) + + attempt_time = request.axes_attempt_time + if attempt_time is None: + return datetime.now() + cool_off + return attempt_time + cool_off def get_credentials(username: Optional[str] = None, **kwargs) -> dict: """ diff --git a/axes/migrations/0010_accessattemptexpiration.py b/axes/migrations/0010_accessattemptexpiration.py new file mode 100644 index 00000000..74b36890 --- /dev/null +++ b/axes/migrations/0010_accessattemptexpiration.py @@ -0,0 +1,41 @@ +# Generated by Django 5.2.1 on 2025-06-10 20:21 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("axes", "0009_add_session_hash"), + ] + + operations = [ + migrations.CreateModel( + name="AccessAttemptExpiration", + fields=[ + ( + "access_attempt", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + primary_key=True, + related_name="expiration", + serialize=False, + to="axes.accessattempt", + verbose_name="Access Attempt", + ), + ), + ( + "expires_at", + models.DateTimeField( + help_text="The time when access attempt expires and is no longer valid.", + verbose_name="Expires At", + ), + ), + ], + options={ + "verbose_name": "access attempt expiration", + "verbose_name_plural": "access attempt expirations", + }, + ), + ] diff --git a/axes/models.py b/axes/models.py index 9c8a7da3..5658cabb 100644 --- a/axes/models.py +++ b/axes/models.py @@ -51,6 +51,23 @@ class Meta: unique_together = [["username", "ip_address", "user_agent"]] +class AccessAttemptExpiration(models.Model): + access_attempt = models.OneToOneField( + AccessAttempt, + primary_key=True, + on_delete=models.CASCADE, + related_name="expiration", + verbose_name=_("Access Attempt"), + ) + expires_at = models.DateTimeField( + _("Expires At"), + help_text=_("The time when access attempt expires and is no longer valid."), + ) + + class Meta: + verbose_name = _("access attempt expiration") + verbose_name_plural = _("access attempt expirations") + class AccessLog(AccessBase): logout_time = models.DateTimeField(_("Logout Time"), null=True, blank=True) session_hash = models.CharField(_("Session key hash (sha256)"), default="", blank=True, max_length=64) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index ccef945a..99bd1d53 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1,18 +1,20 @@ from platform import python_implementation from unittest.mock import MagicMock, patch +from datetime import datetime, timezone as dt_timezone +from django.test import override_settings +from django.utils import timezone +from axes.handlers.database import AxesDatabaseHandler +from axes.models import AccessAttempt, AccessLog, AccessFailureLog, AccessAttemptExpiration from pytest import mark from django.core.cache import cache -from django.test import override_settings from django.urls import reverse -from django.utils import timezone from django.utils.timezone import timedelta from axes.conf import settings from axes.handlers.proxy import AxesProxyHandler from axes.helpers import get_client_str -from axes.models import AccessAttempt, AccessLog, AccessFailureLog from tests.base import AxesTestCase @@ -567,3 +569,170 @@ def test_handler_is_allowed(self): def test_handler_get_failures(self): self.assertEqual(0, AxesProxyHandler.get_failures(self.request, {})) + + +@override_settings(AXES_HANDLER="axes.handlers.database.AxesDatabaseHandler", AXES_COOLOFF_TIME=timezone.timedelta(seconds=10)) +class AxesDatabaseHandlerExpirationFlagTestCase(AxesTestCase): + def setUp(self): + super().setUp() + self.handler = AxesDatabaseHandler() + self.mock_request = MagicMock() + self.mock_credentials = None + + @override_settings(AXES_USE_ATTEMPT_EXPIRATION=True) + @patch("axes.handlers.database.log") + @patch("axes.models.AccessAttempt.objects.filter") + @patch("django.utils.timezone.now") + def test_clean_expired_user_attempts_expiration_true(self, mock_now, mock_filter, mock_log): + mock_now.return_value = datetime(2025, 1, 1, tzinfo=dt_timezone.utc) + mock_qs = MagicMock() + mock_filter.return_value = mock_qs + mock_qs.delete.return_value = (3, None) + + count = self.handler.clean_expired_user_attempts(request=None, credentials=None) + mock_filter.assert_called_once_with(expiration__expires_at__lte=mock_now.return_value) + mock_qs.delete.assert_called_once() + mock_log.info.assert_called_with( + "AXES: Cleaned up %s expired access attempts from database that expiry were older than %s", + 3, + mock_now.return_value, + ) + self.assertEqual(count, 3) + + @override_settings(AXES_USE_ATTEMPT_EXPIRATION=True) + @patch("axes.handlers.database.log") + def test_clean_expired_user_attempts_expiration_true_with_complete_deletion(self, mock_log): + AccessAttempt.objects.all().delete() + dummy_attempt = AccessAttempt.objects.create( + username="test_user", + ip_address="192.168.1.1", + failures_since_start=1, + user_agent="test_agent", + ) + dummy_attempt.expiration = AccessAttemptExpiration.objects.create( + access_attempt=dummy_attempt, + expires_at=timezone.now() - timezone.timedelta(days=1) # Set to expire in the past + ) + + count = self.handler.clean_expired_user_attempts(request=None, credentials=None) + mock_log.info.assert_called_once() + + # comparing count=2, as one is the dummy attempt and one is the expiration + self.assertEqual(count, 2) + self.assertEqual( + AccessAttempt.objects.count(), 0 + ) + self.assertEqual( + AccessAttemptExpiration.objects.count(), 0 + ) + + @override_settings(AXES_USE_ATTEMPT_EXPIRATION=True) + @patch("axes.handlers.database.log") + def test_clean_expired_user_attempts_expiration_true_with_partial_deletion(self, mock_log): + + attempt_not_expired = AccessAttempt.objects.create( + username="test_user", + ip_address="192.168.1.1", + failures_since_start=1, + user_agent="test_agent", + ) + attempt_not_expired.expiration = AccessAttemptExpiration.objects.create( + access_attempt=attempt_not_expired, + expires_at=timezone.now() + timezone.timedelta(days=1) # Set to expire in the future + ) + + attempt_expired = AccessAttempt.objects.create( + username="test_user_2", + ip_address="192.168.1.2", + failures_since_start=1, + user_agent="test_agent", + ) + attempt_expired.expiration = AccessAttemptExpiration.objects.create( + access_attempt=attempt_expired, + expires_at=timezone.now() - timezone.timedelta(days=1) # Set to expire in the past + ) + + access_attempt_count = AccessAttempt.objects.count() + access_attempt_expiration_count = AccessAttemptExpiration.objects.count() + + count = self.handler.clean_expired_user_attempts(request=None, credentials=None) + mock_log.info.assert_called_once() + + # comparing count=2, as one is the dummy attempt and one is the expiration + self.assertEqual(count, 2) + self.assertEqual( + AccessAttempt.objects.count(), access_attempt_count - 1 + ) + self.assertEqual( + AccessAttemptExpiration.objects.count(), access_attempt_expiration_count - 1 + ) + + @override_settings(AXES_USE_ATTEMPT_EXPIRATION=True) + @patch("axes.handlers.database.log") + def test_clean_expired_user_attempts_expiration_true_with_no_deletion(self, mock_log): + + attempt_not_expired_1 = AccessAttempt.objects.create( + username="test_user", + ip_address="192.168.1.1", + failures_since_start=1, + user_agent="test_agent", + ) + attempt_not_expired_1.expiration = AccessAttemptExpiration.objects.create( + access_attempt=attempt_not_expired_1, + expires_at=timezone.now() + timezone.timedelta(days=1) # Set to expire in the future + ) + + attempt_not_expired_2 = AccessAttempt.objects.create( + username="test_user_2", + ip_address="192.168.1.2", + failures_since_start=1, + user_agent="test_agent", + ) + attempt_not_expired_2.expiration = AccessAttemptExpiration.objects.create( + access_attempt=attempt_not_expired_2, + expires_at=timezone.now() + timezone.timedelta(days=2) # Set to expire in the future + ) + + access_attempt_count = AccessAttempt.objects.count() + access_attempt_expiration_count = AccessAttemptExpiration.objects.count() + + count = self.handler.clean_expired_user_attempts(request=None, credentials=None) + mock_log.info.assert_called_once() + + # comparing count=2, as one is the dummy attempt and one is the expiration + self.assertEqual(count, 0) + self.assertEqual( + AccessAttempt.objects.count(), access_attempt_count + ) + self.assertEqual( + AccessAttemptExpiration.objects.count(), access_attempt_expiration_count + ) + + @override_settings(AXES_USE_ATTEMPT_EXPIRATION=False) + @patch("axes.handlers.database.log") + @patch("axes.handlers.database.get_cool_off_threshold") + @patch("axes.models.AccessAttempt.objects.filter") + def test_clean_expired_user_attempts_expiration_false(self, mock_filter, mock_get_threshold, mock_log): + mock_get_threshold.return_value = "fake-threshold" + mock_qs = MagicMock() + mock_filter.return_value = mock_qs + mock_qs.delete.return_value = (2, None) + + count = self.handler.clean_expired_user_attempts(request=self.mock_request, credentials=None) + mock_filter.assert_called_once_with(attempt_time__lte="fake-threshold") + mock_qs.delete.assert_called_once() + mock_log.info.assert_called_with( + "AXES: Cleaned up %s expired access attempts from database that were older than %s", + 2, + "fake-threshold", + ) + self.assertEqual(count, 2) + + @override_settings(AXES_COOLOFF_TIME=None) + @patch("axes.handlers.database.log") + def test_clean_expired_user_attempts_no_cooloff(self, mock_log): + count = self.handler.clean_expired_user_attempts(request=None, credentials=None) + mock_log.debug.assert_called_with( + "AXES: Skipping clean for expired access attempts because no AXES_COOLOFF_TIME is configured" + ) + self.assertEqual(count, 0)