Skip to content

Added options to configure token type claim. #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/configuration/general.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,16 @@
`authjwt_refresh_token_expires`
: How long an refresh token should live before it expires. This takes value `integer` *(seconds)* or
`datetime.timedelta`, and defaults to **30 days**. Can be set to `False` to disable expiration.

`authjwt_token_type_claim`
: Wether you want to add claim that identidies type of a token. Setting it to `False` is not recommended,
but it's useful for integration with other systems that use jwt. Defaults to `True`

`authjwt_access_token_type`
: String that will be placed in type claim to identify if token is an access token. Defaults to `access`

`authjwt_refresh_token_type`
: String that will be placed in type claim to identify if token is a refresh token. Defaults to `refresh`

`authjwt_token_type_claim_name`
: Name of claim where type of a token is located. Defaults to `type`
10 changes: 10 additions & 0 deletions fastapi_jwt_auth/auth_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class AuthConfig:
_refresh_csrf_header_name = "X-CSRF-Token"
_csrf_methods = {'POST','PUT','PATCH','DELETE'}

# options to adjust token's type claim
_token_type_claim = True
_access_token_type = "access"
_refresh_token_type = "refresh"
_token_type_claim_name = "type"

@property
def jwt_in_cookies(self) -> bool:
return 'cookies' in self._token_location
Expand Down Expand Up @@ -91,6 +97,10 @@ def load_config(cls, settings: Callable[...,List[tuple]]) -> "AuthConfig":
cls._access_csrf_header_name = config.authjwt_access_csrf_header_name
cls._refresh_csrf_header_name = config.authjwt_refresh_csrf_header_name
cls._csrf_methods = config.authjwt_csrf_methods
cls._token_type_claim = config.authjwt_token_type_claim
cls._access_token_type = config.authjwt_access_token_type
cls._refresh_token_type = config.authjwt_refresh_token_type
cls._token_type_claim_name = config.authjwt_token_type_claim_name
except ValidationError:
raise
except Exception:
Expand Down
43 changes: 29 additions & 14 deletions fastapi_jwt_auth/auth_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,15 @@ def _create_token(
"nbf": self._get_int_from_datetime(datetime.now(timezone.utc)),
"jti": self._get_jwt_identifier()
}
token_types = {
'access': self._access_token_type,
'refresh': self._refresh_token_type,
}

custom_claims = {"type": type_token}
if self._token_type_claim:
custom_claims = {self._token_type_claim_name: token_types[type_token]}
else:
custom_claims = {}

# for access_token only fresh needed
if type_token == 'access':
Expand Down Expand Up @@ -579,10 +586,12 @@ def _verify_jwt_optional_in_request(self,token: str) -> None:

:param token: The encoded JWT
"""
if token: self._verifying_token(token)
if token:
self._verifying_token(token)
if self._token_type_claim:
if self.get_raw_jwt(token)[self._token_type_claim_name] != self._access_token_type:
raise AccessTokenRequired(status_code=422,message="Only access tokens are allowed")

if token and self.get_raw_jwt(token)['type'] != 'access':
raise AccessTokenRequired(status_code=422,message="Only access tokens are allowed")

def _verify_jwt_in_request(
self,
Expand Down Expand Up @@ -613,15 +622,20 @@ def _verify_jwt_in_request(
# verify jwt
issuer = self._decode_issuer if type_token == 'access' else None
self._verifying_token(token,issuer)
raw_jwt = self.get_raw_jwt(token)
if self._token_type_claim:
token_types = {
"access": self._access_token_type,
"refresh": self._refresh_token_type,
}
if raw_jwt[self._token_type_claim_name] != token_types[type_token]:
msg = "Only {} tokens are allowed".format(token_types[type_token])
if type_token == 'access':
raise AccessTokenRequired(status_code=422,message=msg)
if type_token == 'refresh':
raise RefreshTokenRequired(status_code=422,message=msg)

if self.get_raw_jwt(token)['type'] != type_token:
msg = "Only {} tokens are allowed".format(type_token)
if type_token == 'access':
raise AccessTokenRequired(status_code=422,message=msg)
if type_token == 'refresh':
raise RefreshTokenRequired(status_code=422,message=msg)

if fresh and not self.get_raw_jwt(token)['fresh']:
if fresh and not raw_jwt['fresh']:
raise FreshTokenRequired(status_code=401,message="Fresh token required")

def _verifying_token(self,encoded_token: str, issuer: Optional[str] = None) -> None:
Expand All @@ -632,8 +646,9 @@ def _verifying_token(self,encoded_token: str, issuer: Optional[str] = None) -> N
:param issuer: expected issuer in the JWT
"""
raw_token = self._verified_token(encoded_token,issuer)
if raw_token['type'] in self._denylist_token_checks:
self._check_token_is_revoked(raw_token)
if self._token_type_claim:
if raw_token[self._token_type_claim_name] in self._denylist_token_checks:
self._check_token_is_revoked(raw_token)

