Skip to content

Commit 99da6b8

Browse files
committed
Added options to configure token type claim.
Signed-off-by: Pavel Kirilin <win10@list.ru>
1 parent a6c0619 commit 99da6b8

File tree

6 files changed

+237
-14
lines changed

6 files changed

+237
-14
lines changed

docs/configuration/general.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,16 @@
4141
`authjwt_refresh_token_expires`
4242
: How long an refresh token should live before it expires. This takes value `integer` *(seconds)* or
4343
`datetime.timedelta`, and defaults to **30 days**. Can be set to `False` to disable expiration.
44+
45+
`authjwt_token_type_claim`
46+
: Wether you want to add claim that identidies type of a token. Setting it to `False` is not recommended,
47+
but it's useful for integration with other systems that use jwt. Defaults to `True`
48+
49+
`authjwt_access_token_type`
50+
: String that will be placed in type claim to identify if token is an access token. Defaults to `access`
51+
52+
`authjwt_refresh_token_type`
53+
: String that will be placed in type claim to identify if token is a refresh token. Defaults to `refresh`
54+
55+
`authjwt_token_type_claim_name`
56+
: Name of claim where type of a token is located. Defaults to `type`

fastapi_jwt_auth/auth_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ class AuthConfig:
4444
_refresh_csrf_header_name = "X-CSRF-Token"
4545
_csrf_methods = {'POST','PUT','PATCH','DELETE'}
4646

47+
# options to adjust token's type claim
48+
_token_type_claim = True
49+
_access_token_type = "access"
50+
_refresh_token_type = "refresh"
51+
_token_type_claim_name = "type"
52+
4753
@property
4854
def jwt_in_cookies(self) -> bool:
4955
return 'cookies' in self._token_location
@@ -91,6 +97,10 @@ def load_config(cls, settings: Callable[...,List[tuple]]) -> "AuthConfig":
9197
cls._access_csrf_header_name = config.authjwt_access_csrf_header_name
9298
cls._refresh_csrf_header_name = config.authjwt_refresh_csrf_header_name
9399
cls._csrf_methods = config.authjwt_csrf_methods
100+
cls._token_type_claim = config.authjwt_token_type_claim
101+
cls._access_token_type = config.authjwt_access_token_type
102+
cls._refresh_token_type = config.authjwt_refresh_token_type
103+
cls._token_type_claim_name = config.authjwt_token_type_claim_name
94104
except ValidationError:
95105
raise
96106
except Exception:

fastapi_jwt_auth/auth_jwt.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,15 @@ def _create_token(
162162
"nbf": self._get_int_from_datetime(datetime.now(timezone.utc)),
163163
"jti": self._get_jwt_identifier()
164164
}
165+
token_types = {
166+
'access': self._access_token_type,
167+
'refresh': self._refresh_token_type,
168+
}
165169

166-
custom_claims = {"type": type_token}
170+
if self._token_type_claim:
171+
custom_claims = {self._token_type_claim_name: token_types[type_token]}
172+
else:
173+
custom_claims = {}
167174

