diff --git a/axes/handlers/database.py b/axes/handlers/database.py index 64f63576..5852ab4c 100644 --- a/axes/handlers/database.py +++ b/axes/handlers/database.py @@ -1,6 +1,7 @@ from logging import getLogger -from typing import List, Optional +from typing import List, Optional, Type, TypeVar +from django.contrib.auth.base_user import AbstractBaseUser from django.db import router, transaction from django.db.models import F, Q, QuerySet, Sum, Value from django.db.models.functions import Concat @@ -23,6 +24,8 @@ from axes.models import AccessAttempt, AccessFailureLog, AccessLog from axes.signals import user_locked_out +UserModel = TypeVar("UserModel", bound=AbstractBaseUser) + log = getLogger(__name__) @@ -114,7 +117,7 @@ def get_failures(self, request, credentials: Optional[dict] = None) -> int: ) return attempt_count - def user_login_failed(self, sender, credentials: dict, request=None, **kwargs): + def user_login_failed(self, sender: str, credentials: dict, request: Optional[HttpRequest] = None, **kwargs): """ When user login fails, save AccessFailureLog record in database, save AccessAttempt record in database, mark request with @@ -254,7 +257,7 @@ def user_login_failed(self, sender, credentials: dict, request=None, **kwargs): ) self.remove_out_of_limit_failure_logs(username=username) - def user_logged_in(self, sender, request, user, **kwargs): + def user_logged_in(self, sender: Type[UserModel], request: HttpRequest, user: UserModel, **kwargs): """ When user logs in, update the AccessLog related to the user. """ @@ -297,7 +300,9 @@ def user_logged_in(self, sender, request, user, **kwargs): client_str, ) - def user_logged_out(self, sender, request, user, **kwargs): + def user_logged_out( + self, sender: Optional[Type[UserModel]], request: HttpRequest, user: Optional[UserModel], **kwargs + ): """ When user logs out, update the AccessLog related to the user. """ @@ -370,7 +375,7 @@ def get_user_attempts( ] def clean_expired_user_attempts( - self, request: Optional[HttpRequest] = None, credentials: Optional[dict] = None + self, request: HttpRequest, credentials: Optional[dict] = None ) -> int: """ Clean expired user attempts from the database. @@ -382,8 +387,9 @@ def clean_expired_user_attempts( ) return 0 + username = get_client_username(request, credentials) threshold = get_cool_off_threshold(request) - count, _ = AccessAttempt.objects.filter(attempt_time__lt=threshold).delete() + count, _ = AccessAttempt.objects.filter(username=username, attempt_time__lt=threshold).delete() log.info( "AXES: Cleaned up %s expired access attempts from database that were older than %s", count,