def _verified_token(self,encoded_token: str, issuer: Optional[str] = None) -> Dict[str,Union[str,int,bool]]:
"""
Expand Down
11 changes: 11 additions & 0 deletions fastapi_jwt_auth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class LoadConfig(BaseModel):
authjwt_access_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token"
authjwt_refresh_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token"
authjwt_csrf_methods: Optional[Sequence[StrictStr]] = {'POST','PUT','PATCH','DELETE'}
# options to adjust token's type claim
authjwt_token_type_claim: Optional[StrictBool] = True
authjwt_access_token_type: Optional[StrictStr] = "access"
authjwt_refresh_token_type: Optional[StrictStr] = "refresh"
authjwt_token_type_claim_name: Optional[StrictStr] = "type"

@validator('authjwt_access_token_expires')
def validate_access_token_expires(cls, v):
Expand Down Expand Up @@ -80,6 +85,12 @@ def validate_csrf_methods(cls, v):
raise ValueError("The 'authjwt_csrf_methods' must be between http request methods")
return v.upper()

@validator('authjwt_token_type_claim_name')
def validate_token_type_claim_name(cls, v):
if v.lower() in {'iss', 'sub', 'aud', 'exp', 'nbf', 'iat', 'jti'}:
raise ValueError("The 'authjwt_token_type_claim_name' can not override default JWT claims")
return v

class Config:
min_anystr_length = 1
anystr_strip_whitespace = True
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
import pytest
from fastapi_jwt_auth import AuthJWT
from fastapi_jwt_auth.config import LoadConfig

@pytest.fixture(scope="module")
def Authorize():
return AuthJWT()


@pytest.fixture(autouse=True)
def reset_config():
"""
Resets config to default to
guarantee that config is unchanged after test.
"""
yield
@AuthJWT.load_config
def default_conf():
return LoadConfig(
authjwt_secret_key="secret",
authjwt_cookie_samesite='strict',
)
29 changes: 29 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def test_default_config():
assert AuthJWT._access_csrf_header_name == "X-CSRF-Token"
assert AuthJWT._refresh_csrf_header_name == "X-CSRF-Token"
assert AuthJWT._csrf_methods == {'POST','PUT','PATCH','DELETE'}
assert AuthJWT._token_type_claim == True
assert AuthJWT._access_token_type == "access"
assert AuthJWT._refresh_token_type == "refresh"
assert AuthJWT._token_type_claim_name == "type"

def test_token_expired_false(Authorize):
class TokenFalse(BaseSettings):
Expand Down Expand Up @@ -165,6 +169,11 @@ class Settings(BaseSettings):
authjwt_access_csrf_header_name: str = "ACCESS-CSRF-Token"
authjwt_refresh_csrf_header_name: str = "REFRESH-CSRF-Token"
authjwt_csrf_methods: list = ['post']
# options to adjust token's type claim
authjwt_token_type_claim: bool = True
authjwt_access_token_type: str = "access"
authjwt_refresh_token_type: str = "refresh"
authjwt_token_type_claim_name: str = "type"

@AuthJWT.load_config
def get_valid_settings():
Expand Down Expand Up @@ -204,6 +213,11 @@ def get_valid_settings():
assert AuthJWT._access_csrf_header_name == "ACCESS-CSRF-Token"
assert AuthJWT._refresh_csrf_header_name == "REFRESH-CSRF-Token"
assert AuthJWT._csrf_methods == ['POST']
# options to adjust token's type claim
assert AuthJWT._token_type_claim
assert AuthJWT._access_token_type == "access"
assert AuthJWT._refresh_token_type == "refresh"
assert AuthJWT._token_type_claim_name == "type"

with pytest.raises(TypeError,match=r"Config"):
@AuthJWT.load_config
Expand Down Expand Up @@ -401,3 +415,18 @@ def get_invalid_csrf_methods():
@AuthJWT.load_config
def get_invalid_csrf_methods_value():
return [("authjwt_csrf_methods",['posts'])]

with pytest.raises(ValidationError,match=r"authjwt_token_type_claim_name"):
@AuthJWT.load_config
def get_invalid_token_type_claim_name():
return [("authjwt_token_type_claim_name", 'exp')]

with pytest.raises(ValidationError,match=r"authjwt_token_type_claim_name"):
@AuthJWT.load_config
def get_invalid_token_type_claim_name():
return [("authjwt_token_type_claim_name", 'iss')]

with pytest.raises(ValidationError,match=r"authjwt_token_type_claim_name"):
@AuthJWT.load_config
def get_invalid_token_type_claim_name():
return [("authjwt_token_type_claim_name", 'sub')]
23 changes: 14 additions & 9 deletions tests/test_create_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from pydantic import BaseSettings
from datetime import timedelta, datetime, timezone

def test_create_access_token(Authorize):


@pytest.fixture()
def test_settings() -> None:
class Settings(BaseSettings):
AUTHJWT_SECRET_KEY: str = "testing"
AUTHJWT_ACCESS_TOKEN_EXPIRES: int = 2
Expand All @@ -13,6 +16,8 @@ class Settings(BaseSettings):
def get_settings():
return Settings()

def test_create_access_token(Authorize, test_settings):

with pytest.raises(TypeError,match=r"missing 1 required positional argument"):
Authorize.create_access_token()

Expand All @@ -25,7 +30,7 @@ def get_settings():
with pytest.raises(ValueError,match=r"dictionary update sequence element"):
Authorize.create_access_token(subject=1,headers="test")

def test_create_refresh_token(Authorize):
def test_create_refresh_token(Authorize, test_settings):
with pytest.raises(TypeError,match=r"missing 1 required positional argument"):
Authorize.create_refresh_token()

Expand All @@ -35,7 +40,7 @@ def test_create_refresh_token(Authorize):
with pytest.raises(ValueError,match=r"dictionary update sequence element"):
Authorize.create_refresh_token(subject=1,headers="test")

def test_create_dynamic_access_token_expires(Authorize):
def test_create_dynamic_access_token_expires(Authorize, test_settings):
expires_time = int(datetime.now(timezone.utc).timestamp()) + 90
token = Authorize.create_access_token(subject=1,expires_time=90)
assert jwt.decode(token,"testing",algorithms="HS256")['exp'] == expires_time
Expand All @@ -54,7 +59,7 @@ def test_create_dynamic_access_token_expires(Authorize):
with pytest.raises(TypeError,match=r"expires_time"):
Authorize.create_access_token(subject=1,expires_time="test")

def test_create_dynamic_refresh_token_expires(Authorize):
def test_create_dynamic_refresh_token_expires(Authorize, test_settings):
expires_time = int(datetime.now(timezone.utc).timestamp()) + 90
token = Authorize.create_refresh_token(subject=1,expires_time=90)
assert jwt.decode(token,"testing",algorithms="HS256")['exp'] == expires_time
Expand All @@ -73,34 +78,34 @@ def test_create_dynamic_refresh_token_expires(Authorize):
with pytest.raises(TypeError,match=r"expires_time"):
Authorize.create_refresh_token(subject=1,expires_time="test")

def test_create_token_invalid_type_data_audience(Authorize):
def test_create_token_invalid_type_data_audience(Authorize, test_settings):
with pytest.raises(TypeError,match=r"audience"):
Authorize.create_access_token(subject=1,audience=1)

with pytest.raises(TypeError,match=r"audience"):
Authorize.create_refresh_token(subject=1,audience=1)

def test_create_token_invalid_algorithm(Authorize):
def test_create_token_invalid_algorithm(Authorize, test_settings):
with pytest.raises(ValueError,match=r"Algorithm"):
Authorize.create_access_token(subject=1,algorithm="test")

with pytest.raises(ValueError,match=r"Algorithm"):
Authorize.create_refresh_token(subject=1,algorithm="test")

def test_create_token_invalid_type_data_algorithm(Authorize):
def test_create_token_invalid_type_data_algorithm(Authorize, test_settings):
with pytest.raises(TypeError,match=r"algorithm"):
Authorize.create_access_token(subject=1,algorithm=1)

with pytest.raises(TypeError,match=r"algorithm"):
Authorize.create_refresh_token(subject=1,algorithm=1)

def test_create_token_invalid_user_claims(Authorize):
def test_create_token_invalid_user_claims(Authorize, test_settings):
with pytest.raises(TypeError,match=r"user_claims"):
Authorize.create_access_token(subject=1,user_claims="asd")
with pytest.raises(TypeError,match=r"user_claims"):
Authorize.create_refresh_token(subject=1,user_claims="asd")

def test_create_valid_user_claims(Authorize):
def test_create_valid_user_claims(Authorize, test_settings):
access_token = Authorize.create_access_token(subject=1,user_claims={"my_access":"yeah"})
refresh_token = Authorize.create_refresh_token(subject=1,user_claims={"my_refresh":"hello"})

Expand Down
Loading