From 99da6b802fee7cab84c4320957132e80364c5d2f Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Tue, 3 May 2022 13:02:17 +0400 Subject: [PATCH 1/2] Added options to configure token type claim. Signed-off-by: Pavel Kirilin --- docs/configuration/general.md | 13 +++ fastapi_jwt_auth/auth_config.py | 10 +++ fastapi_jwt_auth/auth_jwt.py | 43 +++++++--- fastapi_jwt_auth/config.py | 11 +++ tests/test_config.py | 29 +++++++ tests/test_token_types.py | 145 ++++++++++++++++++++++++++++++++ 6 files changed, 237 insertions(+), 14 deletions(-) create mode 100644 tests/test_token_types.py diff --git a/docs/configuration/general.md b/docs/configuration/general.md index c27ccc3..32c6215 100644 --- a/docs/configuration/general.md +++ b/docs/configuration/general.md @@ -41,3 +41,16 @@ `authjwt_refresh_token_expires` : How long an refresh token should live before it expires. This takes value `integer` *(seconds)* or `datetime.timedelta`, and defaults to **30 days**. Can be set to `False` to disable expiration. + +`authjwt_token_type_claim` +: Wether you want to add claim that identidies type of a token. Setting it to `False` is not recommended, + but it's useful for integration with other systems that use jwt. Defaults to `True` + +`authjwt_access_token_type` +: String that will be placed in type claim to identify if token is an access token. Defaults to `access` + +`authjwt_refresh_token_type` +: String that will be placed in type claim to identify if token is a refresh token. Defaults to `refresh` + +`authjwt_token_type_claim_name` +: Name of claim where type of a token is located. Defaults to `type` \ No newline at end of file diff --git a/fastapi_jwt_auth/auth_config.py b/fastapi_jwt_auth/auth_config.py index b259f2e..af11775 100644 --- a/fastapi_jwt_auth/auth_config.py +++ b/fastapi_jwt_auth/auth_config.py @@ -44,6 +44,12 @@ class AuthConfig: _refresh_csrf_header_name = "X-CSRF-Token" _csrf_methods = {'POST','PUT','PATCH','DELETE'} + # options to adjust token's type claim + _token_type_claim = True + _access_token_type = "access" + _refresh_token_type = "refresh" + _token_type_claim_name = "type" + @property def jwt_in_cookies(self) -> bool: return 'cookies' in self._token_location @@ -91,6 +97,10 @@ def load_config(cls, settings: Callable[...,List[tuple]]) -> "AuthConfig": cls._access_csrf_header_name = config.authjwt_access_csrf_header_name cls._refresh_csrf_header_name = config.authjwt_refresh_csrf_header_name cls._csrf_methods = config.authjwt_csrf_methods + cls._token_type_claim = config.authjwt_token_type_claim + cls._access_token_type = config.authjwt_access_token_type + cls._refresh_token_type = config.authjwt_refresh_token_type + cls._token_type_claim_name = config.authjwt_token_type_claim_name except ValidationError: raise except Exception: diff --git a/fastapi_jwt_auth/auth_jwt.py b/fastapi_jwt_auth/auth_jwt.py index 4110bdb..f6a8b1f 100644 --- a/fastapi_jwt_auth/auth_jwt.py +++ b/fastapi_jwt_auth/auth_jwt.py @@ -162,8 +162,15 @@ def _create_token( "nbf": self._get_int_from_datetime(datetime.now(timezone.utc)), "jti": self._get_jwt_identifier() } + token_types = { + 'access': self._access_token_type, + 'refresh': self._refresh_token_type, + } - custom_claims = {"type": type_token} + if self._token_type_claim: + custom_claims = {self._token_type_claim_name: token_types[type_token]} + else: + custom_claims = {} # for access_token only fresh needed if type_token == 'access': @@ -579,10 +586,12 @@ def _verify_jwt_optional_in_request(self,token: str) -> None: :param token: The encoded JWT """ - if token: self._verifying_token(token) + if token: + self._verifying_token(token) + if self._token_type_claim: + if self.get_raw_jwt(token)[self._token_type_claim_name] != self._access_token_type: + raise AccessTokenRequired(status_code=422,message="Only access tokens are allowed") - if token and self.get_raw_jwt(token)['type'] != 'access': - raise AccessTokenRequired(status_code=422,message="Only access tokens are allowed") def _verify_jwt_in_request( self, @@ -613,15 +622,20 @@ def _verify_jwt_in_request( # verify jwt issuer = self._decode_issuer if type_token == 'access' else None self._verifying_token(token,issuer) + raw_jwt = self.get_raw_jwt(token) + if self._token_type_claim: + token_types = { + "access": self._access_token_type, + "refresh": self._refresh_token_type, + } + if raw_jwt[self._token_type_claim_name] != token_types[type_token]: + msg = "Only {} tokens are allowed".format(token_types[type_token]) + if type_token == 'access': + raise AccessTokenRequired(status_code=422,message=msg) + if type_token == 'refresh': + raise RefreshTokenRequired(status_code=422,message=msg) - if self.get_raw_jwt(token)['type'] != type_token: - msg = "Only {} tokens are allowed".format(type_token) - if type_token == 'access': - raise AccessTokenRequired(status_code=422,message=msg) - if type_token == 'refresh': - raise RefreshTokenRequired(status_code=422,message=msg) - - if fresh and not self.get_raw_jwt(token)['fresh']: + if fresh and not raw_jwt['fresh']: raise FreshTokenRequired(status_code=401,message="Fresh token required") def _verifying_token(self,encoded_token: str, issuer: Optional[str] = None) -> None: @@ -632,8 +646,9 @@ def _verifying_token(self,encoded_token: str, issuer: Optional[str] = None) -> N :param issuer: expected issuer in the JWT """ raw_token = self._verified_token(encoded_token,issuer) - if raw_token['type'] in self._denylist_token_checks: - self._check_token_is_revoked(raw_token) + if self._token_type_claim: + if raw_token[self._token_type_claim_name] in self._denylist_token_checks: + self._check_token_is_revoked(raw_token) def _verified_token(self,encoded_token: str, issuer: Optional[str] = None) -> Dict[str,Union[str,int,bool]]: """ diff --git a/fastapi_jwt_auth/config.py b/fastapi_jwt_auth/config.py index c81b50c..d070ed9 100644 --- a/fastapi_jwt_auth/config.py +++ b/fastapi_jwt_auth/config.py @@ -43,6 +43,11 @@ class LoadConfig(BaseModel): authjwt_access_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" authjwt_refresh_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" authjwt_csrf_methods: Optional[Sequence[StrictStr]] = {'POST','PUT','PATCH','DELETE'} + # options to adjust token's type claim + authjwt_token_type_claim: Optional[StrictBool] = True + authjwt_access_token_type: Optional[StrictStr] = "access" + authjwt_refresh_token_type: Optional[StrictStr] = "refresh" + authjwt_token_type_claim_name: Optional[StrictStr] = "type" @validator('authjwt_access_token_expires') def validate_access_token_expires(cls, v): @@ -80,6 +85,12 @@ def validate_csrf_methods(cls, v): raise ValueError("The 'authjwt_csrf_methods' must be between http request methods") return v.upper() + @validator('authjwt_token_type_claim_name') + def validate_token_type_claim_name(cls, v): + if v.lower() in {'iss', 'sub', 'aud', 'exp', 'nbf', 'iat', 'jti'}: + raise ValueError("The 'authjwt_token_type_claim_name' can not override default JWT claims") + return v + class Config: min_anystr_length = 1 anystr_strip_whitespace = True diff --git a/tests/test_config.py b/tests/test_config.py index 05d2a85..f25dc7f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -58,6 +58,10 @@ def test_default_config(): assert AuthJWT._access_csrf_header_name == "X-CSRF-Token" assert AuthJWT._refresh_csrf_header_name == "X-CSRF-Token" assert AuthJWT._csrf_methods == {'POST','PUT','PATCH','DELETE'} + assert AuthJWT._token_type_claim == True + assert AuthJWT._access_token_type == "access" + assert AuthJWT._refresh_token_type == "refresh" + assert AuthJWT._token_type_claim_name == "type" def test_token_expired_false(Authorize): class TokenFalse(BaseSettings): @@ -165,6 +169,11 @@ class Settings(BaseSettings): authjwt_access_csrf_header_name: str = "ACCESS-CSRF-Token" authjwt_refresh_csrf_header_name: str = "REFRESH-CSRF-Token" authjwt_csrf_methods: list = ['post'] + # options to adjust token's type claim + authjwt_token_type_claim: bool = True + authjwt_access_token_type: str = "access" + authjwt_refresh_token_type: str = "refresh" + authjwt_token_type_claim_name: str = "type" @AuthJWT.load_config def get_valid_settings(): @@ -204,6 +213,11 @@ def get_valid_settings(): assert AuthJWT._access_csrf_header_name == "ACCESS-CSRF-Token" assert AuthJWT._refresh_csrf_header_name == "REFRESH-CSRF-Token" assert AuthJWT._csrf_methods == ['POST'] + # options to adjust token's type claim + assert AuthJWT._token_type_claim + assert AuthJWT._access_token_type == "access" + assert AuthJWT._refresh_token_type == "refresh" + assert AuthJWT._token_type_claim_name == "type" with pytest.raises(TypeError,match=r"Config"): @AuthJWT.load_config @@ -401,3 +415,18 @@ def get_invalid_csrf_methods(): @AuthJWT.load_config def get_invalid_csrf_methods_value(): return [("authjwt_csrf_methods",['posts'])] + + with pytest.raises(ValidationError,match=r"authjwt_token_type_claim_name"): + @AuthJWT.load_config + def get_invalid_token_type_claim_name(): + return [("authjwt_token_type_claim_name", 'exp')] + + with pytest.raises(ValidationError,match=r"authjwt_token_type_claim_name"): + @AuthJWT.load_config + def get_invalid_token_type_claim_name(): + return [("authjwt_token_type_claim_name", 'iss')] + + with pytest.raises(ValidationError,match=r"authjwt_token_type_claim_name"): + @AuthJWT.load_config + def get_invalid_token_type_claim_name(): + return [("authjwt_token_type_claim_name", 'sub')] diff --git a/tests/test_token_types.py b/tests/test_token_types.py new file mode 100644 index 0000000..eb1413f --- /dev/null +++ b/tests/test_token_types.py @@ -0,0 +1,145 @@ +import jwt +import pytest +from fastapi import Depends, FastAPI, Request +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient +from pydantic import BaseSettings + +from fastapi_jwt_auth import AuthJWT +from fastapi_jwt_auth.exceptions import AuthJWTException + + +@pytest.fixture(scope="function") +def client() -> TestClient: + app = FastAPI() + + @app.exception_handler(AuthJWTException) + def authjwt_exception_handler(request: Request, exc: AuthJWTException): + return JSONResponse( + status_code=exc.status_code, content={"detail": exc.message} + ) + + @app.get("/protected") + def protected(Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + return {"hello": "world"} + + @app.get("/semi_protected") + def protected(Authorize: AuthJWT = Depends()): + Authorize.jwt_optional() + return {"hello": "world"} + + @app.get("/refresh") + def refresher(Authorize: AuthJWT = Depends()): + Authorize.jwt_refresh_token_required() + return {"hello": "world"} + + client = TestClient(app) + return client + + +def test_custom_token_type_claim_validation( + client: TestClient, Authorize: AuthJWT +) -> None: + class TestConfig(BaseSettings): + authjwt_secret_key: str = "secret" + authjwt_token_type_claim_name: str = "custom_type" + + @AuthJWT.load_config + def test_config(): + return TestConfig() + + # Checking that created token has custom type claim + access = Authorize.create_access_token(subject="test") + assert jwt.decode(access, key="secret", algorithms=['HS256'])["custom_type"] == "access" + + # Checking that protected endpoint validates token correctly + response = client.get("/protected", headers={"Authorization": f"Bearer {access}"}) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + # Checking that endpoint with optional protection validates token with + # custom type claim correctly. + response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"}) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + # Creating refresh token and checking if it has correct + # type claim. + refresh = Authorize.create_refresh_token(subject="test") + assert jwt.decode(refresh, key="secret", algorithms=['HS256'])["custom_type"] == "refresh" + + # Checking that refreshing with custom claim works. + response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"}) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + + +def test_custom_token_type_names_validation( + client: TestClient, Authorize: AuthJWT +) -> None: + class TestConfig(BaseSettings): + authjwt_secret_key: str = "secret" + authjwt_refresh_token_type: str = "refresh_custom" + authjwt_access_token_type: str = "access_custom" + + @AuthJWT.load_config + def test_config(): + return TestConfig() + + # Creating access token and checking that + # it has custom type + access = Authorize.create_access_token(subject="test") + assert jwt.decode(access, key="secret", algorithms=['HS256'])["type"] == "access_custom" + + # Checking that validation for custom type works as expected. + response = client.get("/protected", headers={"Authorization": f"Bearer {access}"}) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"}) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + # Creating refresh token and checking if it has correct type claim. + refresh = Authorize.create_refresh_token(subject="test") + assert jwt.decode(refresh, key="secret", algorithms=['HS256'])["type"] == "refresh_custom" + + # Checking that refreshing with custom type works. + response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"}) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + +def test_without_type_claims( + client: TestClient, Authorize: AuthJWT +) -> None: + class TestConfig(BaseSettings): + authjwt_secret_key: str = "secret" + authjwt_token_type_claim: bool = False + + @AuthJWT.load_config + def test_config(): + return TestConfig() + + # Creating access token and checking if it doesn't have type claim. + access = Authorize.create_access_token(subject="test") + assert "type" not in jwt.decode(access, key="secret", algorithms=['HS256']) + + response = client.get("/protected", headers={"Authorization": f"Bearer {access}"}) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"}) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + # Creating refresh token and checking if it doesn't have type claim. + refresh = Authorize.create_refresh_token(subject="test") + assert "type" not in jwt.decode(refresh, key="secret", algorithms=['HS256']) + + # Checking that refreshing without type works. + response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"}) + assert response.status_code == 200 + assert response.json() == {"hello": "world"} From cee8d2f098b86caa81cbbfeb63a6329655ff78b5 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Tue, 3 May 2022 14:06:14 +0400 Subject: [PATCH 2/2] Fix for some tests. Signed-off-by: Pavel Kirilin --- tests/conftest.py | 16 +++++++++++ tests/test_create_token.py | 23 +++++++++------- tests/test_decode_token.py | 33 ++++++++++++++++------- tests/test_token_types.py | 55 +++++++++++++++++++++----------------- 4 files changed, 84 insertions(+), 43 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 68974b7..7367bab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,22 @@ import pytest from fastapi_jwt_auth import AuthJWT +from fastapi_jwt_auth.config import LoadConfig @pytest.fixture(scope="module") def Authorize(): return AuthJWT() + + +@pytest.fixture(autouse=True) +def reset_config(): + """ + Resets config to default to + guarantee that config is unchanged after test. + """ + yield + @AuthJWT.load_config + def default_conf(): + return LoadConfig( + authjwt_secret_key="secret", + authjwt_cookie_samesite='strict', + ) \ No newline at end of file diff --git a/tests/test_create_token.py b/tests/test_create_token.py index 7d7f27d..5df164e 100644 --- a/tests/test_create_token.py +++ b/tests/test_create_token.py @@ -3,7 +3,10 @@ from pydantic import BaseSettings from datetime import timedelta, datetime, timezone -def test_create_access_token(Authorize): + + +@pytest.fixture() +def test_settings() -> None: class Settings(BaseSettings): AUTHJWT_SECRET_KEY: str = "testing" AUTHJWT_ACCESS_TOKEN_EXPIRES: int = 2 @@ -13,6 +16,8 @@ class Settings(BaseSettings): def get_settings(): return Settings() +def test_create_access_token(Authorize, test_settings): + with pytest.raises(TypeError,match=r"missing 1 required positional argument"): Authorize.create_access_token() @@ -25,7 +30,7 @@ def get_settings(): with pytest.raises(ValueError,match=r"dictionary update sequence element"): Authorize.create_access_token(subject=1,headers="test") -def test_create_refresh_token(Authorize): +def test_create_refresh_token(Authorize, test_settings): with pytest.raises(TypeError,match=r"missing 1 required positional argument"): Authorize.create_refresh_token() @@ -35,7 +40,7 @@ def test_create_refresh_token(Authorize): with pytest.raises(ValueError,match=r"dictionary update sequence element"): Authorize.create_refresh_token(subject=1,headers="test") -def test_create_dynamic_access_token_expires(Authorize): +def test_create_dynamic_access_token_expires(Authorize, test_settings): expires_time = int(datetime.now(timezone.utc).timestamp()) + 90 token = Authorize.create_access_token(subject=1,expires_time=90) assert jwt.decode(token,"testing",algorithms="HS256")['exp'] == expires_time @@ -54,7 +59,7 @@ def test_create_dynamic_access_token_expires(Authorize): with pytest.raises(TypeError,match=r"expires_time"): Authorize.create_access_token(subject=1,expires_time="test") -def test_create_dynamic_refresh_token_expires(Authorize): +def test_create_dynamic_refresh_token_expires(Authorize, test_settings): expires_time = int(datetime.now(timezone.utc).timestamp()) + 90 token = Authorize.create_refresh_token(subject=1,expires_time=90) assert jwt.decode(token,"testing",algorithms="HS256")['exp'] == expires_time @@ -73,34 +78,34 @@ def test_create_dynamic_refresh_token_expires(Authorize): with pytest.raises(TypeError,match=r"expires_time"): Authorize.create_refresh_token(subject=1,expires_time="test") -def test_create_token_invalid_type_data_audience(Authorize): +def test_create_token_invalid_type_data_audience(Authorize, test_settings): with pytest.raises(TypeError,match=r"audience"): Authorize.create_access_token(subject=1,audience=1) with pytest.raises(TypeError,match=r"audience"): Authorize.create_refresh_token(subject=1,audience=1) -def test_create_token_invalid_algorithm(Authorize): +def test_create_token_invalid_algorithm(Authorize, test_settings): with pytest.raises(ValueError,match=r"Algorithm"): Authorize.create_access_token(subject=1,algorithm="test") with pytest.raises(ValueError,match=r"Algorithm"): Authorize.create_refresh_token(subject=1,algorithm="test") -def test_create_token_invalid_type_data_algorithm(Authorize): +def test_create_token_invalid_type_data_algorithm(Authorize, test_settings): with pytest.raises(TypeError,match=r"algorithm"): Authorize.create_access_token(subject=1,algorithm=1) with pytest.raises(TypeError,match=r"algorithm"): Authorize.create_refresh_token(subject=1,algorithm=1) -def test_create_token_invalid_user_claims(Authorize): +def test_create_token_invalid_user_claims(Authorize, test_settings): with pytest.raises(TypeError,match=r"user_claims"): Authorize.create_access_token(subject=1,user_claims="asd") with pytest.raises(TypeError,match=r"user_claims"): Authorize.create_refresh_token(subject=1,user_claims="asd") -def test_create_valid_user_claims(Authorize): +def test_create_valid_user_claims(Authorize, test_settings): access_token = Authorize.create_access_token(subject=1,user_claims={"my_access":"yeah"}) refresh_token = Authorize.create_refresh_token(subject=1,user_claims={"my_refresh":"hello"}) diff --git a/tests/test_decode_token.py b/tests/test_decode_token.py index 5344d48..d5b06eb 100644 --- a/tests/test_decode_token.py +++ b/tests/test_decode_token.py @@ -49,6 +49,19 @@ def default_access_token(): 'fresh': True, } +@pytest.fixture() +def test_settings() -> None: + class TestSettings(BaseSettings): + AUTHJWT_SECRET_KEY: str = "secret-key" + AUTHJWT_ACCESS_TOKEN_EXPIRES: int = 1 + AUTHJWT_REFRESH_TOKEN_EXPIRES: int = 1 + AUTHJWT_DECODE_LEEWAY: int = 2 + + @AuthJWT.load_config + def load(): + return TestSettings() + + @pytest.fixture(scope='function') def encoded_token(default_access_token): return jwt.encode(default_access_token,'secret-key',algorithm='HS256').decode('utf-8') @@ -111,23 +124,23 @@ def get_settings_two(): assert response.status_code == 200 assert response.json() == {'hello':'world'} -def test_get_raw_token(client,default_access_token,encoded_token): +def test_get_raw_token(client,default_access_token,encoded_token,test_settings): response = client.get('/raw_token',headers={"Authorization":f"Bearer {encoded_token}"}) assert response.status_code == 200 assert response.json() == default_access_token -def test_get_raw_jwt(default_access_token,encoded_token,Authorize): +def test_get_raw_jwt(default_access_token,encoded_token,Authorize,test_settings): assert Authorize.get_raw_jwt(encoded_token) == default_access_token -def test_get_jwt_jti(client,default_access_token,encoded_token,Authorize): +def test_get_jwt_jti(client,default_access_token,encoded_token,Authorize,test_settings): assert Authorize.get_jti(encoded_token=encoded_token) == default_access_token['jti'] -def test_get_jwt_subject(client,default_access_token,encoded_token): +def test_get_jwt_subject(client,default_access_token,encoded_token,test_settings): response = client.get('/get_subject',headers={"Authorization":f"Bearer {encoded_token}"}) assert response.status_code == 200 assert response.json() == default_access_token['sub'] -def test_invalid_jwt_issuer(client,Authorize): +def test_invalid_jwt_issuer(client,Authorize,test_settings): # No issuer claim expected or provided - OK token = Authorize.create_access_token(subject='test') response = client.get('/protected',headers={'Authorization':f"Bearer {token}"}) @@ -154,7 +167,7 @@ def test_invalid_jwt_issuer(client,Authorize): AuthJWT._encode_issuer = None @pytest.mark.parametrize("token_aud",['foo', ['bar'], ['foo', 'bar', 'baz']]) -def test_valid_aud(client,Authorize,token_aud): +def test_valid_aud(client,Authorize,token_aud,test_settings): AuthJWT._decode_audience = ['foo','bar'] access_token = Authorize.create_access_token(subject=1,audience=token_aud) @@ -171,7 +184,7 @@ def test_valid_aud(client,Authorize,token_aud): AuthJWT._decode_audience = None @pytest.mark.parametrize("token_aud",['bar', ['bar'], ['bar', 'baz']]) -def test_invalid_aud_and_missing_aud(client,Authorize,token_aud): +def test_invalid_aud_and_missing_aud(client,Authorize,token_aud,test_settings): AuthJWT._decode_audience = 'foo' access_token = Authorize.create_access_token(subject=1,audience=token_aud) @@ -187,7 +200,7 @@ def test_invalid_aud_and_missing_aud(client,Authorize,token_aud): if token_aud == ['bar','baz']: AuthJWT._decode_audience = None -def test_invalid_decode_algorithms(client,Authorize): +def test_invalid_decode_algorithms(client,Authorize,test_settings): class SettingsAlgorithms(BaseSettings): authjwt_secret_key: str = "secret" authjwt_decode_algorithms: list = ['HS384','RS256'] @@ -203,7 +216,7 @@ def get_settings_algorithms(): AuthJWT._decode_algorithms = None -def test_valid_asymmetric_algorithms(client,Authorize): +def test_valid_asymmetric_algorithms(client,Authorize,test_settings): hs256_token = Authorize.create_access_token(subject=1) DIR = os.path.abspath(os.path.dirname(__file__)) @@ -236,7 +249,7 @@ def get_settings_asymmetric(): assert response.status_code == 200 assert response.json() == {'hello':'world'} -def test_invalid_asymmetric_algorithms(client,Authorize): +def test_invalid_asymmetric_algorithms(client,Authorize,test_settings): class SettingsAsymmetricOne(BaseSettings): authjwt_algorithm: str = "RS256" diff --git a/tests/test_token_types.py b/tests/test_token_types.py index eb1413f..e5d827b 100644 --- a/tests/test_token_types.py +++ b/tests/test_token_types.py @@ -1,24 +1,16 @@ import jwt import pytest -from fastapi import Depends, FastAPI, Request -from fastapi.responses import JSONResponse +from fastapi import Depends, FastAPI from fastapi.testclient import TestClient from pydantic import BaseSettings from fastapi_jwt_auth import AuthJWT -from fastapi_jwt_auth.exceptions import AuthJWTException @pytest.fixture(scope="function") def client() -> TestClient: app = FastAPI() - @app.exception_handler(AuthJWTException) - def authjwt_exception_handler(request: Request, exc: AuthJWTException): - return JSONResponse( - status_code=exc.status_code, content={"detail": exc.message} - ) - @app.get("/protected") def protected(Authorize: AuthJWT = Depends()): Authorize.jwt_required() @@ -51,7 +43,10 @@ def test_config(): # Checking that created token has custom type claim access = Authorize.create_access_token(subject="test") - assert jwt.decode(access, key="secret", algorithms=['HS256'])["custom_type"] == "access" + assert ( + jwt.decode(access, key="secret", algorithms=["HS256"])["custom_type"] + == "access" + ) # Checking that protected endpoint validates token correctly response = client.get("/protected", headers={"Authorization": f"Bearer {access}"}) @@ -60,14 +55,19 @@ def test_config(): # Checking that endpoint with optional protection validates token with # custom type claim correctly. - response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"}) + response = client.get( + "/semi_protected", headers={"Authorization": f"Bearer {access}"} + ) assert response.status_code == 200 assert response.json() == {"hello": "world"} - # Creating refresh token and checking if it has correct + # Creating refresh token and checking if it has correct # type claim. refresh = Authorize.create_refresh_token(subject="test") - assert jwt.decode(refresh, key="secret", algorithms=['HS256'])["custom_type"] == "refresh" + assert ( + jwt.decode(refresh, key="secret", algorithms=["HS256"])["custom_type"] + == "refresh" + ) # Checking that refreshing with custom claim works. response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"}) @@ -75,7 +75,6 @@ def test_config(): assert response.json() == {"hello": "world"} - def test_custom_token_type_names_validation( client: TestClient, Authorize: AuthJWT ) -> None: @@ -88,23 +87,31 @@ class TestConfig(BaseSettings): def test_config(): return TestConfig() - # Creating access token and checking that + # Creating access token and checking that # it has custom type access = Authorize.create_access_token(subject="test") - assert jwt.decode(access, key="secret", algorithms=['HS256'])["type"] == "access_custom" + assert ( + jwt.decode(access, key="secret", algorithms=["HS256"])["type"] + == "access_custom" + ) # Checking that validation for custom type works as expected. response = client.get("/protected", headers={"Authorization": f"Bearer {access}"}) assert response.status_code == 200 assert response.json() == {"hello": "world"} - response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"}) + response = client.get( + "/semi_protected", headers={"Authorization": f"Bearer {access}"} + ) assert response.status_code == 200 assert response.json() == {"hello": "world"} # Creating refresh token and checking if it has correct type claim. refresh = Authorize.create_refresh_token(subject="test") - assert jwt.decode(refresh, key="secret", algorithms=['HS256'])["type"] == "refresh_custom" + assert ( + jwt.decode(refresh, key="secret", algorithms=["HS256"])["type"] + == "refresh_custom" + ) # Checking that refreshing with custom type works. response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"}) @@ -112,9 +119,7 @@ def test_config(): assert response.json() == {"hello": "world"} -def test_without_type_claims( - client: TestClient, Authorize: AuthJWT -) -> None: +def test_without_type_claims(client: TestClient, Authorize: AuthJWT) -> None: class TestConfig(BaseSettings): authjwt_secret_key: str = "secret" authjwt_token_type_claim: bool = False @@ -125,19 +130,21 @@ def test_config(): # Creating access token and checking if it doesn't have type claim. access = Authorize.create_access_token(subject="test") - assert "type" not in jwt.decode(access, key="secret", algorithms=['HS256']) + assert "type" not in jwt.decode(access, key="secret", algorithms=["HS256"]) response = client.get("/protected", headers={"Authorization": f"Bearer {access}"}) assert response.status_code == 200 assert response.json() == {"hello": "world"} - response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"}) + response = client.get( + "/semi_protected", headers={"Authorization": f"Bearer {access}"} + ) assert response.status_code == 200 assert response.json() == {"hello": "world"} # Creating refresh token and checking if it doesn't have type claim. refresh = Authorize.create_refresh_token(subject="test") - assert "type" not in jwt.decode(refresh, key="secret", algorithms=['HS256']) + assert "type" not in jwt.decode(refresh, key="secret", algorithms=["HS256"]) # Checking that refreshing without type works. response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"})