2
2
from jwt .algorithms import requires_cryptography , has_crypto
3
3
from datetime import datetime , timezone , timedelta
4
4
from typing import Optional , Dict , Union , Sequence
5
- from types import GeneratorType
6
- from fastapi import Request , Response
5
+ from fastapi import Request , Response , WebSocket
7
6
from fastapi_jwt_auth .auth_config import AuthConfig
8
7
from fastapi_jwt_auth .exceptions import (
9
8
InvalidHeaderError ,
@@ -148,9 +147,9 @@ def _create_token(
148
147
# Validation type data
149
148
if not isinstance (subject , (str ,int )):
150
149
raise TypeError ("subject must be a string or integer" )
151
- if not isinstance (fresh , ( bool ) ):
150
+ if not isinstance (fresh , bool ):
152
151
raise TypeError ("fresh must be a boolean" )
153
- if audience and not isinstance (audience , (str , list , tuple , set , frozenset , GeneratorType )):
152
+ if audience and not isinstance (audience , (str , list , tuple , set , frozenset )):
154
153
raise TypeError ("audience must be a string or sequence" )
155
154
if algorithm and not isinstance (algorithm , str ):
156
155
raise TypeError ("algorithm must be a string" )
@@ -484,89 +483,104 @@ def unset_refresh_cookies(self,response: Optional[Response] = None) -> None:
484
483
domain = self ._cookie_domain
485
484
)
486
485
487
- def _verify_and_get_jwt_optional_in_cookies (self ) -> "AuthJWT" :
486
+ def verify_and_get_jwt_optional_in_cookies (
487
+ self ,
488
+ request : Union [Request ,WebSocket ],
489
+ csrf_token : Optional [str ] = None ,
490
+ ) -> "AuthJWT" :
488
491
"""
489
492
Optionally check if cookies have a valid access token. if an access token present in
490
- cookies property _token will set. raises exception error when an access token is invalid
493
+ cookies, self. _token will set. raises exception error when an access token is invalid
491
494
and doesn't match with CSRF token double submit
495
+
496
+ :param request: for identity get cookies from HTTP or WebSocket
497
+ :param csrf_token: get csrf token if a request from WebSocket
492
498
"""
499
+ if not isinstance (request ,(Request ,WebSocket )):
500
+ raise TypeError ("request must be an instance of 'Request' or 'WebSocket'" )
501
+
493
502
cookie_key = self ._access_cookie_key
494
- cookie = self ._request .cookies .get (cookie_key )
495
- csrf_cookie = self ._request .headers .get (self ._access_csrf_header_name )
503
+ cookie = request .cookies .get (cookie_key )
504
+ if not isinstance (request , WebSocket ):
505
+ csrf_token = request .headers .get (self ._access_csrf_header_name )
496
506
497
- if (
498
- cookie and
499
- self ._cookie_csrf_protect and
500
- self ._request .method in self ._csrf_methods and
501
- not csrf_cookie
502
- ):
503
- raise CSRFError (status_code = 401 ,message = "Missing CSRF Token" )
507
+ if cookie and self ._cookie_csrf_protect and not csrf_token :
508
+ if isinstance (request , WebSocket ) or request .method in self ._csrf_methods :
509
+ raise CSRFError (status_code = 401 ,message = "Missing CSRF Token" )
504
510
505
511
# set token from cookie and verify jwt
506
512
self ._token = cookie
507
513
self .verify_jwt_optional_in_request (self ._token )
508
514
509
515
decoded_token = self .get_raw_jwt ()
510
516
511
- if (
512
- self ._cookie_csrf_protect and
513
- self ._request .method in self ._csrf_methods and
514
- csrf_cookie and
515
- decoded_token
516
- ):
517
- if 'csrf' not in decoded_token :
518
- raise JWTDecodeError (status_code = 422 ,message = "Missing claim: csrf" )
519
- if not hmac .compare_digest (csrf_cookie ,decoded_token ['csrf' ]):
520
- raise CSRFError (status_code = 401 ,message = "CSRF double submit tokens do not match" )
521
-
522
- def _verify_and_get_jwt_in_cookies (
517
+ if decoded_token and self ._cookie_csrf_protect and csrf_token :
518
+ if isinstance (request , WebSocket ) or request .method in self ._csrf_methods :
519
+ if 'csrf' not in decoded_token :
520
+ raise JWTDecodeError (status_code = 422 ,message = "Missing claim: csrf" )
521
+ if not hmac .compare_digest (csrf_token ,decoded_token ['csrf' ]):
522
+ raise CSRFError (status_code = 401 ,message = "CSRF double submit tokens do not match" )
523
+
524
+ def verify_and_get_jwt_in_cookies (
523
525
self ,
524
526
type_token : str ,
525
- fresh : Optional [bool ] = False
527
+ request : Union [Request ,WebSocket ],
528
+ csrf_token : Optional [str ] = None ,
529
+ fresh : Optional [bool ] = False ,
526
530
) -> "AuthJWT" :
527
531
"""
528
532
Check if cookies have a valid access or refresh token. if an token present in
529
- cookies property _token will set. raises exception error when an access or refresh token
533
+ cookies, self. _token will set. raises exception error when an access or refresh token
530
534
is invalid and doesn't match with CSRF token double submit
531
535
532
536
:param type_token: indicate token is access or refresh token
537
+ :param request: for identity get cookies from HTTP or WebSocket
538
+ :param csrf_token: get csrf token if a request from WebSocket
533
539
:param fresh: check freshness token if True
534
540
"""
541
+ if type_token not in ['access' ,'refresh' ]:
542
+ raise ValueError ("type_token must be between 'access' or 'refresh'" )
543
+ if not isinstance (request ,(Request ,WebSocket )):
544
+ raise TypeError ("request must be an instance of 'Request' or 'WebSocket'" )
545
+
535
546
if type_token == 'access' :
536
547
cookie_key = self ._access_cookie_key
537
- cookie = self ._request .cookies .get (cookie_key )
538
- csrf_cookie = self ._request .headers .get (self ._access_csrf_header_name )
548
+ cookie = request .cookies .get (cookie_key )
549
+ if not isinstance (request , WebSocket ):
550
+ csrf_token = request .headers .get (self ._access_csrf_header_name )
539
551
if type_token == 'refresh' :
540
552
cookie_key = self ._refresh_cookie_key
541
- cookie = self ._request .cookies .get (cookie_key )
542
- csrf_cookie = self ._request .headers .get (self ._refresh_csrf_header_name )
553
+ cookie = request .cookies .get (cookie_key )
554
+ if not isinstance (request , WebSocket ):
555
+ csrf_token = request .headers .get (self ._refresh_csrf_header_name )
543
556
544
557
if not cookie :
545
558
raise MissingCookieError (status_code = 401 ,message = "Missing cookie {}" .format (cookie_key ))
546
559
547
- if self ._cookie_csrf_protect and self ._request .method in self ._csrf_methods and not csrf_cookie :
548
- raise CSRFError (status_code = 401 ,message = "Missing CSRF Token" )
560
+ if self ._cookie_csrf_protect and not csrf_token :
561
+ if isinstance (request , WebSocket ) or request .method in self ._csrf_methods :
562
+ raise CSRFError (status_code = 401 ,message = "Missing CSRF Token" )
549
563
550
564
# set token from cookie and verify jwt
551
565
self ._token = cookie
552
566
self .verify_jwt_in_request (self ._token ,type_token ,'cookies' ,fresh )
553
567
554
568
decoded_token = self .get_raw_jwt ()
555
569
556
- if self ._cookie_csrf_protect and self ._request .method in self ._csrf_methods and csrf_cookie :
557
- if 'csrf' not in decoded_token :
558
- raise JWTDecodeError (status_code = 422 ,message = "Missing claim: csrf" )
559
- if not hmac .compare_digest (csrf_cookie ,decoded_token ['csrf' ]):
560
- raise CSRFError (status_code = 401 ,message = "CSRF double submit tokens do not match" )
570
+ if self ._cookie_csrf_protect and csrf_token :
571
+ if isinstance (request , WebSocket ) or request .method in self ._csrf_methods :
572
+ if 'csrf' not in decoded_token :
573
+ raise JWTDecodeError (status_code = 422 ,message = "Missing claim: csrf" )
574
+ if not hmac .compare_digest (csrf_token ,decoded_token ['csrf' ]):
575
+ raise CSRFError (status_code = 401 ,message = "CSRF double submit tokens do not match" )
561
576
562
577
def verify_jwt_optional_in_request (self ,token : str ) -> None :
563
578
"""
564
579
Optionally check if this request has a valid access token
565
580
566
581
:param token: The encoded JWT
567
582
"""
568
- if token :
569
- self ._verifying_token (token )
583
+ if token : self ._verifying_token (token )
570
584
571
585
if token and self .get_raw_jwt (token )['type' ] != 'access' :
572
586
raise AccessTokenRequired (status_code = 422 ,message = "Only access tokens are allowed" )
@@ -583,17 +597,20 @@ def verify_jwt_in_request(
583
597
584
598
:param token: The encoded JWT
585
599
:param type_token: indicate token is access or refresh token
586
- :param token_from: indicate token from headers or cookies
600
+ :param token_from: indicate token from headers cookies, websocket
587
601
:param fresh: check freshness token if True
588
602
"""
589
- issuer = self . _decode_issuer if type_token == 'access' else None
590
-
591
- if token :
592
- self . _verifying_token ( token , issuer )
603
+ if type_token not in [ 'access' , 'refresh' ]:
604
+ raise ValueError ( "type_token must be between 'access' or 'refresh'" )
605
+ if token_from not in [ 'headers' , 'cookies' , 'websocket' ] :
606
+ raise ValueError ( "token_from must be between 'headers', 'cookies', 'websocket'" )
593
607
594
608
if not token and token_from == 'headers' :
595
609
raise MissingHeaderError (status_code = 401 ,message = "Missing {} Header" .format (self ._header_name ))
596
610
611
+ issuer = self ._decode_issuer if type_token == 'access' else None
612
+ self ._verifying_token (token ,issuer )
613
+
597
614
if self .get_raw_jwt (token )['type' ] != type_token :
598
615
msg = "Only {} tokens are allowed" .format (type_token )
599
616
if type_token == 'access' :
@@ -656,12 +673,12 @@ def jwt_required(self) -> None:
656
673
if self ._token and self .jwt_in_headers :
657
674
self .verify_jwt_in_request (self ._token ,'access' ,'headers' )
658
675
if not self ._token and self .jwt_in_cookies :
659
- self ._verify_and_get_jwt_in_cookies ('access' )
676
+ self .verify_and_get_jwt_in_cookies ('access' , self . _request )
660
677
else :
661
678
if self .jwt_in_headers :
662
679
self .verify_jwt_in_request (self ._token ,'access' ,'headers' )
663
680
if self .jwt_in_cookies :
664
- self ._verify_and_get_jwt_in_cookies ('access' )
681
+ self .verify_and_get_jwt_in_cookies ('access' , self . _request )
665
682
666
683
def jwt_optional (self ) -> None :
667
684
"""
@@ -673,12 +690,12 @@ def jwt_optional(self) -> None:
673
690
if self ._token and self .jwt_in_headers :
674
691
self .verify_jwt_optional_in_request (self ._token )
675
692
if not self ._token and self .jwt_in_cookies :
676
- self ._verify_and_get_jwt_optional_in_cookies ( )
693
+ self .verify_and_get_jwt_optional_in_cookies ( self . _request )
677
694
else :
678
695
if self .jwt_in_headers :
679
696
self .verify_jwt_optional_in_request (self ._token )
680
697
if self .jwt_in_cookies :
681
- self ._verify_and_get_jwt_optional_in_cookies ( )
698
+ self .verify_and_get_jwt_optional_in_cookies ( self . _request )
682
699
683
700
def jwt_refresh_token_required (self ) -> None :
684
701
"""
@@ -688,12 +705,12 @@ def jwt_refresh_token_required(self) -> None:
688
705
if self ._token and self .jwt_in_headers :
689
706
self .verify_jwt_in_request (self ._token ,'refresh' ,'headers' )
690
707
if not self ._token and self .jwt_in_cookies :
691
- self ._verify_and_get_jwt_in_cookies ('refresh' )
708
+ self .verify_and_get_jwt_in_cookies ('refresh' , self . _request )
692
709
else :
693
710
if self .jwt_in_headers :
694
711
self .verify_jwt_in_request (self ._token ,'refresh' ,'headers' )
695
712
if self .jwt_in_cookies :
696
- self ._verify_and_get_jwt_in_cookies ('refresh' )
713
+ self .verify_and_get_jwt_in_cookies ('refresh' , self . _request )
697
714
698
715
def fresh_jwt_required (self ) -> None :
699
716
"""
@@ -703,12 +720,12 @@ def fresh_jwt_required(self) -> None:
703
720
if self ._token and self .jwt_in_headers :
704
721
self .verify_jwt_in_request (self ._token ,'access' ,'headers' ,True )
705
722
if not self ._token and self .jwt_in_cookies :
706
- self ._verify_and_get_jwt_in_cookies ('access' ,True )
723
+ self .verify_and_get_jwt_in_cookies ('access' ,self . _request , fresh = True )
707
724
else :
708
725
if self .jwt_in_headers :
709
726
self .verify_jwt_in_request (self ._token ,'access' ,'headers' ,True )
710
727
if self .jwt_in_cookies :
711
- self ._verify_and_get_jwt_in_cookies ('access' ,True )
728
+ self .verify_and_get_jwt_in_cookies ('access' ,self . _request , fresh = True )
712
729
713
730
def get_raw_jwt (self ,encoded_token : Optional [str ] = None ) -> Optional [Dict [str ,Union [str ,int ,bool ]]]:
714
731
"""
0 commit comments