168175
# for access_token only fresh needed
169176
if type_token == 'access':
@@ -579,10 +586,12 @@ def _verify_jwt_optional_in_request(self,token: str) -> None:
579586
580587
:param token: The encoded JWT
581588
"""
582-
if token: self._verifying_token(token)
589+
if token:
590+
self._verifying_token(token)
591+
if self._token_type_claim:
592+
if self.get_raw_jwt(token)[self._token_type_claim_name] != self._access_token_type:
593+
raise AccessTokenRequired(status_code=422,message="Only access tokens are allowed")
583594

584-
if token and self.get_raw_jwt(token)['type'] != 'access':
585-
raise AccessTokenRequired(status_code=422,message="Only access tokens are allowed")
586595

587596
def _verify_jwt_in_request(
588597
self,
@@ -613,15 +622,20 @@ def _verify_jwt_in_request(
613622
# verify jwt
614623
issuer = self._decode_issuer if type_token == 'access' else None
615624
self._verifying_token(token,issuer)
625+
raw_jwt = self.get_raw_jwt(token)
626+
if self._token_type_claim:
627+
token_types = {
628+
"access": self._access_token_type,
629+
"refresh": self._refresh_token_type,
630+
}
631+
if raw_jwt[self._token_type_claim_name] != token_types[type_token]:
632+
msg = "Only {} tokens are allowed".format(token_types[type_token])
633+
if type_token == 'access':
634+
raise AccessTokenRequired(status_code=422,message=msg)
635+
if type_token == 'refresh':
636+
raise RefreshTokenRequired(status_code=422,message=msg)
616637

617-
if self.get_raw_jwt(token)['type'] != type_token:
618-
msg = "Only {} tokens are allowed".format(type_token)
619-
if type_token == 'access':
620-
raise AccessTokenRequired(status_code=422,message=msg)
621-
if type_token == 'refresh':
622-
raise RefreshTokenRequired(status_code=422,message=msg)
623-
624-
if fresh and not self.get_raw_jwt(token)['fresh']:
638+
if fresh and not raw_jwt['fresh']:
625639
raise FreshTokenRequired(status_code=401,message="Fresh token required")
626640

627641
def _verifying_token(self,encoded_token: str, issuer: Optional[str] = None) -> None:
@@ -632,8 +646,9 @@ def _verifying_token(self,encoded_token: str, issuer: Optional[str] = None) -> N
632646
:param issuer: expected issuer in the JWT
633647
"""
634648
raw_token = self._verified_token(encoded_token,issuer)
635-
if raw_token['type'] in self._denylist_token_checks:
636-
self._check_token_is_revoked(raw_token)
649+
if self._token_type_claim:
650+
if raw_token[self._token_type_claim_name] in self._denylist_token_checks:
651+
self._check_token_is_revoked(raw_token)
637652

