diff --git a/Makefile b/Makefile index 83b7a02e5..a612a1e85 100644 --- a/Makefile +++ b/Makefile @@ -37,10 +37,12 @@ build-docs: *confest* \ tests/* \ rest_framework_simplejwt/token_blacklist/* \ + rest_framework_simplejwt/token_family/* \ rest_framework_simplejwt/backends.py \ rest_framework_simplejwt/exceptions.py \ rest_framework_simplejwt/settings.py \ - rest_framework_simplejwt/state.py + rest_framework_simplejwt/state.py \ + rest_framework_simplejwt/cache.py $(MAKE) -C docs clean $(MAKE) -C docs html $(MAKE) -C docs doctest diff --git a/docs/cache_support.rst b/docs/cache_support.rst new file mode 100644 index 000000000..ec3ec8492 --- /dev/null +++ b/docs/cache_support.rst @@ -0,0 +1,64 @@ +Cache Support +============== + +SimpleJWT provides optional cache support for improving performance. +Currently, caching is available for: + +- Blacklisted refresh tokens +- Blacklisted token families + +To enable caching in SimpleJWT, you must first configure Django's ``CACHES`` setting: + +.. code-block:: python + + CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "unique-name", + }, + "redis": { + "BACKEND": "django_redis.cache.RedisCache", + "LOCATION": "redis://127.0.0.1:6379/0", + "OPTIONS": { + "CLIENT_CLASS": "django_redis.client.DefaultClient", + } + } + } + +In this example, two cache backends are defined. You can choose which one to use by +setting the ``SJWT_CACHE_NAME`` option in your SimpleJWT configuration. For this case, +it could be either `"default"` or `"redis"`. + +.. Note:: + Ensure that your Django `CACHES` setting includes a cache matching the alias + defined by `SJWT_CACHE_NAME`. + + +Blacklist Cache +---------------- + +When enabled via the appropriate settings, blacklisted refresh tokens and token families +will be cached. This reduces the number of database queries when verifying whether a token +or family is blacklisted. + +.. code-block:: python + + SIMPLE_JWT = { + ... + + "SJWT_CACHE_NAME": "default", + "CACHE_BLACKLISTED_REFRESH_TOKENS": True, + "CACHE_BLACKLISTED_FAMILIES": True, + "CACHE_TTL_BLACKLISTED_REFRESH_TOKENS": 3600, #time in seconds + "CACHE_TTL_BLACKLISTED_FAMILIES": 3600, #time in seconds + "CACHE_KEY_PREFIX_BLACKLISTED_REFRESH_TOKENS": "sjwt_brt", + "CACHE_KEY_PREFIX_BLACKLISTED_FAMILIES": "sjwt_btf", + + ... + } + + +Cache keys follow this format: + +- Refresh token: ``sjwt_brt:`` +- Token family: ``sjwt_btf:`` diff --git a/docs/conf.py b/docs/conf.py index f8efc6b2f..346d72911 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,6 +34,7 @@ def django_configure(): "rest_framework", "rest_framework_simplejwt", "rest_framework_simplejwt.token_blacklist", + "rest_framework_simplejwt.token_family", ), ) diff --git a/docs/family_app.rst b/docs/family_app.rst new file mode 100644 index 000000000..0025cbf8a --- /dev/null +++ b/docs/family_app.rst @@ -0,0 +1,116 @@ +.. _family_app: + +Family app +=========== + +The **Token Family** system provides a way to group refresh and access tokens into logical families, +allowing developers to track, manage, and invalidate related tokens as a unit. + +This feature is especially useful in enhancing security and traceability in authentication flows. +Each token family is identified by a unique ``family_id``, which is included in the token payload. +This enables the system to: + +- Detect and respond to refresh token reuse by invalidating the entire token family. +- Revoke all related tokens at once. +- Enforce expiration policies at the family level via a ``family_exp`` claim. + +A new token family is automatically created every time a user successfully obtains a pair of tokens +from the ``TokenObtainPairView`` (i.e., when starting a new session). From that point onward, as long as the user +continues to refresh their tokens, the newly issued access and refresh tokens will retain the same +``family_id`` and ``family_exp`` values. This means all tokens issued as part of a session are +considered to belong to the same token family. + +This session-based grouping allows administrators or systems to treat the token family as the unit of trust. +If suspicious activity is detected, the entire session can be invalidated at once by blacklisting the +associated token family. + +The Token Family system is optional and customizable. It works best when paired with the +:doc:`/blacklist_app`. + +By organizing tokens into families, you gain finer control over user sessions and potential compromises. +For example, when the settings ``BLACKLIST_AFTER_ROTATION`` and ``TOKEN_FAMILY_BLACKLIST_ON_REUSE`` are +set to ``True``, if a refresh token is stolen and used by both the valid user and the attacker, +the system will detect the reuse of the token and automatically blacklist the associated family. +This invalidates every token that shares the same ``family_id`` as the reused refresh token, effectively +cutting off access without waiting for individual token expiration. + +------- + +Simple JWT includes an app that provides token family functionality. To use +this app, include it in your list of installed apps in ``settings.py``: + +.. code-block:: python + + # Django project settings.py + + ... + + INSTALLED_APPS = ( + ... + 'rest_framework_simplejwt.token_family', + ... + ) + +Also, make sure to run ``python manage.py migrate`` to run the app's +migrations. + +If the token family app is detected in ``INSTALLED_APPS`` and the setting +``TOKEN_FAMILY_ENABLED`` is set to ``True``, Simple JWT will add a new family +to the family list and will also add two new claims, "family_id" and +"family_exp", to the refresh tokens. It will also check that the +family indicated in the token's payload does not appear in a blacklist of +families before it considers it as valid, and it will also check that the +family expiration date ("family_exp") has not passed; if the family is expired, +then the token will be considered as invalid. + +The Simple JWT family app implements its family and blacklisted family +lists using two models: ``TokenFamily`` and ``BlacklistedTokenFamily``. Model +admins are defined for both of these models. To add a family to the blacklist, +find its corresponding ``TokenFamily`` record in the admin and use the +admin again to create a ``BlacklistedTokenFamily`` record that points to the +``TokenFamily`` record. + +Alternatively, you can blacklist a family by creating a ``FamilyMixin`` +subclass instance and calling the instance's ``blacklist_family`` method: + +.. code-block:: python + + from rest_framework_simplejwt.tokens import RefreshToken + + token = RefreshToken(base64_encoded_token_string) + token.blacklist_family() + +Keep in mind that the ``base64_encoded_token_string`` should already +contain a family ID claim in its payload. + +This will create a unique family and blacklist records for the token's +"family_id" claim or whichever claim is specified by the ``TOKEN_FAMILY_CLAIM`` setting. + + +In a ``urls.py`` file, you can also include a route for ``TokenFamilyBlacklistView``: + +.. code-block:: python + + from rest_framework_simplejwt.views import TokenFamilyBlacklistView + + urlpatterns = [ + ... + path('api/token/family/blacklist/', TokenFamilyBlacklistView.as_view(), name='token_family_blacklist'), + ... + ] + +It allows API users to blacklist token families sending them to ``/api/token/family/blacklist/``, for example using curl: + +.. code-block:: bash + + curl \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{"refresh":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0b2tlbl90eXBlIjoicmVmcmVzaCIsImV4cCI6MTc0NzI0OTU1MywiaWF0IjoxNzQ3MjQ0MTUzLCJqdGkiOiI1YmMzMjlmMjVkODE0OGFhOTY1ODI1YjgwNDQ1ZDQ5OCIsInVzZXJfaWQiOjIsImZhbWlseV9pZCI6ImMyZGYyM2M1YjU1NjRmYjNhNTA3MjFhYzVkMTljNThmIiwiZmFtaWx5X2V4cCI6MTc0NzI0OTE1M30.4oDOmtkgot_W2mXByKuCyJLi6_xeMZtDQJmHIBXZx98"}' \ + http://localhost:8000/api/token/family/blacklist/ + +The family app also provides a management command, ``flushexpiredfamilies``, +which will delete any families from the token family list and family blacklist that have +expired. The command will not affect families that have ``None`` as their expiration. +You should set up a cron job on your server or hosting platform which +runs this command daily. diff --git a/docs/index.rst b/docs/index.rst index fbf54629e..d675758ba 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -51,7 +51,9 @@ Contents creating_tokens_manually token_types blacklist_app + family_app stateless_user_authentication + cache_support development_and_contributing drf_yasg_integration rest_framework_simplejwt diff --git a/docs/settings.rst b/docs/settings.rst index c99fb45c3..3055e2bbf 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -20,6 +20,11 @@ Some of Simple JWT's behavior can be customized through settings variables in "BLACKLIST_AFTER_ROTATION": False, "UPDATE_LAST_LOGIN": False, + "TOKEN_FAMILY_ENABLED": False, + "TOKEN_FAMILY_LIFETIME": timedelta(days=30), + "TOKEN_FAMILY_CHECK_ON_ACCESS": False, + "TOKEN_FAMILY_BLACKLIST_ON_REUSE": False, + "ALGORITHM": "HS256", "SIGNING_KEY": settings.SECRET_KEY, "VERIFYING_KEY": "", @@ -29,6 +34,14 @@ Some of Simple JWT's behavior can be customized through settings variables in "JWK_URL": None, "LEEWAY": 0, + "SJWT_CACHE_NAME": "default", + "CACHE_BLACKLISTED_REFRESH_TOKENS": False, + "CACHE_BLACKLISTED_FAMILIES": False, + "CACHE_TTL_BLACKLISTED_REFRESH_TOKENS": 3600, # time is seconds + "CACHE_TTL_BLACKLISTED_FAMILIES": 3600, # time in seconds + "CACHE_KEY_PREFIX_BLACKLISTED_REFRESH_TOKENS": "sjwt_brt", + "CACHE_KEY_PREFIX_BLACKLISTED_FAMILIES": "sjwt_btf", + "AUTH_HEADER_TYPES": ("Bearer",), "AUTH_HEADER_NAME": "HTTP_AUTHORIZATION", "USER_ID_FIELD": "id", @@ -43,6 +56,9 @@ Some of Simple JWT's behavior can be customized through settings variables in "JTI_CLAIM": "jti", + "TOKEN_FAMILY_CLAIM": "family_id", + "TOKEN_FAMILY_EXPIRATION_CLAIM": "family_exp", + "SLIDING_TOKEN_REFRESH_EXP_CLAIM": "refresh_exp", "SLIDING_TOKEN_LIFETIME": timedelta(minutes=5), "SLIDING_TOKEN_REFRESH_LIFETIME": timedelta(days=1), @@ -53,6 +69,7 @@ Some of Simple JWT's behavior can be customized through settings variables in "TOKEN_BLACKLIST_SERIALIZER": "rest_framework_simplejwt.serializers.TokenBlacklistSerializer", "SLIDING_TOKEN_OBTAIN_SERIALIZER": "rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer", "SLIDING_TOKEN_REFRESH_SERIALIZER": "rest_framework_simplejwt.serializers.TokenRefreshSlidingSerializer", + "TOKEN_FAMILY_BLACKLIST_SERIALIZER": "rest_framework_simplejwt.serializers.TokenFamilyBlacklistSerializer", } Above, the default values for these settings are shown. @@ -105,6 +122,46 @@ login (TokenObtainPairView). a security vulnerability. If you really want this, throttle the endpoint with DRF at the very least. +``TOKEN_FAMILY_ENABLED`` +---------------------------- + +When set to ``True``, enables the Token Family tracking system. This allows +refresh tokens to be grouped into families using a shared identifier. By default, +this identifier is only included in refresh tokens, but it can also be added to +access tokens if ``TOKEN_FAMILY_CHECK_ON_ACCESS`` is set to ``True``. +Families can be invalidated as a whole, meaning all tokens associated with the +same family will then be considered invalid. +You need to add ``'rest_framework_simplejwt.token_family',`` to your +``INSTALLED_APPS`` in the settings file to use this setting. + +This feature is most effective when used in conjunction with the :doc:`/blacklist_app` + +Learn more about :doc:`/family_app`. + +``TOKEN_FAMILY_LIFETIME`` +---------------------------- + +A ``datetime.timedelta`` object that specifies how long a token family is considered valid. +This ``timedelta`` value is added to the current UTC time during token generation +to obtain the token's default "family_exp" claim value. +This setting can also be set to ``None``, in which case the "family_exp" claim +will not be included in the token payload and the token family will never expire automatically. +In that case, the only way to invalidate the family is by blacklisting it. + +``TOKEN_FAMILY_CHECK_ON_ACCESS`` +------------------------------------- + +When set to ``True``, the token family claims ("family_id" and "family_exp") will be included +in the access token payload. Requests authenticated with access tokens will then verify +that the token's family is valid, meaning it has not expired and has not been blacklisted. + +``TOKEN_FAMILY_BLACKLIST_ON_REUSE`` +------------------------------------- + +When set to ``True``, any detected reuse of a refresh token will trigger blacklisting of +the entire token family. This invalidates all tokens that share the same family identifier. +This feature can be enhanced when used together with ``BLACKLIST_AFTER_ROTATION`` set to ``True``. + ``ALGORITHM`` ------------- @@ -175,6 +232,47 @@ integer for seconds or a ``datetime.timedelta``. Please reference https://pyjwt.readthedocs.io/en/latest/usage.html#expiration-time-claim-exp for more information. +``SJWT_CACHE_NAME`` +--------------------- + +Specifies the Django cache alias to use. This must match a defined entry +in Django's ``CACHES`` setting. + +Learn more about :doc:`/cache_support`. + +``CACHE_BLACKLISTED_REFRESH_TOKENS`` +-------------------------------------- + +When set to ``True``, enables caching of blacklisted refresh tokens. +Blacklisted refresh token entries will be cached for a period defined +by ``CACHE_TTL_BLACKLISTED_REFRESH_TOKENS``. + +``CACHE_BLACKLISTED_FAMILIES`` +-------------------------------- + +When set to ``True``, enables caching of blacklisted token families. +Blacklisted family entries will be cached for a period defined +by ``CACHE_TTL_BLACKLISTED_FAMILIES``. + +``CACHE_TTL_BLACKLISTED_REFRESH_TOKENS`` +------------------------------------------ + +Time-to-live (TTL) in seconds for cached refresh token blacklist entries. + +``CACHE_TTL_BLACKLISTED_FAMILIES`` +------------------------------------ + +Time-to-live (TTL) in seconds for cached token family blacklist entries. + +``CACHE_KEY_PREFIX_BLACKLISTED_REFRESH_TOKENS`` +------------------------------------------------- + +Prefix used for cache keys when storing blacklisted refresh tokens. + +``CACHE_KEY_PREFIX_BLACKLISTED_FAMILIES`` +------------------------------------------- + +Prefix used for cache keys when storing blacklisted token families. ``AUTH_HEADER_TYPES`` --------------------- @@ -259,6 +357,18 @@ identifier is used to identify revoked tokens in the blacklist app. It may be necessary in some cases to use another claim besides the default "jti" claim to store such a value. +``TOKEN_FAMILY_CLAIM`` +--------------------------- + +The claim name used to store the token family's unique identifier in the token +payload. Defaults to "family_id". + +``TOKEN_FAMILY_EXPIRATION_CLAIM`` +------------------------------------- + +The claim name used to store the token family's expiration date in the token +payload. Defaults to "family_exp". + ``TOKEN_USER_CLASS`` -------------------- diff --git a/rest_framework_simplejwt/cache.py b/rest_framework_simplejwt/cache.py new file mode 100644 index 000000000..54df443fb --- /dev/null +++ b/rest_framework_simplejwt/cache.py @@ -0,0 +1,65 @@ +from django.core.cache import caches + +from .settings import api_settings + + +class BlacklistCache: + """ + Cache implementation for Simple JWT blacklist functionalities. + Provides caching for both blacklisted refresh tokens and token families. + """ + + @property + def cache(self): + """Get the configured cache backend for Simple JWT.""" + return caches[api_settings.SJWT_CACHE_NAME] + + @property + def is_refresh_tokens_cache_enabled(self): + """Check if refresh token caching is enabled.""" + return api_settings.CACHE_BLACKLISTED_REFRESH_TOKENS + + @property + def is_families_cache_enabled(self): + """Check if token family caching is enabled.""" + return api_settings.CACHE_BLACKLISTED_FAMILIES + + def _get_refresh_token_key(self, jti: str) -> str: + """Generate cache key for a refresh token JTI.""" + return f"{api_settings.CACHE_KEY_PREFIX_BLACKLISTED_REFRESH_TOKENS}:{jti}" + + def _get_family_key(self, family_id: str) -> str: + """Generate cache key for a token family ID.""" + return f"{api_settings.CACHE_KEY_PREFIX_BLACKLISTED_FAMILIES}:{family_id}" + + def add_refresh_token(self, jti: str) -> None: + """Stores refresh token JTI in the cache.""" + key = self._get_refresh_token_key(jti) + self.cache.set( + key, True, timeout=api_settings.CACHE_TTL_BLACKLISTED_REFRESH_TOKENS + ) + + def is_refresh_token_blacklisted(self, jti: str) -> bool: + """Checks if a refresh token JTI is blacklisted in cache.""" + return self.cache.get(self._get_refresh_token_key(jti), False) + + def add_token_family(self, family_id: str) -> None: + """Stores a token family ID in the cache.""" + key = self._get_family_key(family_id) + self.cache.set(key, True, timeout=api_settings.CACHE_TTL_BLACKLISTED_FAMILIES) + + def is_token_family_blacklisted(self, family_id: str) -> bool: + """Checks if a token family is blacklisted in cache.""" + return self.cache.get(self._get_family_key(family_id), False) + + def delete_refresh_token_from_cache(self, jti: str) -> bool: + """Returns True if the token jti was successfully deleted, False otherwise.""" + return self.cache.delete(self._get_refresh_token_key(jti)) + + def delete_family_from_cache(self, family_id: str) -> bool: + """Returns True if the family ID was successfully deleted, False otherwise.""" + return self.cache.delete(self._get_family_key(family_id)) + + +# Singleton instance for centralized cache management +blacklist_cache = BlacklistCache() diff --git a/rest_framework_simplejwt/exceptions.py b/rest_framework_simplejwt/exceptions.py index 8cc58e976..07b867225 100644 --- a/rest_framework_simplejwt/exceptions.py +++ b/rest_framework_simplejwt/exceptions.py @@ -20,6 +20,12 @@ class TokenBackendExpiredToken(TokenBackendError): pass +class RefreshTokenBlacklistedError(TokenError): + """Raised when a refresh token is found in the blacklist.""" + + pass + + class DetailDictMixin: default_detail: str default_code: str @@ -54,3 +60,11 @@ class InvalidToken(AuthenticationFailed): status_code = status.HTTP_401_UNAUTHORIZED default_detail = _("Token is invalid or expired") default_code = "token_not_valid" + + +class TokenFamilyNotConfigured(DetailDictMixin, exceptions.APIException): + status_code = status.HTTP_501_NOT_IMPLEMENTED + default_detail = _( + "Token family functionality is not enabled or available. Please check your configuration." + ) + default_code = "token_family_not_configured" diff --git a/rest_framework_simplejwt/serializers.py b/rest_framework_simplejwt/serializers.py index ab76a0ff7..1c4256eb9 100644 --- a/rest_framework_simplejwt/serializers.py +++ b/rest_framework_simplejwt/serializers.py @@ -8,16 +8,20 @@ from rest_framework.exceptions import AuthenticationFailed, ValidationError from rest_framework.request import Request +from .cache import blacklist_cache +from .exceptions import ( + RefreshTokenBlacklistedError, + TokenError, + TokenFamilyNotConfigured, +) from .models import TokenUser from .settings import api_settings -from .tokens import RefreshToken, SlidingToken, Token, UntypedToken +from .token_blacklist.models import BlacklistedToken +from .tokens import FamilyMixin, RefreshToken, SlidingToken, Token, UntypedToken from .utils import get_md5_hash_password AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser) -if api_settings.BLACKLIST_AFTER_ROTATION: - from .token_blacklist.models import BlacklistedToken - class PasswordField(serializers.CharField): def __init__(self, *args, **kwargs) -> None: @@ -115,7 +119,7 @@ class TokenRefreshSerializer(serializers.Serializer): } def validate(self, attrs: dict[str, Any]) -> dict[str, str]: - refresh = self.token_class(attrs["refresh"]) + refresh = self._get_refresh_token(attrs["refresh"]) user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None) if user_id: @@ -174,6 +178,26 @@ def validate(self, attrs: dict[str, Any]) -> dict[str, str]: return data + def _get_refresh_token(self, token_str: str) -> RefreshToken: + """ + Handles refresh token instantiation and family blacklisting if enabled. + """ + try: + return self.token_class(token_str) + except RefreshTokenBlacklistedError as e: + if ( + api_settings.TOKEN_FAMILY_ENABLED + and api_settings.TOKEN_FAMILY_BLACKLIST_ON_REUSE + and "rest_framework_simplejwt.token_family" in settings.INSTALLED_APPS + ): + refresh = self.token_class(token=token_str, verify=False) + family_id = refresh.get_family_id() + + if family_id: + refresh.blacklist_family() + + raise e + class TokenRefreshSlidingSerializer(serializers.Serializer): token = serializers.CharField() @@ -241,13 +265,26 @@ def validate(self, attrs: dict[str, None]) -> dict[Any, Any]: token = UntypedToken(attrs["token"]) if ( - api_settings.BLACKLIST_AFTER_ROTATION + token.get(api_settings.TOKEN_TYPE_CLAIM) == RefreshToken.token_type and "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS ): jti = token.get(api_settings.JTI_CLAIM) + if ( + blacklist_cache.is_refresh_tokens_cache_enabled + and blacklist_cache.is_refresh_token_blacklisted(jti) + ): + raise ValidationError(_("Token is blacklisted")) + if BlacklistedToken.objects.filter(token__jti=jti).exists(): raise ValidationError(_("Token is blacklisted")) + if ( + api_settings.TOKEN_FAMILY_ENABLED + and "rest_framework_simplejwt.token_family" in settings.INSTALLED_APPS + ): + FamilyMixin.check_family_expiration(token=token) + FamilyMixin.check_family_blacklist(token=token) + return {} @@ -264,6 +301,22 @@ def validate(self, attrs: dict[str, Any]) -> dict[Any, Any]: return {} +class TokenFamilyBlacklistSerializer(serializers.Serializer): + refresh = serializers.CharField(write_only=True) + token_class = RefreshToken + + def validate(self, attrs: dict[str, Any]) -> dict[Any, Any]: + refresh = self.token_class(attrs["refresh"]) + try: + refresh.blacklist_family() + except AttributeError: + raise TokenFamilyNotConfigured() + except TokenError as e: + raise serializers.ValidationError({"refresh": str(e)}) + + return {"message": "Token Family blacklisted"} + + def default_on_login_success(user: AuthUser, request: Optional[Request]) -> None: update_last_login(None, user) diff --git a/rest_framework_simplejwt/settings.py b/rest_framework_simplejwt/settings.py index 5c4d9fc4a..e2a454c10 100644 --- a/rest_framework_simplejwt/settings.py +++ b/rest_framework_simplejwt/settings.py @@ -16,6 +16,10 @@ "ROTATE_REFRESH_TOKENS": False, "BLACKLIST_AFTER_ROTATION": False, "UPDATE_LAST_LOGIN": False, + "TOKEN_FAMILY_ENABLED": False, + "TOKEN_FAMILY_LIFETIME": timedelta(days=30), + "TOKEN_FAMILY_CHECK_ON_ACCESS": False, + "TOKEN_FAMILY_BLACKLIST_ON_REUSE": False, "ALGORITHM": "HS256", "SIGNING_KEY": settings.SECRET_KEY, "VERIFYING_KEY": "", @@ -24,6 +28,13 @@ "JSON_ENCODER": None, "JWK_URL": None, "LEEWAY": 0, + "SJWT_CACHE_NAME": "default", + "CACHE_BLACKLISTED_REFRESH_TOKENS": False, + "CACHE_BLACKLISTED_FAMILIES": False, + "CACHE_TTL_BLACKLISTED_REFRESH_TOKENS": 3600, # time is seconds + "CACHE_TTL_BLACKLISTED_FAMILIES": 3600, # time in seconds + "CACHE_KEY_PREFIX_BLACKLISTED_REFRESH_TOKENS": "sjwt_brt", + "CACHE_KEY_PREFIX_BLACKLISTED_FAMILIES": "sjwt_btf", "AUTH_HEADER_TYPES": ("Bearer",), "AUTH_HEADER_NAME": "HTTP_AUTHORIZATION", "USER_ID_FIELD": "id", @@ -34,6 +45,8 @@ "AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",), "TOKEN_TYPE_CLAIM": "token_type", "JTI_CLAIM": "jti", + "TOKEN_FAMILY_CLAIM": "family_id", + "TOKEN_FAMILY_EXPIRATION_CLAIM": "family_exp", "TOKEN_USER_CLASS": "rest_framework_simplejwt.models.TokenUser", "SLIDING_TOKEN_REFRESH_EXP_CLAIM": "refresh_exp", "SLIDING_TOKEN_LIFETIME": timedelta(minutes=5), @@ -44,6 +57,7 @@ "TOKEN_BLACKLIST_SERIALIZER": "rest_framework_simplejwt.serializers.TokenBlacklistSerializer", "SLIDING_TOKEN_OBTAIN_SERIALIZER": "rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer", "SLIDING_TOKEN_REFRESH_SERIALIZER": "rest_framework_simplejwt.serializers.TokenRefreshSlidingSerializer", + "TOKEN_FAMILY_BLACKLIST_SERIALIZER": "rest_framework_simplejwt.serializers.TokenFamilyBlacklistSerializer", "CHECK_REVOKE_TOKEN": False, "REVOKE_TOKEN_CLAIM": "hash_password", "CHECK_USER_IS_ACTIVE": True, diff --git a/rest_framework_simplejwt/token_family/__init__.py b/rest_framework_simplejwt/token_family/__init__.py new file mode 100644 index 000000000..7a45bb583 --- /dev/null +++ b/rest_framework_simplejwt/token_family/__init__.py @@ -0,0 +1,4 @@ +from django import VERSION + +if VERSION < (3, 2): + default_app_config = "rest_framework_simplejwt.token_family.apps.TokenFamilyConfig" diff --git a/rest_framework_simplejwt/token_family/admin.py b/rest_framework_simplejwt/token_family/admin.py new file mode 100644 index 000000000..c3c966364 --- /dev/null +++ b/rest_framework_simplejwt/token_family/admin.py @@ -0,0 +1,101 @@ +from datetime import datetime +from typing import Any, Optional, TypeVar + +from django.contrib import admin +from django.contrib.auth.models import AbstractBaseUser +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ +from rest_framework.request import Request + +from ..models import TokenUser +from .models import BlacklistedTokenFamily, TokenFamily + +AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser) + + +class TokenFamilyAdmin(admin.ModelAdmin): + list_display = ( + "family_id", + "user", + "created_at", + "expires_at", + ) + search_fields = ( + "user__id", + "family_id", + ) + ordering = ("user",) + + def get_queryset(self, *args, **kwargs) -> QuerySet: + qs = super().get_queryset(*args, **kwargs) + + return qs.select_related("user") + + # Read-only behavior defined below + actions = None + + def get_readonly_fields(self, *args, **kwargs) -> list[Any]: + return [f.name for f in self.model._meta.fields] + + def has_add_permission(self, *args, **kwargs) -> bool: + return False + + def has_delete_permission(self, *args, **kwargs) -> bool: + return False + + def has_change_permission( + self, request: Request, obj: Optional[object] = None + ) -> bool: + return request.method in ["GET", "HEAD"] and super().has_change_permission( + request, obj + ) + + +admin.site.register(TokenFamily, TokenFamilyAdmin) + + +class BlacklistedTokenFamilyAdmin(admin.ModelAdmin): + list_display = ( + "family_id", + "family_user", + "family_created_at", + "family_expires_at", + "blacklisted_at", + ) + search_fields = ( + "family__user__id", + "family__family_id", + ) + ordering = ("family__user",) + + def get_queryset(self, *args, **kwargs) -> QuerySet: + qs = super().get_queryset(*args, **kwargs) + + return qs.select_related("family__user") + + def family_id(self, obj: BlacklistedTokenFamily) -> str: + return obj.family.family_id + + family_id.short_description = _("family id") # type: ignore + family_id.admin_order_field = "family__family_id" # type: ignore + + def family_user(self, obj: BlacklistedTokenFamily) -> AuthUser: + return obj.family.user + + family_user.short_description = _("user") # type: ignore + family_user.admin_order_field = "family__user" # type: ignore + + def family_created_at(self, obj: BlacklistedTokenFamily) -> datetime: + return obj.family.created_at + + family_created_at.short_description = _("created at") # type: ignore + family_created_at.admin_order_field = "family__created_at" # type: ignore + + def family_expires_at(self, obj: BlacklistedTokenFamily) -> datetime: + return obj.family.expires_at + + family_expires_at.short_description = _("expires at") # type: ignore + family_expires_at.admin_order_field = "family__expires_at" # type: ignore + + +admin.site.register(BlacklistedTokenFamily, BlacklistedTokenFamilyAdmin) diff --git a/rest_framework_simplejwt/token_family/apps.py b/rest_framework_simplejwt/token_family/apps.py new file mode 100644 index 000000000..18f7bfd73 --- /dev/null +++ b/rest_framework_simplejwt/token_family/apps.py @@ -0,0 +1,76 @@ +from datetime import timedelta + +from django.apps import AppConfig +from django.core.exceptions import ImproperlyConfigured +from django.utils.module_loading import import_string +from django.utils.translation import gettext_lazy as _ + +from rest_framework_simplejwt.settings import api_settings + + +class TokenFamilyConfig(AppConfig): + name = "rest_framework_simplejwt.token_family" + verbose_name = _("Token Family") + default_auto_field = "django.db.models.BigAutoField" + + def ready(self): + """Validate token family settings at startup.""" + try: + self._validate_family_settings() + except (ImproperlyConfigured, ImportError) as e: + raise ImproperlyConfigured(f"Invalid Token Family settings: {e}") from e + + @staticmethod + def _validate_family_settings() -> None: + """ + Ensures that required token family settings are properly configured. + This way we prevent undesired behavior. + """ + family_claim = api_settings.TOKEN_FAMILY_CLAIM + if not isinstance(family_claim, str) or not family_claim.strip(): + raise ImproperlyConfigured("TOKEN_FAMILY_CLAIM must be a non-empty string") + + family_exp_claim = api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM + if not isinstance(family_exp_claim, str) or not family_exp_claim.strip(): + raise ImproperlyConfigured( + "TOKEN_FAMILY_EXPIRATION_CLAIM must be a non-empty string" + ) + + family_lifetime = api_settings.TOKEN_FAMILY_LIFETIME + if family_lifetime is not None and not isinstance(family_lifetime, timedelta): + raise ImproperlyConfigured( + "TOKEN_FAMILY_LIFETIME must be of type timedelta or None" + ) + + family_enabled = api_settings.TOKEN_FAMILY_ENABLED + if not isinstance(family_enabled, bool): + raise ImproperlyConfigured("TOKEN_FAMILY_ENABLED must be of type bool") + + check_on_access = api_settings.TOKEN_FAMILY_CHECK_ON_ACCESS + if not isinstance(check_on_access, bool): + raise ImproperlyConfigured( + "TOKEN_FAMILY_CHECK_ON_ACCESS must be of type bool" + ) + + blacklist_on_reuse = api_settings.TOKEN_FAMILY_BLACKLIST_ON_REUSE + if not isinstance(blacklist_on_reuse, bool): + raise ImproperlyConfigured( + "TOKEN_FAMILY_BLACKLIST_ON_REUSE must be of type bool" + ) + + # Validate TOKEN_FAMILY_BLACKLIST_SERIALIZER + blacklist_serializer_path = api_settings.TOKEN_FAMILY_BLACKLIST_SERIALIZER + if ( + not isinstance(blacklist_serializer_path, str) + or not blacklist_serializer_path.strip() + ): + raise ImproperlyConfigured( + "TOKEN_FAMILY_BLACKLIST_SERIALIZER must be a non-empty string" + ) + + try: + import_string(blacklist_serializer_path) + except ImportError as e: + raise ImportError( + f"Could not import serializer '{blacklist_serializer_path}': {e}" + ) from e diff --git a/rest_framework_simplejwt/token_family/management/__init__.py b/rest_framework_simplejwt/token_family/management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rest_framework_simplejwt/token_family/management/commands/__init__.py b/rest_framework_simplejwt/token_family/management/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rest_framework_simplejwt/token_family/management/commands/flushexpiredfamilies.py b/rest_framework_simplejwt/token_family/management/commands/flushexpiredfamilies.py new file mode 100644 index 000000000..781918cfa --- /dev/null +++ b/rest_framework_simplejwt/token_family/management/commands/flushexpiredfamilies.py @@ -0,0 +1,12 @@ +from django.core.management.base import BaseCommand + +from rest_framework_simplejwt.utils import aware_utcnow + +from ...models import TokenFamily + + +class Command(BaseCommand): + help = "Flushes expired token families that have a defined expiration date. Families without an expiration date are not affected." + + def handle(self, *args, **kwargs) -> None: + TokenFamily.objects.filter(expires_at__lte=aware_utcnow()).delete() diff --git a/rest_framework_simplejwt/token_family/migrations/0001_initial.py b/rest_framework_simplejwt/token_family/migrations/0001_initial.py new file mode 100644 index 000000000..b225a3644 --- /dev/null +++ b/rest_framework_simplejwt/token_family/migrations/0001_initial.py @@ -0,0 +1,58 @@ +# Generated by Django 5.1.7 on 2025-04-02 02:00 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="TokenFamily", + fields=[ + ("id", models.BigAutoField(primary_key=True, serialize=False)), + ("family_id", models.CharField(max_length=255, unique=True)), + ("created_at", models.DateTimeField(blank=True)), + ("expires_at", models.DateTimeField(blank=True, null=True)), + ( + "user", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="token_families", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Token Family", + "verbose_name_plural": "Token Families", + }, + ), + migrations.CreateModel( + name="TokenFamilyBlacklist", + fields=[ + ("id", models.BigAutoField(primary_key=True, serialize=False)), + ("blacklisted_at", models.DateTimeField(auto_now_add=True)), + ( + "family", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="blacklisted", + to="token_family.tokenfamily", + ), + ), + ], + options={ + "verbose_name": "Token Family Blacklist", + "verbose_name_plural": "Blacklisted Token Families", + }, + ), + ] diff --git a/rest_framework_simplejwt/token_family/migrations/0002_rename_tokenfamilyblacklist_blacklistedtokenfamily.py b/rest_framework_simplejwt/token_family/migrations/0002_rename_tokenfamilyblacklist_blacklistedtokenfamily.py new file mode 100644 index 000000000..bca3ada46 --- /dev/null +++ b/rest_framework_simplejwt/token_family/migrations/0002_rename_tokenfamilyblacklist_blacklistedtokenfamily.py @@ -0,0 +1,16 @@ +# Generated by Django 5.1.7 on 2025-04-28 05:38 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("token_family", "0001_initial"), + ] + + operations = [ + migrations.RenameModel( + old_name="TokenFamilyBlacklist", + new_name="BlacklistedTokenFamily", + ), + ] diff --git a/rest_framework_simplejwt/token_family/migrations/__init__.py b/rest_framework_simplejwt/token_family/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rest_framework_simplejwt/token_family/models.py b/rest_framework_simplejwt/token_family/models.py new file mode 100644 index 000000000..b0575edf1 --- /dev/null +++ b/rest_framework_simplejwt/token_family/models.py @@ -0,0 +1,72 @@ +from django.conf import settings +from django.db import models +from django.utils.translation import gettext_lazy as _ + + +class TokenFamily(models.Model): + id = models.BigAutoField(primary_key=True, serialize=False) + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="token_families", + ) + + family_id = models.CharField( + unique=True, null=False, max_length=255 + ) # Unique token family identifier + + created_at = models.DateTimeField(null=False, blank=True) + expires_at = models.DateTimeField(null=True, blank=True) + + class Meta: + verbose_name = _("Token Family") + verbose_name_plural = _("Token Families") + # Work around for a bug in Django: + # https://code.djangoproject.com/ticket/19422 + # + # Also see corresponding ticket: + # https://github.com/encode/django-rest-framework/issues/705 + # + # NOTE: Although this issue did not manifest in the Django shell (calling save() + # raised an error as expected), it did occur when running the tests. + abstract = ( + "rest_framework_simplejwt.token_family" not in settings.INSTALLED_APPS + ) + + def __str__(self) -> str: + return _("Token Family for %(user)s (%(family_id)s)") % { + "user": self.user, + "family_id": self.family_id, + } + + +class BlacklistedTokenFamily(models.Model): + id = models.BigAutoField(primary_key=True, serialize=False) + family = models.OneToOneField( + TokenFamily, on_delete=models.CASCADE, related_name="blacklisted" + ) + + blacklisted_at = models.DateTimeField(auto_now_add=True) + + class Meta: + verbose_name = _("Token Family Blacklist") + verbose_name_plural = _("Blacklisted Token Families") + # Work around for a bug in Django: + # https://code.djangoproject.com/ticket/19422 + # + # Also see corresponding ticket: + # https://github.com/encode/django-rest-framework/issues/705 + # + # NOTE: Although this issue did not manifest in the Django shell (calling save() + # raised an error as expected), it did occur when running the tests. + abstract = ( + "rest_framework_simplejwt.token_family" not in settings.INSTALLED_APPS + ) + + def __str__(self) -> str: + return _("Blacklisted Token Family (%(family_id)s) for %(user)s") % { + "family_id": self.family.family_id, + "user": self.family.user, + } diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index 11a05770a..f48c49491 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -8,8 +8,10 @@ from django.utils.module_loading import import_string from django.utils.translation import gettext_lazy as _ +from .cache import blacklist_cache from .exceptions import ( ExpiredTokenError, + RefreshTokenBlacklistedError, TokenBackendError, TokenBackendExpiredToken, TokenError, @@ -17,6 +19,7 @@ from .models import TokenUser from .settings import api_settings from .token_blacklist.models import BlacklistedToken, OutstandingToken +from .token_family.models import BlacklistedTokenFamily, TokenFamily from .utils import ( aware_utcnow, datetime_from_epoch, @@ -276,10 +279,16 @@ def check_blacklist(self) -> None: """ jti = self.payload[api_settings.JTI_CLAIM] + if ( + blacklist_cache.is_refresh_tokens_cache_enabled + and blacklist_cache.is_refresh_token_blacklisted(jti) + ): + raise RefreshTokenBlacklistedError(_("Token is blacklisted")) + if BlacklistedToken.objects.filter(token__jti=jti).exists(): - raise TokenError(_("Token is blacklisted")) + raise RefreshTokenBlacklistedError(_("Token is blacklisted")) - def blacklist(self) -> BlacklistedToken: + def blacklist(self) -> tuple[BlacklistedToken, bool]: """ Ensures this token is included in the outstanding token list and adds it to the blacklist. @@ -304,7 +313,14 @@ def blacklist(self) -> BlacklistedToken: }, ) - return BlacklistedToken.objects.get_or_create(token=token) + blacklisted_token, created = BlacklistedToken.objects.get_or_create( + token=token + ) + + if blacklist_cache.is_refresh_tokens_cache_enabled: + blacklist_cache.add_refresh_token(jti) + + return blacklisted_token, created def outstand(self) -> Optional[OutstandingToken]: """ @@ -352,6 +368,183 @@ def for_user(cls: type[T], user: AuthUser) -> T: return token +class FamilyMixin(Generic[T]): + """ + Tokens created from FamilyMixin subclasses will track token families, + enhancing the ability to detect and manage unwanted refresh token reuse. + + This is useful for implementing security measures such as blacklisting + entire token families upon detected misuse. + """ + + payload: dict[str, Any] + + if ( + api_settings.TOKEN_FAMILY_ENABLED + and "rest_framework_simplejwt.token_family" in settings.INSTALLED_APPS + ): + + def verify(self, *args, **kwargs) -> None: + """ + Runs verification checks for token family expiration and blacklist status + before calling the superclass verification. + """ + self.__class__.check_family_expiration(token=self) + self.__class__.check_family_blacklist(token=self) + + super().verify(*args, **kwargs) # type: ignore + + def blacklist_family(self) -> tuple[BlacklistedTokenFamily, bool]: + """ + Blacklists the token family. + """ + family_id = self.get_family_id() + if not family_id: + raise TokenError(_("Token has no family ID")) + + # Ensure Family exist with the given family_id + family, created = TokenFamily.objects.get_or_create( + family_id=family_id, + defaults={ + "user": self._get_user(), + "created_at": self.current_time, + "expires_at": self.get_family_expiration_date(), + }, + ) + + # Blacklist the entire family + blacklisted_fam, created = BlacklistedTokenFamily.objects.get_or_create( + family=family + ) + + if blacklist_cache.is_families_cache_enabled: + blacklist_cache.add_token_family(family_id) + + return blacklisted_fam, created + + def get_family_id(self) -> Optional[str]: + return self.payload.get(api_settings.TOKEN_FAMILY_CLAIM, None) + + def get_family_expiration_date(self) -> Optional[datetime]: + expires_at = self.payload.get( + api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM, None + ) + + if expires_at is None: + return None + + return datetime_from_epoch(expires_at) + + def _get_user(self) -> Optional[AuthUser]: + """ + Retrieves the user associated with this token. + Returns None if the user does not exist. + """ + user_id = self.payload.get(api_settings.USER_ID_CLAIM) + if not user_id: + return None + + User = get_user_model() + try: + return User.objects.get(**{api_settings.USER_ID_FIELD: user_id}) + except User.DoesNotExist: + return None + + @staticmethod + def check_family_blacklist(token: T) -> None: + """ + Checks if this token's family is blacklisted. + Raises `TokenError` if so. + + If the token does not have a `family_id`, it is either an old + token (before this feature was added/enabled) or it could be a + manually issued JWT without family tracking. In such cases, we + skip the blacklist check. + """ + family_id = token.get(api_settings.TOKEN_FAMILY_CLAIM) + + if not family_id: + user_id = token.get(api_settings.USER_ID_CLAIM) + logger.warning( + f"Token of user:{user_id} does not have a family_id. Skipping family blacklist check." + ) + return + + if ( + blacklist_cache.is_families_cache_enabled + and blacklist_cache.is_token_family_blacklisted(family_id) + ): + raise TokenError(_("Token family is blacklisted")) + + if BlacklistedTokenFamily.objects.filter( + family__family_id=family_id + ).exists(): + raise TokenError(_("Token family is blacklisted")) + + @staticmethod + def check_family_expiration( + token: T, current_time: Optional[datetime] = None + ) -> None: + """ + Checks whether the token family's expiration timestamp has passed + (relative to the given `current_time`). + + Raises a `TokenError` with a user-facing error message if the family has expired. + If no expiration claim is set, the check is skipped. + """ + expires_at = token.get(api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM, None) + + if expires_at is None: + return # No expiration set, so we skip this check. + + expiration_date = datetime_from_epoch(expires_at) + + if current_time is None: + current_time = aware_utcnow() + + if expiration_date <= current_time: + raise TokenError(_("Token family has expired")) + + @classmethod + def for_user(cls: type[T], user: AuthUser) -> T: + """ + Generates a new token instance with a unique family ID and optional family expiration. + + This method: + - Creates a unique `family_id`. + - Assigns a family expiration timestamp if `TOKEN_FAMILY_LIFETIME` is set. + - Saves the token family information in the database. + """ + token = super().for_user(user) # type: ignore + + # Generate a new family ID + family_id = uuid4().hex + token[api_settings.TOKEN_FAMILY_CLAIM] = family_id + + family_lifetime = api_settings.TOKEN_FAMILY_LIFETIME + expires_at: Optional[datetime] + + # Since the token_family settings values are checked at startup, + # we don't have to worry about checking the value type again. + if family_lifetime is None: + expires_at = None + else: + expires_at = token.current_time + family_lifetime + token[api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM] = datetime_to_epoch( + expires_at + ) + + # Create the token family + TokenFamily.objects.create( + user=user, + family_id=family_id, + created_at=token.current_time, + expires_at=expires_at, + ) + + return token + + class SlidingToken(BlacklistMixin["SlidingToken"], Token): token_type = "sliding" lifetime = api_settings.SLIDING_TOKEN_LIFETIME @@ -372,8 +565,20 @@ class AccessToken(Token): token_type = "access" lifetime = api_settings.ACCESS_TOKEN_LIFETIME + def verify(self): + """Runs standard verification and optionally checks token family status.""" + super().verify() -class RefreshToken(BlacklistMixin["RefreshToken"], Token): + if ( + api_settings.TOKEN_FAMILY_ENABLED + and api_settings.TOKEN_FAMILY_CHECK_ON_ACCESS + and "rest_framework_simplejwt.token_family" in settings.INSTALLED_APPS + ): + FamilyMixin.check_family_expiration(token=self) + FamilyMixin.check_family_blacklist(token=self) + + +class RefreshToken(BlacklistMixin["RefreshToken"], FamilyMixin["RefreshToken"], Token): token_type = "refresh" lifetime = api_settings.REFRESH_TOKEN_LIFETIME no_copy_claims = ( @@ -404,7 +609,23 @@ def access_token(self) -> AccessToken: # a pair. access.set_exp(from_time=self.current_time) - no_copy = self.no_copy_claims + # Convert tuple to set for efficient updates. + # This allows us to dynamically add or remove claims without creating a new tuple + no_copy = set(self.no_copy_claims) + + # If TOKEN_FAMILY_CHECK_ON_ACCESS is False, the family claims are not needed in the access token. + # We exclude them from being copied to reduce unnecessary token size. + if ( + api_settings.TOKEN_FAMILY_ENABLED + and not api_settings.TOKEN_FAMILY_CHECK_ON_ACCESS + ): + no_copy.update( + { + api_settings.TOKEN_FAMILY_CLAIM, + api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM, + } + ) + for claim, value in self.payload.items(): if claim in no_copy: continue diff --git a/rest_framework_simplejwt/views.py b/rest_framework_simplejwt/views.py index f0c257dc6..1e41958ed 100644 --- a/rest_framework_simplejwt/views.py +++ b/rest_framework_simplejwt/views.py @@ -120,3 +120,15 @@ class TokenBlacklistView(TokenViewBase): token_blacklist = TokenBlacklistView.as_view() + + +class TokenFamilyBlacklistView(TokenViewBase): + """ + Takes a token's family and blacklists it. Must be used with the + `rest_framework_simplejwt.token_family` app installed. + """ + + _serializer_class = api_settings.TOKEN_FAMILY_BLACKLIST_SERIALIZER + + +token_family_blacklist = TokenFamilyBlacklistView.as_view() diff --git a/tests/conftest.py b/tests/conftest.py index 5248ff08c..6898c51ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,11 +36,24 @@ def pytest_configure(): "rest_framework", "rest_framework_simplejwt", "rest_framework_simplejwt.token_blacklist", + "rest_framework_simplejwt.token_family", "tests", ), PASSWORD_HASHERS=("django.contrib.auth.hashers.MD5PasswordHasher",), SIMPLE_JWT={ "BLACKLIST_AFTER_ROTATION": True, + "TOKEN_FAMILY_ENABLED": True, + "SJWT_CACHE_NAME": "default", + }, + CACHES={ + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "unique-snowflake", + }, + "alternate": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "alternate-snowflake", + }, }, ) diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 000000000..5890d92af --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,388 @@ +# These tests verify the cache blacklisting functionalities. +# +# We primarily use Django's `LocMemCache` backend for these tests. +# This is because LocMemCache is an in-memory backend that is fast, +# requires no external services (like Redis or Memcached), and adheres +# strictly to the standard Django cache API. +# +# The cache integration logic managed via the BlacklistCache class *only* uses the standard +# methods provided by the Django cache API (like `set`, `get`, `delete`). +# It does NOT rely on any backend-specific features. +# +# BECAUSE the code only uses this standard API, testing thoroughly with +# a compliant backend like LocMemCache is considered sufficient. This ensures +# that the logic works correctly with *any other* compliant Django cache backend +# (such as RedisCache, MemcachedCache, etc.) that a user might configure in production, +# as they all implement the same standard API. + +from datetime import timedelta + +from django.contrib.auth.models import User +from django.core.cache import InvalidCacheBackendError, caches +from django.test import TestCase +from freezegun import freeze_time + +from rest_framework_simplejwt.cache import BlacklistCache, blacklist_cache +from rest_framework_simplejwt.exceptions import RefreshTokenBlacklistedError, TokenError +from rest_framework_simplejwt.settings import api_settings +from rest_framework_simplejwt.token_blacklist.models import BlacklistedToken +from rest_framework_simplejwt.token_family.models import BlacklistedTokenFamily +from rest_framework_simplejwt.tokens import RefreshToken +from rest_framework_simplejwt.utils import aware_utcnow +from tests.utils import override_api_settings + + +class TestBlacklistCache(TestCase): + """ + Test the BlacklistCache class and its integration with token blacklisting. + """ + + def setUp(self): + # Create a fresh cache instance for each test + self.cache = BlacklistCache() + # Clear the cache before each test + self.cache.cache.clear() + + def test_cache_properties(self): + """Test that the cache properties read from settings correctly""" + with override_api_settings( + CACHE_BLACKLISTED_REFRESH_TOKENS=True, CACHE_BLACKLISTED_FAMILIES=False + ): + self.assertTrue(self.cache.is_refresh_tokens_cache_enabled) + self.assertFalse(self.cache.is_families_cache_enabled) + + with override_api_settings( + CACHE_BLACKLISTED_REFRESH_TOKENS=False, CACHE_BLACKLISTED_FAMILIES=True + ): + self.assertFalse(self.cache.is_refresh_tokens_cache_enabled) + self.assertTrue(self.cache.is_families_cache_enabled) + + @override_api_settings( + CACHE_KEY_PREFIX_BLACKLISTED_REFRESH_TOKENS="test_rt", + CACHE_KEY_PREFIX_BLACKLISTED_FAMILIES="test_fam", + ) + def test_cache_key_generation(self): + """Test that cache keys are generated correctly""" + refresh_key = self.cache._get_refresh_token_key("abc123") + family_key = self.cache._get_family_key("fam456") + + self.assertEqual(refresh_key, "test_rt:abc123") + self.assertEqual(family_key, "test_fam:fam456") + + def test_add_refresh_token(self): + jti = "test-jti-123" + + # Add the token to the cache and verify it's there + self.cache.add_refresh_token(jti) + self.assertTrue(self.cache.is_refresh_token_blacklisted(jti)) + + def test_add_token_family(self): + family_id = "test-family-456" + + # Add the family to the cache and verify it's there + self.cache.add_token_family(family_id) + self.assertTrue(self.cache.is_token_family_blacklisted(family_id)) + + def test_is_refresh_token_blacklisted(self): + jti_blacklisted = "blacklisted-jti" + jti_not_blacklisted = "clean-jti" + + self.cache.add_refresh_token(jti_blacklisted) + + self.assertTrue(self.cache.is_refresh_token_blacklisted(jti_blacklisted)) + self.assertFalse(self.cache.is_refresh_token_blacklisted(jti_not_blacklisted)) + + def test_is_token_family_blacklisted(self): + family_blacklisted = "blacklisted-family" + family_not_blacklisted = "clean-family" + + self.cache.add_token_family(family_blacklisted) + + self.assertTrue(self.cache.is_token_family_blacklisted(family_blacklisted)) + self.assertFalse(self.cache.is_token_family_blacklisted(family_not_blacklisted)) + + def test_delete_refresh_token_from_cache(self): + jti = "delete-me-jti" + self.cache.add_refresh_token(jti) + + # Verify it's in the cache before deletion + self.assertTrue(self.cache.is_refresh_token_blacklisted(jti)) + + # Delete and verify it's gone + result = self.cache.delete_refresh_token_from_cache(jti) + self.assertTrue(result) + self.assertFalse(self.cache.is_refresh_token_blacklisted(jti)) + + # Try deleting again - should return False + result = self.cache.delete_refresh_token_from_cache(jti) + self.assertFalse(result) + + def test_delete_family_from_cache(self): + family_id = "delete-me-family" + self.cache.add_token_family(family_id) + + # Verify it's in the cache before deletion + self.assertTrue(self.cache.is_token_family_blacklisted(family_id)) + + # Delete and verify it's gone + result = self.cache.delete_family_from_cache(family_id) + self.assertTrue(result) + self.assertFalse(self.cache.is_token_family_blacklisted(family_id)) + + # Try deleting again - should return False + result = self.cache.delete_family_from_cache(family_id) + self.assertFalse(result) + + @override_api_settings( + CACHE_BLACKLISTED_REFRESH_TOKENS=True, + CACHE_TTL_BLACKLISTED_REFRESH_TOKENS=30, # Set to 30 seconds + ) + def test_refresh_token_ttl(self): + """Test that refresh tokens are gone after the configured TTL.""" + jti = "expiring-jti" + + with freeze_time(aware_utcnow() - timedelta(seconds=31)): + self.cache.add_refresh_token(jti) + self.assertTrue(self.cache.is_refresh_token_blacklisted(jti)) + + # Should be gone now + self.assertFalse(self.cache.is_refresh_token_blacklisted(jti)) + + @override_api_settings( + CACHE_BLACKLISTED_FAMILIES=True, + CACHE_TTL_BLACKLISTED_FAMILIES=30, # Set to 30 seconds + ) + def test_token_family_ttl(self): + """Test that token families are gone after the configured TTL.""" + family_id = "expiring-family" + + with freeze_time(aware_utcnow() - timedelta(seconds=31)): + self.cache.add_token_family(family_id) + self.assertTrue(self.cache.is_token_family_blacklisted(family_id)) + + # Should be gone now + self.assertFalse(self.cache.is_token_family_blacklisted(family_id)) + + +class TestBlacklistMixinCacheIntegration(TestCase): + """ + Test the integration of BlacklistCache with the BlacklistMixin. + """ + + def setUp(self): + # Clear the cache before each test + blacklist_cache.cache.clear() + + @override_api_settings(CACHE_BLACKLISTED_REFRESH_TOKENS=True) + def test_refresh_token_blacklist_method_with_cache_enabled(self): + """Test that blacklisted tokens are added to cache when enabled""" + # Create a token and blacklist it + token = RefreshToken() + jti = token.payload[api_settings.JTI_CLAIM] + + token.blacklist() + + # Check that it was added to the cache + self.assertTrue(blacklist_cache.is_refresh_token_blacklisted(jti)) + + # check that is was added to the DB + try: + BlacklistedToken.objects.get(token__jti=jti) + except BlacklistedToken.DoesNotExist: + self.fail(f"Expected BlacklistedToken object with jti='{jti}'") + + @override_api_settings(CACHE_BLACKLISTED_REFRESH_TOKENS=True) + def test_refresh_token_check_blacklist_method_with_cache_enabled(self): + """Test that check_blacklist checks the cache when enabled""" + # Create a token + token = RefreshToken() + jti = token.payload[api_settings.JTI_CLAIM] + + # Add it to the cache directly (not the database) + blacklist_cache.add_refresh_token(jti) + + # Verify that check_blacklist finds it in the cache + with self.assertRaises(RefreshTokenBlacklistedError): + token.check_blacklist() + + # call blacklist method to add the token to the DB blacklist + token.blacklist() + blacklist_cache.cache.clear() + + self.assertFalse(blacklist_cache.is_refresh_token_blacklisted(jti)) + + # here we verify that the DB is check if the token is not in the cache + with self.assertRaises(RefreshTokenBlacklistedError): + token.check_blacklist() + + @override_api_settings(CACHE_BLACKLISTED_REFRESH_TOKENS=False) + def test_refresh_token_blacklist_method_with_cache_disabled(self): + # Create a token and blacklist it + token = RefreshToken() + jti = token.payload[api_settings.JTI_CLAIM] + + token.blacklist() + + # Check that is not present in the cache + self.assertFalse(blacklist_cache.is_refresh_token_blacklisted(jti)) + + # check that is was added to the DB + try: + BlacklistedToken.objects.get(token__jti=jti) + except BlacklistedToken.DoesNotExist: + self.fail(f"Expected BlacklistedToken object with jti='{jti}'") + + @override_api_settings(CACHE_BLACKLISTED_REFRESH_TOKENS=False) + def test_refresh_token_check_blacklist_method_with_cache_disabled(self): + """Test that check_blacklist checks the cache when enabled""" + # Create a token + token = RefreshToken() + jti = token.payload[api_settings.JTI_CLAIM] + + # Add it to the cache directly (not the database) + blacklist_cache.add_refresh_token(jti) + + # it should pass and not raise and error since the token + # is only in the cache and not the DB + token.check_blacklist() + + # clean cache to remove the previous token + blacklist_cache.cache.clear() + + # call blacklist method to add the token to the DB blacklist + token.blacklist() + + self.assertFalse(blacklist_cache.is_refresh_token_blacklisted(jti)) + + with self.assertRaises(RefreshTokenBlacklistedError): + token.check_blacklist() + + +class TestFamilyMixinCacheIntegration(TestCase): + """ + Test the integration of BlacklistCache with the FamilyMixin. + """ + + def setUp(self): + self.user = User.objects.create(username="test_user", password="test_password") + # Clear the cache before each test + blacklist_cache.cache.clear() + + @override_api_settings(CACHE_BLACKLISTED_FAMILIES=True, TOKEN_FAMILY_ENABLED=True) + def test_token_family_blacklist_method_with_cache_enabled(self): + # Create a token and blacklist its family + token = RefreshToken.for_user(self.user) + family_id = token.payload.get(api_settings.TOKEN_FAMILY_CLAIM) + + # Ensure there is a family_id + self.assertIsNotNone(family_id) + + token.blacklist_family() + + # Check that it was added to the cache + self.assertTrue(blacklist_cache.is_token_family_blacklisted(family_id)) + + # check that is was added to the DB + try: + BlacklistedTokenFamily.objects.get(family__family_id=family_id) + except BlacklistedTokenFamily.DoesNotExist: + self.fail( + f"Expected BlacklistedTokenFamily object with family_id='{family_id}'" + ) + + @override_api_settings(CACHE_BLACKLISTED_FAMILIES=True, TOKEN_FAMILY_ENABLED=True) + def test_token_family_check_blacklist_method_with_cache_enabled(self): + # Create a token + token = RefreshToken.for_user(self.user) + family_id = token.payload.get(api_settings.TOKEN_FAMILY_CLAIM) + + # Ensure there is a family_id + self.assertIsNotNone(family_id) + + # Add it to the cache directly (not the database) + blacklist_cache.add_token_family(family_id) + + with self.assertRaises(TokenError): + RefreshToken.check_family_blacklist(token) + + # call blacklist method to add the token to the DB blacklist + token.blacklist_family() + blacklist_cache.cache.clear() + + self.assertFalse(blacklist_cache.is_token_family_blacklisted(family_id)) + + # here we verify that the DB is check if the token is not in the cache + with self.assertRaises(TokenError): + token.check_family_blacklist(token) + + @override_api_settings(CACHE_BLACKLISTED_FAMILIES=False, TOKEN_FAMILY_ENABLED=True) + def test_token_family_blacklist_method_with_cache_disabled(self): + # Create a token and blacklist its family + token = RefreshToken.for_user(self.user) + family_id = token.payload.get(api_settings.TOKEN_FAMILY_CLAIM) + + # Ensure there is a family_id + self.assertIsNotNone(family_id) + + token.blacklist_family() + + # Check that token is not present in the cache + self.assertFalse(blacklist_cache.is_token_family_blacklisted(family_id)) + + # check that is was added to the DB + try: + BlacklistedTokenFamily.objects.get(family__family_id=family_id) + except BlacklistedTokenFamily.DoesNotExist: + self.fail( + f"Expected BlacklistedTokenFamily object with family_id='{family_id}'" + ) + + @override_api_settings(CACHE_BLACKLISTED_FAMILIES=False, TOKEN_FAMILY_ENABLED=True) + def test_token_family_check_blacklist_method_with_cache_disabled(self): + # Create a token + token = RefreshToken.for_user(self.user) + family_id = token.payload.get(api_settings.TOKEN_FAMILY_CLAIM) + + # Ensure there is a family_id + self.assertIsNotNone(family_id) + + # Add it to the cache directly (not the database) + blacklist_cache.add_token_family(family_id) + + # it should pass and not raise any errors + RefreshToken.check_family_blacklist(token) + + # clean cache to remove the previous family + blacklist_cache.cache.clear() + + # call blacklist method to add the token to the DB blacklist + token.blacklist_family() + + self.assertFalse(blacklist_cache.is_token_family_blacklisted(family_id)) + + with self.assertRaises(TokenError): + token.check_family_blacklist(token) + + +class TestCacheBackend(TestCase): + """ + Test the cache backend configuration. + """ + + def test_cache_backend_configuration(self): + """Test that the correct cache backend is used""" + with override_api_settings(SJWT_CACHE_NAME="default"): + self.assertEqual(blacklist_cache.cache, caches["default"]) + + # If you have a different cache configured, you can test it + # Note: This requires the cache to be configured in settings + with override_api_settings(SJWT_CACHE_NAME="alternate"): + self.assertEqual(blacklist_cache.cache, caches["alternate"]) + + @override_api_settings(SJWT_CACHE_NAME="nonexistent_cache") + def test_invalid_cache_backend(self): + """Test that trying to use a non-existent cache backend raises an appropriate error.""" + cache = BlacklistCache() + with self.assertRaises(InvalidCacheBackendError): + # This should raise an exception when trying to access the cache property + _ = cache.cache diff --git a/tests/test_integration.py b/tests/test_integration.py index beee3d552..3eec15ec6 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,11 +1,15 @@ from datetime import timedelta +from importlib import reload from django.contrib.auth import get_user_model from django.urls import reverse +from freezegun import freeze_time from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED +from rest_framework_simplejwt.cache import blacklist_cache from rest_framework_simplejwt.settings import api_settings -from rest_framework_simplejwt.tokens import AccessToken +from rest_framework_simplejwt.tokens import AccessToken, RefreshToken +from rest_framework_simplejwt.utils import aware_utcnow from .utils import APIViewTestCase, override_api_settings @@ -127,3 +131,266 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self): self.assertEqual(res.status_code, HTTP_200_OK) self.assertEqual(res.data["foo"], "bar") + + @override_api_settings( + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",), + TOKEN_FAMILY_CHECK_ON_ACCESS=True, + ) + def test_access_token_performs_family_blacklist_check_when_enabled(self): + res = self.client.post( + reverse("token_obtain_pair"), + data={ + User.USERNAME_FIELD: self.username, + "password": self.password, + }, + ) + + access = res.data["access"] + refresh = res.data["refresh"] + + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) + + res = self.view_get() + + self.assertEqual(res.status_code, HTTP_200_OK) + self.assertEqual(res.data["foo"], "bar") + + RefreshToken(str(refresh)).blacklist_family() + + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) + + res = self.view_get() + + self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) + self.assertEqual("token_not_valid", res.data["code"]) + + error_msg = res.data.get("messages")[0].get("message") + self.assertIn("family", error_msg) + self.assertIn("blacklisted", error_msg) + + @override_api_settings( + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",), + TOKEN_FAMILY_CHECK_ON_ACCESS=True, + # We use a smaller family lifetime than the default access token lifetime + # so we dont have to reload the tokens and serializers modules + TOKEN_FAMILY_LIFETIME=timedelta(minutes=2), + ) + def test_access_token_performs_family_expiration_check_when_enabled(self): + with freeze_time(aware_utcnow() - timedelta(minutes=2)): + res = self.client.post( + reverse("token_obtain_pair"), + data={ + User.USERNAME_FIELD: self.username, + "password": self.password, + }, + ) + + access = res.data["access"] + + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) + + res = self.view_get() + + self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) + self.assertEqual("token_not_valid", res.data["code"]) + + error_msg = res.data.get("messages")[0].get("message") + self.assertIn("family", error_msg) + self.assertIn("expired", error_msg) + + @override_api_settings( + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",), + TOKEN_FAMILY_CHECK_ON_ACCESS=False, + # We use a smaller family lifetime than the default access token lifetime + # so we dont have to reload the tokens and serializers modules + TOKEN_FAMILY_LIFETIME=timedelta(minutes=2), + ) + def test_access_token_does_not_performs_family_checks_when_disabled(self): + res = self.client.post( + reverse("token_obtain_pair"), + data={ + User.USERNAME_FIELD: self.username, + "password": self.password, + }, + ) + + access = res.data["access"] + refresh = res.data["refresh"] + + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) + res = self.view_get() + + self.assertEqual(res.status_code, HTTP_200_OK) + self.assertEqual(res.data["foo"], "bar") + + # blacklisting the token family + RefreshToken(str(refresh)).blacklist_family() + + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) + res = self.view_get() + + # response must be 200_OK since the family check for the access token is disabled + self.assertEqual(res.status_code, HTTP_200_OK) + self.assertEqual(res.data["foo"], "bar") + + # testing for family expiration now + with freeze_time(aware_utcnow() - timedelta(minutes=2)): + res = self.client.post( + reverse("token_obtain_pair"), + data={ + User.USERNAME_FIELD: self.username, + "password": self.password, + }, + ) + + access = res.data["access"] + + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) + res = self.view_get() + + # response must be 200_OK since the family check for the access token is disabled + self.assertEqual(res.status_code, HTTP_200_OK) + self.assertEqual(res.data["foo"], "bar") + + +class TestBlacklistCacheIntegration(APIViewTestCase): + view_name = "test_view" + + def setUp(self): + self.username = "test_user" + self.password = "test_password" + self.user = User.objects.create_user( + username=self.username, + password=self.password, + ) + + blacklist_cache.cache.clear() + + @override_api_settings( + CACHE_BLACKLISTED_REFRESH_TOKENS=True, + BLACKLIST_AFTER_ROTATION=True, + ROTATE_REFRESH_TOKENS=True, + ) + def test_token_refresh_blacklists_in_cache(self): + """Test that token refresh adds the old token to the cache when configured.""" + # Get tokens + refresh = RefreshToken.for_user(self.user) + + # Get the JTI of the current refresh token before rotation + old_jti = refresh.payload[api_settings.JTI_CLAIM] + + # Use the token to refresh + res = self.client.post( + reverse("token_refresh"), + data={"refresh": str(refresh)}, + ) + + self.assertEqual(res.status_code, 200) + + # Verify the old refresh token was blacklisted and is in the cache + self.assertTrue(blacklist_cache.is_refresh_token_blacklisted(old_jti)) + + # Try to use the old token again - should fail + res = self.client.post( + reverse("token_refresh"), + data={"refresh": str(refresh)}, + ) + + self.assertEqual(res.status_code, 401) + + @override_api_settings(CACHE_BLACKLISTED_FAMILIES=True, TOKEN_FAMILY_ENABLED=True) + def test_token_verify_checks_blacklisted_family_in_cache(self): + """Test token verification checks family blacklist in cache.""" + # Get token + refresh = RefreshToken.for_user(self.user) + family_id = refresh.payload.get(api_settings.TOKEN_FAMILY_CLAIM) + + # Verify token works + res = self.client.post( + reverse("token_verify"), + data={"token": str(refresh)}, + ) + + self.assertEqual(res.status_code, 200) + + # Now blacklist the family directly in cache + blacklist_cache.add_token_family(family_id) + + # Verify token now fails + res = self.client.post( + reverse("token_verify"), + data={"token": str(refresh)}, + ) + + self.assertEqual(res.status_code, 401) + + error_msg = res.data.get("detail") + self.assertIn("family", error_msg) + self.assertIn("blacklisted", error_msg) + + @override_api_settings( + CACHE_BLACKLISTED_REFRESH_TOKENS=True, + BLACKLIST_AFTER_ROTATION=True, + ROTATE_REFRESH_TOKENS=True, + ) + def test_token_verify_checks_blacklisted_token_in_cache(self): + """Test token verification checks token blacklist in cache.""" + # Get token + refresh = RefreshToken.for_user(self.user) + + # Verify token works + res = self.client.post( + reverse("token_verify"), + data={"token": str(refresh)}, + ) + + self.assertEqual(res.status_code, 200) + + # Now blacklist the token directly in cache + blacklist_cache.add_refresh_token(refresh.get("jti")) + refresh.blacklist() + + # Verify token now fails + res = self.client.post( + reverse("token_verify"), + data={"token": str(refresh)}, + ) + + self.assertEqual(res.status_code, 400) + + @override_api_settings( + CACHE_BLACKLISTED_FAMILIES=True, + TOKEN_FAMILY_ENABLED=True, + TOKEN_FAMILY_CHECK_ON_ACCESS=True, + ) + def test_token_verify_checks_blacklisted_token_in_cache_2(self): + """Test token verification checks token blacklist in cache.""" + # Get tokens + res = self.client.post( + reverse("token_obtain_pair"), + data={ + User.USERNAME_FIELD: self.username, + "password": self.password, + }, + ) + + access = res.data["access"] + refresh = res.data["refresh"] + family_id = RefreshToken(refresh).get_family_id() + + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) + res = self.view_get() + + self.assertEqual(res.status_code, 200) + + # Now blacklist the family directly in cache + blacklist_cache.add_token_family(family_id) + + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) + res = self.view_get() + + self.assertEqual(res.status_code, 401) + + error_msg = res.data.get("messages")[0].get("message") + self.assertIn("family", error_msg) + self.assertIn("blacklisted", error_msg) diff --git a/tests/test_serializers.py b/tests/test_serializers.py index ed8645e54..266a7ff4b 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -5,11 +5,18 @@ from django.conf import settings from django.contrib.auth import get_user_model from django.test import TestCase, override_settings +from freezegun import freeze_time from rest_framework import exceptions as drf_exceptions +from rest_framework.serializers import ValidationError -from rest_framework_simplejwt.exceptions import TokenError +from rest_framework_simplejwt.exceptions import ( + RefreshTokenBlacklistedError, + TokenError, + TokenFamilyNotConfigured, +) from rest_framework_simplejwt.serializers import ( TokenBlacklistSerializer, + TokenFamilyBlacklistSerializer, TokenObtainPairSerializer, TokenObtainSerializer, TokenObtainSlidingSerializer, @@ -22,6 +29,10 @@ BlacklistedToken, OutstandingToken, ) +from rest_framework_simplejwt.token_family.models import ( + BlacklistedTokenFamily, + TokenFamily, +) from rest_framework_simplejwt.tokens import AccessToken, RefreshToken, SlidingToken from rest_framework_simplejwt.utils import ( aware_utcnow, @@ -604,6 +615,118 @@ def test_refresh_token_should_blacklist_after_password_change(self): self.assertTrue(OutstandingToken.objects.filter(jti=jti).exists()) self.assertTrue(BlacklistedToken.objects.filter(token__jti=jti).exists()) + @override_api_settings( + ROTATE_REFRESH_TOKENS=True, + BLACKLIST_AFTER_ROTATION=True, + TOKEN_FAMILY_BLACKLIST_ON_REUSE=True, + ) + def test_family_blacklisting_on_refresh_token_reuse_is_enabled(self): + """ + Tests that the token family is blacklisted upon refresh token reuse. + + If a refresh token has been rotated and subsequently added to the + blacklist, attempting to use that same original token again should + be detected as reuse. This reuse detection should then cause its + associated token family to be added to the BlacklistedTokenFamily model. + """ + refresh = RefreshToken() + refresh.payload[api_settings.TOKEN_FAMILY_CLAIM] = "random string" + + ser = TokenRefreshSerializer(data={"refresh": str(refresh)}) + ser.validate({"refresh": str(refresh)}) + + with self.assertRaises(RefreshTokenBlacklistedError) as e: + ser.validate({"refresh": str(refresh)}) + + self.assertIn("blacklisted", e.exception.args[0]) + + qs = BlacklistedTokenFamily.objects.filter(family__family_id="random string") + + self.assertEqual(qs.count(), 1) + + @override_api_settings( + ROTATE_REFRESH_TOKENS=True, + BLACKLIST_AFTER_ROTATION=True, + TOKEN_FAMILY_BLACKLIST_ON_REUSE=False, + ) + def test_family_blacklisting_on_refresh_token_reuse_is_disabled(self): + """ + Tests that refresh token reuse is detected, but the token family + is not added to the blacklist because the option for that is disabled. + + When a refresh token is reused and the TOKEN_FAMILY_BLACKLIST_ON_REUSE + setting is False, the reuse should be detected (raising the standard + blacklisted token error), but the token's associated family should NOT + be added to the BlacklistedTokenFamily model. + """ + refresh = RefreshToken() + refresh.payload[api_settings.TOKEN_FAMILY_CLAIM] = "random string" + + ser = TokenRefreshSerializer(data={"refresh": str(refresh)}) + ser.validate({"refresh": str(refresh)}) + + with self.assertRaises(RefreshTokenBlacklistedError) as e: + ser.validate({"refresh": str(refresh)}) + + self.assertIn("blacklisted", e.exception.args[0]) + + qs = BlacklistedTokenFamily.objects.all() + + self.assertEqual(qs.count(), 0) + + @override_api_settings( + ROTATE_REFRESH_TOKENS=True, + BLACKLIST_AFTER_ROTATION=True, + TOKEN_FAMILY_BLACKLIST_ON_REUSE=False, + ) + def test_family_blacklisting_on_token_reuse_is_enabled_but_token_has_no_family_id( + self, + ): + """ + Ensures that no error is raised and no family is blacklisted when a reused token lacks a family ID. + + This simulates the case where a token was issued before the token family feature was enabled. + Such tokens are still considered valid and should not trigger family-level blacklisting. + """ + # token has no family id + refresh = RefreshToken() + + ser = TokenRefreshSerializer(data={"refresh": str(refresh)}) + ser.validate({"refresh": str(refresh)}) + + self.assertEqual(BlacklistedToken.objects.all().count(), 1) + self.assertEqual(BlacklistedTokenFamily.objects.all().count(), 0) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_raises_token_error_if_token_family_is_expired(self): + with freeze_time(aware_utcnow() - timedelta(minutes=30)): + token = RefreshToken() + fam_exp = datetime_to_epoch(aware_utcnow()) + token.payload[api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM] = fam_exp + + ser = TokenRefreshSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + ser.validate({"refresh": str(token)}) + + self.assertIn("family", e.exception.args[0]) + self.assertIn("expired", e.exception.args[0]) + + def test_raises_token_error_if_token_family_is_blacklisted(self): + token = RefreshToken() + token.payload[api_settings.TOKEN_FAMILY_CLAIM] = "random string" + token.blacklist_family() + + self.assertEqual(BlacklistedTokenFamily.objects.all().count(), 1) + + ser = TokenRefreshSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + ser.validate({"refresh": str(token)}) + + self.assertIn("family", e.exception.args[0]) + self.assertIn("blacklisted", e.exception.args[0]) + class TestTokenVerifySerializer(TestCase): def test_it_should_raise_token_error_if_token_invalid(self): @@ -649,6 +772,36 @@ def test_it_should_return_given_token_if_everything_ok(self): self.assertEqual(len(s.validated_data), 0) + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_raises_token_error_if_token_family_is_expired(self): + with freeze_time(aware_utcnow() - timedelta(minutes=30)): + token = RefreshToken() + fam_exp = datetime_to_epoch(aware_utcnow()) + token.payload[api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM] = fam_exp + + ser = TokenVerifySerializer(data={"token": str(token)}) + + with self.assertRaises(TokenError) as e: + ser.validate({"token": str(token)}) + + self.assertIn("family", e.exception.args[0]) + self.assertIn("expired", e.exception.args[0]) + + def test_raises_token_error_if_token_family_is_blacklisted(self): + token = RefreshToken() + token.payload[api_settings.TOKEN_FAMILY_CLAIM] = "random string" + token.blacklist_family() + + self.assertEqual(BlacklistedTokenFamily.objects.all().count(), 1) + + ser = TokenVerifySerializer(data={"token": str(token)}) + + with self.assertRaises(TokenError) as e: + ser.validate({"token": str(token)}) + + self.assertIn("family", e.exception.args[0]) + self.assertIn("blacklisted", e.exception.args[0]) + class TestTokenBlacklistSerializer(TestCase): def test_it_should_raise_token_error_if_token_invalid(self): @@ -743,3 +896,188 @@ def test_blacklist_app_not_installed_should_pass(self): # Restore origin module without mock reload(tokens) reload(serializers) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_raises_token_error_if_token_family_is_expired(self): + with freeze_time(aware_utcnow() - timedelta(minutes=30)): + token = RefreshToken() + fam_exp = datetime_to_epoch(aware_utcnow()) + token.payload[api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM] = fam_exp + + ser = TokenBlacklistSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + ser.validate({"refresh": str(token)}) + + self.assertIn("family", e.exception.args[0]) + self.assertIn("expired", e.exception.args[0]) + + def test_raises_token_error_if_token_family_is_blacklisted(self): + token = RefreshToken() + token.payload[api_settings.TOKEN_FAMILY_CLAIM] = "random string" + token.blacklist_family() + + self.assertEqual(BlacklistedTokenFamily.objects.all().count(), 1) + + ser = TokenBlacklistSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + ser.validate({"refresh": str(token)}) + + self.assertIn("family", e.exception.args[0]) + self.assertIn("blacklisted", e.exception.args[0]) + + +class TestTokenFamilyBlacklistSerializer(TestCase): + def setUp(self): + self.user = User.objects.create_user( + username="test_user", password="test_password" + ) + + def test_it_should_raise_token_error_if_token_invalid(self): + token = RefreshToken() + del token["exp"] + + s = TokenFamilyBlacklistSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + s.is_valid() + + self.assertIn("has no 'exp' claim", e.exception.args[0]) + + token.set_exp(lifetime=-timedelta(days=1)) + + s = TokenBlacklistSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + s.is_valid() + + self.assertIn("expired", e.exception.args[0]) + + def test_it_should_raise_token_error_if_token_has_wrong_type(self): + token = RefreshToken() + token[api_settings.TOKEN_TYPE_CLAIM] = "wrong_type" + + s = TokenFamilyBlacklistSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + s.is_valid() + + self.assertIn("wrong type", e.exception.args[0]) + + @override_api_settings(TOKEN_FAMILY_ENABLED=True) + def test_it_should_raise_token_family_not_confgured_if_family_app_is_not_installed( + self, + ): + from rest_framework_simplejwt import serializers, tokens + + # Remove family app + new_apps = list(settings.INSTALLED_APPS) + new_apps.remove("rest_framework_simplejwt.token_family") + + with self.settings(INSTALLED_APPS=tuple(new_apps)): + # Reload module that blacklist app not installed + reload(tokens) + reload(serializers) + + token = tokens.RefreshToken() + token[api_settings.TOKEN_FAMILY_CLAIM] = "random string" + + # Serializer validates + ser = serializers.TokenFamilyBlacklistSerializer( + data={"refresh": str(token)} + ) + + with self.assertRaises(TokenFamilyNotConfigured) as e: + ser.validate({"refresh": str(token)}) + + # Restore origin module without mock + reload(tokens) + reload(serializers) + + def test_it_should_raise_token_family_not_confgured_if_family_setting_is_disable( + self, + ): + """the default value of the setting TOKEN_FAMILY_ENABLED is False""" + from rest_framework_simplejwt import serializers, tokens + + with override_api_settings(TOKEN_FAMILY_ENABLED=False): + reload(tokens) + reload(serializers) + + token = tokens.RefreshToken() + token[api_settings.TOKEN_FAMILY_CLAIM] = "random string" + + # Serializer validates + ser = serializers.TokenFamilyBlacklistSerializer( + data={"refresh": str(token)} + ) + + with self.assertRaises(TokenFamilyNotConfigured) as e: + ser.validate({"refresh": str(token)}) + + # restore origin modules + with override_api_settings(TOKEN_FAMILY_ENABLED=True): + reload(tokens) + reload(serializers) + + def test_raises_validation_error_if_token_has_no_family_id_in_payload(self): + refresh = RefreshToken.for_user(self.user) + del refresh.payload[api_settings.TOKEN_FAMILY_CLAIM] + + ser = TokenFamilyBlacklistSerializer(data={"refresh": str(refresh)}) + + with self.assertRaises(ValidationError) as e: + ser.validate({"refresh": str(refresh)}) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_raises_token_error_if_family_is_expired(self): + with freeze_time(aware_utcnow() - timedelta(minutes=30)): + token = RefreshToken.for_user(self.user) + + ser = TokenFamilyBlacklistSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + ser.validate({"refresh": str(token)}) + + self.assertIn("family", e.exception.args[0]) + self.assertIn("expired", e.exception.args[0]) + + def test_it_should_blacklist_family_if_everything_is_ok(self): + token = RefreshToken.for_user(self.user) + ser = TokenFamilyBlacklistSerializer(data={"refresh": str(token)}) + ser.is_valid() + + qs = BlacklistedTokenFamily.objects.filter( + family__family_id=token.get_family_id() + ) + + self.assertTrue(qs.exists()) + self.assertEqual(qs.count(), 1) + + def test_raises_token_error_if_family_is_blacklisted(self): + token = RefreshToken.for_user(self.user) + token.blacklist_family() + + self.assertEqual(BlacklistedTokenFamily.objects.all().count(), 1) + + ser = TokenFamilyBlacklistSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + ser.validate({"refresh": str(token)}) + + self.assertIn("family", e.exception.args[0]) + self.assertIn("blacklisted", e.exception.args[0]) + + def test_raises_token_error_if_token_is_blacklisted(self): + token = RefreshToken.for_user(self.user) + token.blacklist() + + self.assertEqual(BlacklistedToken.objects.all().count(), 1) + + ser = TokenFamilyBlacklistSerializer(data={"refresh": str(token)}) + + with self.assertRaises(TokenError) as e: + ser.validate({"refresh": str(token)}) + + self.assertIn("blacklisted", e.exception.args[0]) diff --git a/tests/test_token_blacklist.py b/tests/test_token_blacklist.py index 03bada52c..d1d4fa9de 100644 --- a/tests/test_token_blacklist.py +++ b/tests/test_token_blacklist.py @@ -312,7 +312,7 @@ def setUp(self): super().setUp() @override_api_settings(BLACKLIST_AFTER_ROTATION=True) - def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled( + def test_token_verify_serializer_should_honour_blacklist_if_rotation_enabled( self, ): refresh_token = RefreshToken.for_user(self.user) @@ -322,14 +322,14 @@ def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled self.assertFalse(serializer.is_valid()) @override_api_settings(BLACKLIST_AFTER_ROTATION=False) - def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not_enabled( + def test_token_verify_serializer_should_honour_blacklist_if_rotation_not_enabled( self, ): refresh_token = RefreshToken.for_user(self.user) refresh_token.blacklist() serializer = TokenVerifySerializer(data={"token": str(refresh_token)}) - self.assertTrue(serializer.is_valid()) + self.assertFalse(serializer.is_valid()) class TestBigAutoFieldIDMigration(MigrationTestCase): diff --git a/tests/test_token_family.py b/tests/test_token_family.py new file mode 100644 index 000000000..8e9dd7b0b --- /dev/null +++ b/tests/test_token_family.py @@ -0,0 +1,568 @@ +from datetime import timedelta +from importlib import reload +from uuid import uuid4 + +from django.contrib.auth.models import User +from django.core.management import call_command +from django.db import IntegrityError +from django.test import TestCase +from freezegun import freeze_time + +from rest_framework_simplejwt.exceptions import TokenError +from rest_framework_simplejwt.settings import api_settings +from rest_framework_simplejwt.token_family.models import ( + BlacklistedTokenFamily, + TokenFamily, +) +from rest_framework_simplejwt.tokens import ( + AccessToken, + FamilyMixin, + RefreshToken, +) +from rest_framework_simplejwt.utils import aware_utcnow + +from .utils import override_api_settings + + +class TestTokenFamilyModels(TestCase): + def setUp(self): + self.user = User.objects.create(username="test_user", password="test_password") + + def test_manual_family_creation(self): + token_fam = TokenFamily.objects.create( + user=self.user, + family_id=uuid4().hex, + created_at=aware_utcnow(), + expires_at=aware_utcnow() + timedelta(days=1), + ) + + self.assertEqual(self.user.id, token_fam.user.id) + + def test_expires_at_is_none(self): + token_fam = TokenFamily.objects.create( + user=self.user, + family_id=uuid4().hex, + created_at=aware_utcnow(), + expires_at=None, + ) + + self.assertIsNone(token_fam.expires_at) + + def test_expires_at_is_datetime(self): + fam_id = uuid4().hex + current_time = aware_utcnow() + expiration_date = current_time + timedelta(days=1) + token_fam = TokenFamily.objects.create( + user=self.user, + family_id=fam_id, + created_at=current_time, + expires_at=expiration_date, + ) + + self.assertEqual(token_fam.expires_at, expiration_date) + + def test_family_id_cannot_be_null(self): + """ + If the code inside the 'with' block raises ValueError or IntegrityError, the test passes. + If it does NOT raise one of these exceptions, the test fails. + """ + with self.assertRaises( + (ValueError, IntegrityError) + ): # Catch either potential exception + TokenFamily.objects.create( + user=self.user, + family_id=None, # Intentionally try to set it to None + created_at=aware_utcnow(), + expires_at=aware_utcnow() + timedelta(days=1), + ) + + def test_created_at_cannot_be_null(self): + """ + If the code inside the 'with' block raises ValueError or IntegrityError, the test passes. + If it does NOT raise one of these exceptions, the test fails. + """ + with self.assertRaises( + (ValueError, IntegrityError) + ): # Catch either potential exception + TokenFamily.objects.create( + user=self.user, + family_id=uuid4().hex, + created_at=None, + expires_at=aware_utcnow() + timedelta(days=1), + ) + + def test_family_id_must_be_unique(self): + fam_id = uuid4().hex + TokenFamily.objects.create( + user=self.user, + family_id=fam_id, + created_at=aware_utcnow(), + expires_at=aware_utcnow() + timedelta(days=1), + ) + + with self.assertRaises( + (ValueError, IntegrityError) + ): # Catch either potential exception + TokenFamily.objects.create( + user=self.user, + family_id=fam_id, + created_at=aware_utcnow(), + expires_at=aware_utcnow() + timedelta(days=1), + ) + + def test_user_can_have_multiple_families(self): + num_families_to_create = 3 + + # loop to create multiple TokenFamily instances for the same user + for i in range(num_families_to_create): + TokenFamily.objects.create( + user=self.user, + family_id=uuid4().hex, + created_at=aware_utcnow(), + expires_at=aware_utcnow() + timedelta(days=(i + 1)), + ) + + self.assertEqual( + TokenFamily.objects.filter(user=self.user).count(), num_families_to_create + ) + + def test_add_family_to_blacklist(self): + token_fam = TokenFamily.objects.create( + user=self.user, + family_id=uuid4().hex, + created_at=aware_utcnow(), + expires_at=aware_utcnow() + timedelta(days=1), + ) + + blacklisted_fam = BlacklistedTokenFamily.objects.create(family=token_fam) + + self.assertEqual(BlacklistedTokenFamily.objects.all().count(), 1) + self.assertEqual(token_fam.id, blacklisted_fam.family.id) + + def test_duplicate_blacklisted_token_family_fails(self): + """ + Ensures that attempting to create a second BlacklistedTokenFamily + for the same TokenFamily instance raises an IntegrityError or ValueError. + """ + token_fam = TokenFamily.objects.create( + user=self.user, + family_id=uuid4().hex, + created_at=aware_utcnow(), + expires_at=aware_utcnow() + timedelta(days=7), + ) + + BlacklistedTokenFamily.objects.create(family=token_fam) + + with self.assertRaises((IntegrityError, ValueError)): + BlacklistedTokenFamily.objects.create(family=token_fam) + + def test_family_delete_also_removes_the_blacklist_element(self): + token_fam = TokenFamily.objects.create( + user=self.user, + family_id=uuid4().hex, + created_at=aware_utcnow(), + expires_at=aware_utcnow() + timedelta(days=1), + ) + + BlacklistedTokenFamily.objects.create(family=token_fam) + + self.assertEqual(BlacklistedTokenFamily.objects.all().count(), 1) + + token_fam.delete() + + self.assertEqual(TokenFamily.objects.all().count(), 0) + self.assertEqual(BlacklistedTokenFamily.objects.all().count(), 0) + + +class TestTokenFamilyInRefreshTokens(TestCase): + def setUp(self): + self.user = User.objects.create_user( + username="test_user", password="test_password" + ) + + def test_get_family_attributes_from_token(self): + token = RefreshToken.for_user(self.user) + + self.assertIsNotNone(token.get_family_id()) + self.assertIsNotNone(token.get_family_expiration_date()) + + def test_token_family_id_payload_value_must_be_a_string(self): + token = RefreshToken.for_user(self.user) + + self.assertIsInstance(token[api_settings.TOKEN_FAMILY_CLAIM], str) + + def test_token_family_expiraton_payload_value_must_be_an_integer(self): + token = RefreshToken.for_user(self.user) + + self.assertIsInstance(token[api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM], int) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=None) + def test_token_family_expiration_can_be_none(self): + token = RefreshToken.for_user(self.user) + + self.assertIsNone(token.get_family_expiration_date()) + self.assertIsNone(token.get(api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM)) + + def test_families_are_added_to_the_family_list(self): + token = RefreshToken.for_user(self.user) + qs = TokenFamily.objects.all() + token_family_obj = qs.first() + + self.assertEqual(qs.count(), 1) + self.assertEqual(token_family_obj.user, self.user) + self.assertEqual(token_family_obj.family_id, token.get_family_id()) + self.assertEqual(token_family_obj.created_at, token.current_time) + self.assertEqual( + token_family_obj.expires_at.replace(microsecond=0), + token.get_family_expiration_date().replace(microsecond=0), + ) + + def test_refresh_token_will_not_validate_if_family_is_blacklisted(self): + token = RefreshToken.for_user(self.user) + token_family_obj = TokenFamily.objects.first() + + # Should raise no exception + RefreshToken(str(token)) + + # Add family to blacklist + BlacklistedTokenFamily.objects.create(family=token_family_obj) + + with self.assertRaises(TokenError) as cm: + # Should raise exception + RefreshToken(str(token)) + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("blacklisted", cm.exception.args[0]) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_refresh_token_will_not_validate_if_family_is_expired(self): + with freeze_time("2025-04-28 10:00:00"): + token = RefreshToken.for_user(self.user) + + with freeze_time("2025-04-28 10:30:01"): # move time forward 30 mins + 1 sec + with self.assertRaises(TokenError) as cm: + RefreshToken(str(token)) + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("expired", cm.exception.args[0]) + + def test_token_family_can_be_manually_blacklisted(self): + token = RefreshToken.for_user(self.user) + + # Should raise no exception + RefreshToken(str(token)) + + self.assertEqual(TokenFamily.objects.count(), 1) + + # Add family to blacklist + blacklisted_fam, created = token.blacklist_family() + + # Should not add family to tokenfamily list if already present + self.assertEqual(TokenFamily.objects.count(), 1) + + # Should return blacklist record + self.assertTrue(created) + self.assertEqual(blacklisted_fam.family.family_id, token.get_family_id()) + + with self.assertRaises(TokenError) as cm: + # Should raise exception + RefreshToken(str(token)) + + self.assertIn("blacklisted", cm.exception.args[0]) + + # This test checks that an error is raised when trying to blacklist a token + # that does not have a family ID. Tokens created without the `for_user` + # method do not include family-related attributes, which are provided + # by the `FamilyMixin`. Since the token does not have a family, + # the `blacklist_family()` method should raise a `TokenError`. + # + # Additionally, there is no point in creating a family to blacklist an + # isolated token, as the token is not part of any family, and creating + # a family for it will not affect the tokens that are related to it, + # since those other toknes will still have no family attribute. + new_token = RefreshToken() + with self.assertRaises(TokenError) as cm: + # Should raise exception + new_token.blacklist_family() + + self.assertIn("no family", cm.exception.args[0]) + + # Count should still be 1 since the previous token could not be added + self.assertEqual(TokenFamily.objects.count(), 1) + + def test_family_attributes_do_not_change_for_new_tokens(self): + token_1 = RefreshToken.for_user(self.user) + + # we genereate a new refresh token from token_1 + token_2 = RefreshToken(token=str(token_1)) + + self.assertEqual( + token_1[api_settings.TOKEN_FAMILY_CLAIM], + token_2[api_settings.TOKEN_FAMILY_CLAIM], + ) + self.assertEqual( + token_1[api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM], + token_2[api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM], + ) + + @override_api_settings(TOKEN_FAMILY_CLAIM="fam_test_claim") + def test_token_family_with_modify_family_claim(self): + token = RefreshToken.for_user(self.user) + + self.assertEqual(api_settings.TOKEN_FAMILY_CLAIM, "fam_test_claim") + self.assertIn("fam_test_claim", token.payload) + + @override_api_settings(TOKEN_FAMILY_EXPIRATION_CLAIM="fam_test_exp_claim") + def test_token_family_with_modify_family_expiration_claim(self): + token = RefreshToken.for_user(self.user) + + self.assertEqual( + api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM, "fam_test_exp_claim" + ) + self.assertIn("fam_test_exp_claim", token.payload) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_check_family_expiration_method(self): + token: RefreshToken + with freeze_time("2025-04-28 10:00:00"): + token = RefreshToken.for_user(self.user) + + with freeze_time("2025-04-28 10:30:01"): # move time forward 30 mins + 1 sec + with self.assertRaises(TokenError) as cm: + # Call static method directly from the mixin + FamilyMixin.check_family_expiration(token, aware_utcnow()) + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("expired", cm.exception.args[0]) + + # Call via the instance (still a staticmethod, so same signature) + with self.assertRaises(TokenError) as cm: + token.check_family_expiration(token, aware_utcnow()) + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("expired", cm.exception.args[0]) + + def test_family_blacklist_method(self): + token = RefreshToken.for_user(self.user) + token.blacklist_family() + + try: + BlacklistedTokenFamily.objects.get(family__family_id=token.get_family_id()) + except BlacklistedTokenFamily.DoesNotExist: + self.fail( + f"Expected BlacklistedTokenFamily object with family_family_id='{token.get_family_id()}' " + f"to exist after blacklisting the token, but it was not found in the database." + ) + + def test_check_family_blacklist_method(self): + token = RefreshToken.for_user(self.user) + token.blacklist_family() + + # Call static method directly from the mixin + with self.assertRaises(TokenError) as cm: + FamilyMixin.check_family_blacklist(token) + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("blacklisted", cm.exception.args[0]) + + # Call via the instance (still a staticmethod, so same signature) + with self.assertRaises(TokenError) as cm: + token.check_family_blacklist(token) + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("blacklisted", cm.exception.args[0]) + + def test_user_can_have_multiple_families(self): + amount_of_tokens = 5 + tokens = [RefreshToken.for_user(self.user) for _ in range(amount_of_tokens)] + + family_count = TokenFamily.objects.filter(user=self.user).count() + self.assertEqual(family_count, amount_of_tokens) + + def test_family_id_uniqueness(self): + """ + Tests that creating multiple separate tokens for the same user + results in each token having a unique family id. + """ + amount_of_tokens = 5 + families_ids = [ + RefreshToken.for_user(self.user).get_family_id() + for _ in range(amount_of_tokens) + ] + + # Converts the list of IDs to a set (which automatically removes duplicates) + families_ids_set = set(families_ids) + + self.assertEqual( + len(families_ids), + len(families_ids_set), + f"Expected {amount_of_tokens} unique family IDs, but found duplicates. IDs collected: {families_ids}", + ) + + def test_family_features_not_accessible_when_disabled(self): + from rest_framework_simplejwt import tokens + + with override_api_settings(TOKEN_FAMILY_ENABLED=False): + reload(tokens) + token = tokens.RefreshToken.for_user(self.user) + + with self.assertRaises(AttributeError): + token.get_family_id() + + # reloading to that the family mixin can have access to his methods + with override_api_settings(TOKEN_FAMILY_ENABLED=True): + reload(tokens) + + +class TestTokenFamilyInAccessTokens(TestCase): + def setUp(self): + self.user = User.objects.create_user( + username="test_user", password="test_password" + ) + + @override_api_settings(TOKEN_FAMILY_CHECK_ON_ACCESS=True) + def test_access_token_payload_can_include_family_claims(self): + refresh = RefreshToken.for_user(self.user) + access = refresh.access_token + + self.assertIn(api_settings.TOKEN_FAMILY_CLAIM, access.payload) + self.assertIn(api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM, access.payload) + + @override_api_settings(TOKEN_FAMILY_CHECK_ON_ACCESS=False) + def test_access_token_payload_can_exclude_family_claims(self): + refresh = RefreshToken.for_user(self.user) + access = refresh.access_token + + self.assertNotIn(api_settings.TOKEN_FAMILY_CLAIM, access.payload) + self.assertNotIn(api_settings.TOKEN_FAMILY_EXPIRATION_CLAIM, access.payload) + + @override_api_settings(TOKEN_FAMILY_CHECK_ON_ACCESS=True) + def test_access_token_will_not_validate_if_family_is_blacklisted(self): + refresh = RefreshToken.for_user(self.user) + access = refresh.access_token + refresh.blacklist_family() + + with self.assertRaises(TokenError) as cm: + AccessToken(str(access)) + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("blacklisted", cm.exception.args[0]) + + @override_api_settings( + TOKEN_FAMILY_CHECK_ON_ACCESS=True, + TOKEN_FAMILY_LIFETIME=timedelta(minutes=30), + ) + def test_access_token_will_not_validate_if_family_is_expired(self): + from rest_framework_simplejwt import tokens + + with override_api_settings( + ACCESS_TOKEN_LIFETIME=timedelta(minutes=35), + ): + reload(tokens) + + with freeze_time("2025-04-28 10:00:00"): + refresh = tokens.RefreshToken.for_user(self.user) + access = refresh.access_token + + with freeze_time("2025-04-28 10:30:01"): # move time forward 30 mins + 1 sec + with self.assertRaises(TokenError) as cm: + tokens.AccessToken(str(access)) + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("expired", cm.exception.args[0]) + + # relaod module back to default aceess lifetime + with override_api_settings( + ACCESS_TOKEN_LIFETIME=api_settings.ACCESS_TOKEN_LIFETIME, + ): + reload(tokens) + + @override_api_settings(TOKEN_FAMILY_CHECK_ON_ACCESS=True) + def test_access_token_verifies_family_blacklist(self): + refresh = RefreshToken.for_user(self.user) + access = refresh.access_token + refresh.blacklist_family() + + with self.assertRaises(TokenError) as cm: + access.verify() + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("blacklisted", cm.exception.args[0]) + + @override_api_settings( + TOKEN_FAMILY_CHECK_ON_ACCESS=True, TOKEN_FAMILY_LIFETIME=timedelta(minutes=30) + ) + def test_access_token_varifies_family_expiration(self): + access: AccessToken + + with freeze_time("2025-04-28 10:00:00"): + refresh = RefreshToken.for_user(self.user) + access = refresh.access_token + + with freeze_time("2025-04-28 10:30:01"): # move time forward 30 mins + 1 sec + with self.assertRaises(TokenError) as cm: + access.verify() + + self.assertIn("family", cm.exception.args[0]) + self.assertIn("expired", cm.exception.args[0]) + + +class TestFlushExpiredTokenFamiliesCommand(TestCase): + def setUp(self): + self.user = User.objects.create_user( + username="test_user", password="test_password" + ) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_it_should_delete_any_expire_families(self): + amount_of_valid_families = 5 + amount_of_expired_families = 3 + amount_not_expired_blacklisted_families = 1 + amount_expired_blacklisted_families = 2 + + valid_family_tokens: list[RefreshToken] = [ + RefreshToken.for_user(self.user) for _ in range(amount_of_valid_families) + ] + + for i in range(amount_not_expired_blacklisted_families): + valid_family_tokens[i].blacklist_family() + + expired_family_tokens: list[RefreshToken] + with freeze_time(aware_utcnow() - timedelta(minutes=31)): + expired_family_tokens = [ + RefreshToken.for_user(self.user) + for _ in range(amount_of_expired_families) + ] + + for i in range(amount_expired_blacklisted_families): + expired_family_tokens[i].blacklist_family() + + self.assertEqual( + TokenFamily.objects.all().count(), + amount_of_expired_families + amount_of_valid_families, + ) + self.assertEqual( + BlacklistedTokenFamily.objects.all().count(), + amount_expired_blacklisted_families + + amount_not_expired_blacklisted_families, + ) + + call_command("flushexpiredfamilies") + + self.assertEqual(TokenFamily.objects.all().count(), amount_of_valid_families) + self.assertEqual( + BlacklistedTokenFamily.objects.all().count(), + amount_not_expired_blacklisted_families, + ) + + def test_family_will_not_remove_on_User_delete(self): + token: RefreshToken = RefreshToken.for_user(self.user) + token.blacklist_family() + + self.assertEqual(TokenFamily.objects.all().first().user, self.user) + + self.user.delete() + + self.assertEqual(TokenFamily.objects.all().count(), 1) + self.assertEqual(BlacklistedTokenFamily.objects.all().count(), 1) + self.assertIsNone(TokenFamily.objects.all().first().user) diff --git a/tests/test_views.py b/tests/test_views.py index 0f67f81b0..a158c7c9f 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,9 +1,11 @@ from datetime import timedelta +from importlib import reload from unittest import mock from unittest.mock import patch from django.contrib.auth import get_user_model from django.utils import timezone +from freezegun import freeze_time from rest_framework.test import APIRequestFactory from rest_framework_simplejwt import serializers @@ -178,6 +180,34 @@ def test_it_should_return_access_token_if_everything_ok(self): access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME) ) + def test_it_should_return_401_if_family_is_blacklisted(self): + token = RefreshToken.for_user(self.user) + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 200) + + token.blacklist_family() + + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 401) + self.assertIn("family", res.data.get("detail")) + self.assertIn("blacklisted", res.data.get("detail")) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_it_should_return_401_if_family_is_expired(self): + with freeze_time(aware_utcnow() - timedelta(minutes=31)): + token = RefreshToken.for_user(self.user) + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 200) + + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 401) + self.assertIn("family", res.data.get("detail")) + self.assertIn("expired", res.data.get("detail")) + class TestTokenObtainSlidingView(APIViewTestCase): view_name = "token_obtain_sliding" @@ -380,6 +410,34 @@ def test_it_should_ignore_token_type(self): self.assertEqual(res.status_code, 200) self.assertEqual(len(res.data), 0) + def test_it_should_return_401_if_family_is_blacklisted(self): + token = RefreshToken.for_user(self.user) + res = self.view_post(data={"token": str(token)}) + + self.assertEqual(res.status_code, 200) + + token.blacklist_family() + + res = self.view_post(data={"token": str(token)}) + + self.assertEqual(res.status_code, 401) + self.assertIn("family", res.data.get("detail")) + self.assertIn("blacklisted", res.data.get("detail")) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_it_should_return_401_if_family_is_expired(self): + with freeze_time(aware_utcnow() - timedelta(minutes=31)): + token = RefreshToken.for_user(self.user) + res = self.view_post(data={"token": str(token)}) + + self.assertEqual(res.status_code, 200) + + res = self.view_post(data={"token": str(token)}) + + self.assertEqual(res.status_code, 401) + self.assertIn("family", res.data.get("detail")) + self.assertIn("expired", res.data.get("detail")) + class TestTokenBlacklistView(APIViewTestCase): view_name = "token_blacklist" @@ -449,6 +507,101 @@ def test_it_should_return_401_if_token_is_blacklisted(self): self.assertEqual(res.status_code, 401) + def test_it_should_return_401_if_family_is_blacklisted(self): + token = RefreshToken.for_user(self.user) + token.blacklist_family() + + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 401) + self.assertIn("family", res.data.get("detail")) + self.assertIn("blacklisted", res.data.get("detail")) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_it_should_return_401_if_family_is_expired(self): + with freeze_time(aware_utcnow() - timedelta(minutes=31)): + token = RefreshToken.for_user(self.user) + + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 401) + self.assertIn("family", res.data.get("detail")) + self.assertIn("expired", res.data.get("detail")) + + +class TestTokenFamilyBlacklistView(APIViewTestCase): + view_name = "token_family_blacklist" + + def setUp(self): + self.username = "test_user" + self.password = "test_password" + + self.user = User.objects.create_user( + username=self.username, + password=self.password, + ) + + def test_fields_missing(self): + res = self.view_post(data={}) + self.assertEqual(res.status_code, 400) + self.assertIn("refresh", res.data) + + def test_it_should_return_401_if_token_invalid(self): + token = RefreshToken() + del token["exp"] + + res = self.view_post(data={"refresh": str(token)}) + self.assertEqual(res.status_code, 401) + self.assertEqual(res.data["code"], "token_not_valid") + + token.set_exp(lifetime=-timedelta(seconds=1)) + + res = self.view_post(data={"refresh": str(token)}) + self.assertEqual(res.status_code, 401) + self.assertEqual(res.data["code"], "token_not_valid") + + def test_it_should_return_200_if_everything_ok(self): + token = RefreshToken.for_user(self.user) + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 200) + + def test_it_should_return_401_if_token_is_blacklisted(self): + token = RefreshToken() + token.blacklist() + + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 401) + self.assertIn("Token", res.data.get("detail")) + self.assertIn("blacklisted", res.data.get("detail")) + + def test_it_should_return_401_if_family_is_blacklisted(self): + token = RefreshToken.for_user(self.user) + + # here the family gets blacklisted + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 200) + + # here we should get the 401 + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 401) + self.assertIn("family", res.data.get("detail")) + self.assertIn("blacklisted", res.data.get("detail")) + + @override_api_settings(TOKEN_FAMILY_LIFETIME=timedelta(minutes=30)) + def test_it_should_return_401_if_family_is_expired(self): + with freeze_time(aware_utcnow() - timedelta(minutes=31)): + token = RefreshToken.for_user(self.user) + + res = self.view_post(data={"refresh": str(token)}) + + self.assertEqual(res.status_code, 401) + self.assertIn("family", res.data.get("detail")) + self.assertIn("expired", res.data.get("detail")) + class TestCustomTokenView(APIViewTestCase): def test_custom_view_class(self): diff --git a/tests/urls.py b/tests/urls.py index dc8edcac2..0e5e50953 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -17,5 +17,10 @@ ), re_path(r"^token/verify/$", jwt_views.token_verify, name="token_verify"), re_path(r"^token/blacklist/$", jwt_views.token_blacklist, name="token_blacklist"), + re_path( + r"^token/family/blacklist/$", + jwt_views.token_family_blacklist, + name="token_family_blacklist", + ), re_path(r"^test-view/$", views.test_view, name="test_view"), ]