Skip to content

Commit 387baf4

Browse files
author
IndominusByte
committed
add additional claims
1 parent 4161f7f commit 387baf4

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

fastapi_jwt_auth/auth_jwt.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def _create_token(
127127
algorithm: Optional[str] = None,
128128
headers: Optional[Dict] = None,
129129
issuer: Optional[str] = None,
130-
audience: Optional[Union[str,Sequence[str]]] = None
130+
audience: Optional[Union[str,Sequence[str]]] = None,
131+
user_claims: Optional[Dict] = {}
131132
) -> str:
132133
"""
133134
Create token for access_token and refresh_token (utf-8)
@@ -140,6 +141,7 @@ def _create_token(
140141
:param headers: valid dict for specifying additional headers in JWT header section
141142
:param issuer: expected issuer in the JWT
142143
:param audience: expected audience in the JWT
144+
:param user_claims: Custom claims to include in this token. This data must be dictionary
143145
144146
:return: Encoded token
145147
"""
@@ -152,6 +154,8 @@ def _create_token(
152154
raise TypeError("audience must be a string or sequence")
153155
if algorithm and not isinstance(algorithm, str):
154156
raise TypeError("algorithm must be a string")
157+
if user_claims and not isinstance(user_claims, dict):
158+
raise TypeError("user_claims must be a dictionary")
155159

156160
# Data section
157161
reserved_claims = {
@@ -185,7 +189,7 @@ def _create_token(
185189
raise
186190

187191
return jwt.encode(
188-
{**reserved_claims, **custom_claims},
192+
{**reserved_claims, **custom_claims, **user_claims},
189193
secret_key,
190194
algorithm=algorithm,
191195
headers=headers
@@ -256,7 +260,8 @@ def create_access_token(
256260
algorithm: Optional[str] = None,
257261
headers: Optional[Dict] = None,
258262
expires_time: Optional[Union[timedelta,int,bool]] = None,
259-
audience: Optional[Union[str,Sequence[str]]] = None
263+
audience: Optional[Union[str,Sequence[str]]] = None,
264+
user_claims: Optional[Dict] = {}
260265
) -> str:
261266
"""
262267
Create a access token with 15 minutes for expired time (default),
@@ -272,6 +277,7 @@ def create_access_token(
272277
algorithm=algorithm,
273278
headers=headers,
274279
audience=audience,
280+
user_claims=user_claims,
275281
issuer=self._encode_issuer
276282
)
277283

@@ -281,7 +287,8 @@ def create_refresh_token(
281287
algorithm: Optional[str] = None,
282288
headers: Optional[Dict] = None,
283289
expires_time: Optional[Union[timedelta,int,bool]] = None,
284-
audience: Optional[Union[str,Sequence[str]]] = None
290+
audience: Optional[Union[str,Sequence[str]]] = None,
291+
user_claims: Optional[Dict] = {}
285292
) -> str:
286293
"""
287294
Create a refresh token with 30 days for expired time (default),
@@ -295,7 +302,8 @@ def create_refresh_token(
295302
exp_time=self._get_expired_time("refresh",expires_time),
296303
algorithm=algorithm,
297304
headers=headers,
298-
audience=audience
305+
audience=audience,
306+
user_claims=user_claims
299307
)
300308

301309
def _get_csrf_token(self,encoded_token: str) -> str:

tests/test_create_token.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,16 @@ def test_create_token_invalid_type_data_algorithm(Authorize):
9393

9494
with pytest.raises(TypeError,match=r"algorithm"):
9595
Authorize.create_refresh_token(subject=1,algorithm=1)
96+
97+
def test_create_token_invalid_user_claims(Authorize):
98+
with pytest.raises(TypeError,match=r"user_claims"):
99+
Authorize.create_access_token(subject=1,user_claims="asd")
100+
with pytest.raises(TypeError,match=r"user_claims"):
101+
Authorize.create_refresh_token(subject=1,user_claims="asd")
102+
103+
def test_create_valid_user_claims(Authorize):
104+
access_token = Authorize.create_access_token(subject=1,user_claims={"my_access":"yeah"})
105+
refresh_token = Authorize.create_refresh_token(subject=1,user_claims={"my_refresh":"hello"})
106+
107+
assert jwt.decode(access_token,"testing",algorithms="HS256")['my_access'] == "yeah"
108+
assert jwt.decode(refresh_token,"testing",algorithms="HS256")['my_refresh'] == "hello"

0 commit comments

Comments
 (0)