638653
def _verified_token(self,encoded_token: str, issuer: Optional[str] = None) -> Dict[str,Union[str,int,bool]]:
639654
"""

fastapi_jwt_auth/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ class LoadConfig(BaseModel):
4343
authjwt_access_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token"
4444
authjwt_refresh_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token"
4545
authjwt_csrf_methods: Optional[Sequence[StrictStr]] = {'POST','PUT','PATCH','DELETE'}
46+
# options to adjust token's type claim
47+
authjwt_token_type_claim: Optional[StrictBool] = True
48+
authjwt_access_token_type: Optional[StrictStr] = "access"
49+
authjwt_refresh_token_type: Optional[StrictStr] = "refresh"
50+
authjwt_token_type_claim_name: Optional[StrictStr] = "type"
4651

4752
@validator('authjwt_access_token_expires')
4853
def validate_access_token_expires(cls, v):
@@ -80,6 +85,12 @@ def validate_csrf_methods(cls, v):
8085
raise ValueError("The 'authjwt_csrf_methods' must be between http request methods")
8186
return v.upper()
8287

88+
@validator('authjwt_token_type_claim_name')
89+
def validate_token_type_claim_name(cls, v):
90+
if v.lower() in {'iss', 'sub', 'aud', 'exp', 'nbf', 'iat', 'jti'}:
91+
raise ValueError("The 'authjwt_token_type_claim_name' can not override default JWT claims")
92+
return v
93+
8394
class Config:
8495
min_anystr_length = 1
8596
anystr_strip_whitespace = True

tests/test_config.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def test_default_config():
5858
assert AuthJWT._access_csrf_header_name == "X-CSRF-Token"
5959
assert AuthJWT._refresh_csrf_header_name == "X-CSRF-Token"
6060
assert AuthJWT._csrf_methods == {'POST','PUT','PATCH','DELETE'}
61+
assert AuthJWT._token_type_claim == True
62+
assert AuthJWT._access_token_type == "access"
63+
assert AuthJWT._refresh_token_type == "refresh"
64+
assert AuthJWT._token_type_claim_name == "type"
6165

6266
def test_token_expired_false(Authorize):
6367
class TokenFalse(BaseSettings):
@@ -165,6 +169,11 @@ class Settings(BaseSettings):
165169
authjwt_access_csrf_header_name: str = "ACCESS-CSRF-Token"
166170
authjwt_refresh_csrf_header_name: str = "REFRESH-CSRF-Token"
167171
authjwt_csrf_methods: list = ['post']
172+
# options to adjust token's type claim
173+
authjwt_token_type_claim: bool = True
174+
authjwt_access_token_type: str = "access"
175+
authjwt_refresh_token_type: str = "refresh"
176+
authjwt_token_type_claim_name: str = "type"
168177

169178
@AuthJWT.load_config
170179
def get_valid_settings():
@@ -204,6 +213,11 @@ def get_valid_settings():
204213
assert AuthJWT._access_csrf_header_name == "ACCESS-CSRF-Token"
205214
assert AuthJWT._refresh_csrf_header_name == "REFRESH-CSRF-Token"
206215
assert AuthJWT._csrf_methods == ['POST']
216+
# options to adjust token's type claim
217+
assert AuthJWT._token_type_claim
218+
assert AuthJWT._access_token_type == "access"
219+
assert AuthJWT._refresh_token_type == "refresh"
220+
assert AuthJWT._token_type_claim_name == "type"
207221

208222
with pytest.raises(TypeError,match=r"Config"):
209223
@AuthJWT.load_config
@@ -401,3 +415,18 @@ def get_invalid_csrf_methods():
401415
@AuthJWT.load_config
402416
def get_invalid_csrf_methods_value():
403417
return [("authjwt_csrf_methods",['posts'])]
418+
419+
with pytest.raises(ValidationError,match=r"authjwt_token_type_claim_name"):
420+
@AuthJWT.load_config
421+
def get_invalid_token_type_claim_name():
422+
return [("authjwt_token_type_claim_name", 'exp')]
423+
424+
with pytest.raises(ValidationError,match=r"authjwt_token_type_claim_name"):
425+
@AuthJWT.load_config
426+
def get_invalid_token_type_claim_name():
427+
return [("authjwt_token_type_claim_name", 'iss')]
428+
429+
with pytest.raises(ValidationError,match=r"authjwt_token_type_claim_name"):
430+
@AuthJWT.load_config
431+
def get_invalid_token_type_claim_name():
432+
return [("authjwt_token_type_claim_name", 'sub')]

tests/test_token_types.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import jwt
2+
import pytest
3+
from fastapi import Depends, FastAPI, Request
4+
from fastapi.responses import JSONResponse
5+
from fastapi.testclient import TestClient
6+
from pydantic import BaseSettings
7+
8+
from fastapi_jwt_auth import AuthJWT
9+
from fastapi_jwt_auth.exceptions import AuthJWTException
10+
11+
12+
@pytest.fixture(scope="function")
13+
def client() -> TestClient:
14+
app = FastAPI()
15+
16+
@app.exception_handler(AuthJWTException)
17+
def authjwt_exception_handler(request: Request, exc: AuthJWTException):
18+
return JSONResponse(
19+
status_code=exc.status_code, content={"detail": exc.message}
20+
)
21+
22+
@app.get("/protected")
23+
def protected(Authorize: AuthJWT = Depends()):
24+
Authorize.jwt_required()
25+
return {"hello": "world"}
26+
27+
@app.get("/semi_protected")
28+
def protected(Authorize: AuthJWT = Depends()):
29+
Authorize.jwt_optional()
30+
return {"hello": "world"}
31+
32+
@app.get("/refresh")
33+
def refresher(Authorize: AuthJWT = Depends()):
34+
Authorize.jwt_refresh_token_required()
35+
return {"hello": "world"}
36+
37+
client = TestClient(app)
38+
return client
39+
40+
41+
def test_custom_token_type_claim_validation(
42+
client: TestClient, Authorize: AuthJWT
43+
) -> None:
44+
class TestConfig(BaseSettings):
45+
authjwt_secret_key: str = "secret"
46+
authjwt_token_type_claim_name: str = "custom_type"
47+
48+
@AuthJWT.load_config
49+
def test_config():
50+
return TestConfig()
51+
52+
# Checking that created token has custom type claim
53+
access = Authorize.create_access_token(subject="test")
54+
assert jwt.decode(access, key="secret", algorithms=['HS256'])["custom_type"] == "access"
55+
56+
# Checking that protected endpoint validates token correctly
57+
response = client.get("/protected", headers={"Authorization": f"Bearer {access}"})
58+
assert response.status_code == 200
59+
assert response.json() == {"hello": "world"}
60+
61+
# Checking that endpoint with optional protection validates token with
62+
# custom type claim correctly.
63+
response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"})
64+
assert response.status_code == 200
65+
assert response.json() == {"hello": "world"}
66+
67+
# Creating refresh token and checking if it has correct
68+
# type claim.
69+
refresh = Authorize.create_refresh_token(subject="test")
70+
assert jwt.decode(refresh, key="secret", algorithms=['HS256'])["custom_type"] == "refresh"
71+
72+
# Checking that refreshing with custom claim works.
73+
response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"})
74+
assert response.status_code == 200
75+
assert response.json() == {"hello": "world"}
76+
77+
78+
79+
def test_custom_token_type_names_validation(
80+
client: TestClient, Authorize: AuthJWT
81+
) -> None:
82+
class TestConfig(BaseSettings):
83+
authjwt_secret_key: str = "secret"
84+
authjwt_refresh_token_type: str = "refresh_custom"
85+
authjwt_access_token_type: str = "access_custom"
86+
87+
@AuthJWT.load_config
88+
def test_config():
89+
return TestConfig()
90+
91+
# Creating access token and checking that
92+
# it has custom type
93+
access = Authorize.create_access_token(subject="test")
94+
assert jwt.decode(access, key="secret", algorithms=['HS256'])["type"] == "access_custom"
95+
96+
# Checking that validation for custom type works as expected.
97+
response = client.get("/protected", headers={"Authorization": f"Bearer {access}"})
98+
assert response.status_code == 200
99+
assert response.json() == {"hello": "world"}
100+
101+
response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"})
102+
assert response.status_code == 200
103+
assert response.json() == {"hello": "world"}
104+
105+
# Creating refresh token and checking if it has correct type claim.
106+
refresh = Authorize.create_refresh_token(subject="test")
107+
assert jwt.decode(refresh, key="secret", algorithms=['HS256'])["type"] == "refresh_custom"
108+
109+
# Checking that refreshing with custom type works.
110+
response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"})
111+
assert response.status_code == 200
112+
assert response.json() == {"hello": "world"}
113+
114+
115+
def test_without_type_claims(
116+
client: TestClient, Authorize: AuthJWT
117+
) -> None:
118+
class TestConfig(BaseSettings):
119+
authjwt_secret_key: str = "secret"
120+
authjwt_token_type_claim: bool = False
121+
122+
@AuthJWT.load_config
123+
def test_config():
124+
return TestConfig()
125+
126+
# Creating access token and checking if it doesn't have type claim.
127+
access = Authorize.create_access_token(subject="test")
128+
assert "type" not in jwt.decode(access, key="secret", algorithms=['HS256'])
129+
130+
response = client.get("/protected", headers={"Authorization": f"Bearer {access}"})
131+
assert response.status_code == 200
132+
assert response.json() == {"hello": "world"}
133+
134+
response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"})
135+
assert response.status_code == 200
136+
assert response.json() == {"hello": "world"}
137+
138+
# Creating refresh token and checking if it doesn't have type claim.
139+
refresh = Authorize.create_refresh_token(subject="test")
140+
assert "type" not in jwt.decode(refresh, key="secret", algorithms=['HS256'])
141+
142+
# Checking that refreshing without type works.
143+
response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"})
144+
assert response.status_code == 200
145+
assert response.json() == {"hello": "world"}

0 commit comments

Comments
 (0)