Skip to content

Commit 1d3325b

Browse files
author
IndominusByte
committed
add support to websocket
1 parent 1fc929b commit 1d3325b

File tree

1 file changed

+72
-55
lines changed

1 file changed

+72
-55
lines changed

fastapi_jwt_auth/auth_jwt.py

Lines changed: 72 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
from jwt.algorithms import requires_cryptography, has_crypto
33
from datetime import datetime, timezone, timedelta
44
from typing import Optional, Dict, Union, Sequence
5-
from types import GeneratorType
6-
from fastapi import Request, Response
5+
from fastapi import Request, Response, WebSocket
76
from fastapi_jwt_auth.auth_config import AuthConfig
87
from fastapi_jwt_auth.exceptions import (
98
InvalidHeaderError,
@@ -148,9 +147,9 @@ def _create_token(
148147
# Validation type data
149148
if not isinstance(subject, (str,int)):
150149
raise TypeError("subject must be a string or integer")
151-
if not isinstance(fresh, (bool)):
150+
if not isinstance(fresh, bool):
152151
raise TypeError("fresh must be a boolean")
153-
if audience and not isinstance(audience, (str, list, tuple, set, frozenset, GeneratorType)):
152+
if audience and not isinstance(audience, (str, list, tuple, set, frozenset)):
154153
raise TypeError("audience must be a string or sequence")
155154
if algorithm and not isinstance(algorithm, str):
156155
raise TypeError("algorithm must be a string")
@@ -484,89 +483,104 @@ def unset_refresh_cookies(self,response: Optional[Response] = None) -> None:
484483
domain=self._cookie_domain
485484
)
486485

487-
def _verify_and_get_jwt_optional_in_cookies(self) -> "AuthJWT":
486+
def verify_and_get_jwt_optional_in_cookies(
487+
self,
488+
request: Union[Request,WebSocket],
489+
csrf_token: Optional[str] = None,
490+
) -> "AuthJWT":
488491
"""
489492
Optionally check if cookies have a valid access token. if an access token present in
490-
cookies property _token will set. raises exception error when an access token is invalid
493+
cookies, self._token will set. raises exception error when an access token is invalid
491494
and doesn't match with CSRF token double submit
495+
496+
:param request: for identity get cookies from HTTP or WebSocket
497+
:param csrf_token: get csrf token if a request from WebSocket
492498
"""
499+
if not isinstance(request,(Request,WebSocket)):
500+
raise TypeError("request must be an instance of 'Request' or 'WebSocket'")
501+
493502
cookie_key = self._access_cookie_key
494-
cookie = self._request.cookies.get(cookie_key)
495-
csrf_cookie = self._request.headers.get(self._access_csrf_header_name)
503+
cookie = request.cookies.get(cookie_key)
504+
if not isinstance(request, WebSocket):
505+
csrf_token = request.headers.get(self._access_csrf_header_name)
496506

497-
if (
498-
cookie and
499-
self._cookie_csrf_protect and
500-
self._request.method in self._csrf_methods and
501-
not csrf_cookie
502-
):
503-
raise CSRFError(status_code=401,message="Missing CSRF Token")
507+
if cookie and self._cookie_csrf_protect and not csrf_token:
508+
if isinstance(request, WebSocket) or request.method in self._csrf_methods:
509+
raise CSRFError(status_code=401,message="Missing CSRF Token")
504510

505511
# set token from cookie and verify jwt
506512
self._token = cookie
507513
self.verify_jwt_optional_in_request(self._token)
508514

509515
decoded_token = self.get_raw_jwt()
510516

511-
if (
512-
self._cookie_csrf_protect and
513-
self._request.method in self._csrf_methods and
514-
csrf_cookie and
515-
decoded_token
516-
):
517-
if 'csrf' not in decoded_token:
518-
raise JWTDecodeError(status_code=422,message="Missing claim: csrf")
519-
if not hmac.compare_digest(csrf_cookie,decoded_token['csrf']):
520-
raise CSRFError(status_code=401,message="CSRF double submit tokens do not match")
521-
522-
def _verify_and_get_jwt_in_cookies(
517+
if decoded_token and self._cookie_csrf_protect and csrf_token:
518+
if isinstance(request, WebSocket) or request.method in self._csrf_methods:
519+
if 'csrf' not in decoded_token:
520+
raise JWTDecodeError(status_code=422,message="Missing claim: csrf")
521+
if not hmac.compare_digest(csrf_token,decoded_token['csrf']):
522+
raise CSRFError(status_code=401,message="CSRF double submit tokens do not match")
523+
524+
def verify_and_get_jwt_in_cookies(
523525
self,
524526
type_token: str,
525-
fresh: Optional[bool] = False
527+
request: Union[Request,WebSocket],
528+
csrf_token: Optional[str] = None,
529+
fresh: Optional[bool] = False,
526530
) -> "AuthJWT":
527531
"""
528532
Check if cookies have a valid access or refresh token. if an token present in
529-
cookies property _token will set. raises exception error when an access or refresh token
533+
cookies, self._token will set. raises exception error when an access or refresh token
530534
is invalid and doesn't match with CSRF token double submit
531535
532536
:param type_token: indicate token is access or refresh token
537+
:param request: for identity get cookies from HTTP or WebSocket
538+
:param csrf_token: get csrf token if a request from WebSocket
533539
:param fresh: check freshness token if True
534540
"""
541+
if type_token not in ['access','refresh']:
542+
raise ValueError("type_token must be between 'access' or 'refresh'")
543+
if not isinstance(request,(Request,WebSocket)):
544+
raise TypeError("request must be an instance of 'Request' or 'WebSocket'")
545+
535546
if type_token == 'access':
536547
cookie_key = self._access_cookie_key
537-
cookie = self._request.cookies.get(cookie_key)
538-
csrf_cookie = self._request.headers.get(self._access_csrf_header_name)
548+
cookie = request.cookies.get(cookie_key)
549+
if not isinstance(request, WebSocket):
550+
csrf_token = request.headers.get(self._access_csrf_header_name)
539551
if type_token == 'refresh':
540552
cookie_key = self._refresh_cookie_key
541-
cookie = self._request.cookies.get(cookie_key)
542-
csrf_cookie = self._request.headers.get(self._refresh_csrf_header_name)
553+
cookie = request.cookies.get(cookie_key)
554+
if not isinstance(request, WebSocket):
555+
csrf_token = request.headers.get(self._refresh_csrf_header_name)
543556

544557
if not cookie:
545558
raise MissingCookieError(status_code=401,message="Missing cookie {}".format(cookie_key))
546559

547-
if self._cookie_csrf_protect and self._request.method in self._csrf_methods and not csrf_cookie:
548-
raise CSRFError(status_code=401,message="Missing CSRF Token")
560+
if self._cookie_csrf_protect and not csrf_token:
561+
if isinstance(request, WebSocket) or request.method in self._csrf_methods:
562+
raise CSRFError(status_code=401,message="Missing CSRF Token")
549563

550564
# set token from cookie and verify jwt
551565
self._token = cookie
552566
self.verify_jwt_in_request(self._token,type_token,'cookies',fresh)
553567

554568
decoded_token = self.get_raw_jwt()
555569

556-
if self._cookie_csrf_protect and self._request.method in self._csrf_methods and csrf_cookie:
557-
if 'csrf' not in decoded_token:
558-
raise JWTDecodeError(status_code=422,message="Missing claim: csrf")
559-
if not hmac.compare_digest(csrf_cookie,decoded_token['csrf']):
560-
raise CSRFError(status_code=401,message="CSRF double submit tokens do not match")
570+
if self._cookie_csrf_protect and csrf_token:
571+
if isinstance(request, WebSocket) or request.method in self._csrf_methods:
572+
if 'csrf' not in decoded_token:
573+
raise JWTDecodeError(status_code=422,message="Missing claim: csrf")
574+
if not hmac.compare_digest(csrf_token,decoded_token['csrf']):
575+
raise CSRFError(status_code=401,message="CSRF double submit tokens do not match")
561576

562577
def verify_jwt_optional_in_request(self,token: str) -> None:
563578
"""
564579
Optionally check if this request has a valid access token
565580
566581
:param token: The encoded JWT
567582
"""
568-
if token:
569-
self._verifying_token(token)
583+
if token: self._verifying_token(token)
570584

571585
if token and self.get_raw_jwt(token)['type'] != 'access':
572586
raise AccessTokenRequired(status_code=422,message="Only access tokens are allowed")
@@ -583,17 +597,20 @@ def verify_jwt_in_request(
583597
584598
:param token: The encoded JWT
585599
:param type_token: indicate token is access or refresh token
586-
:param token_from: indicate token from headers or cookies
600+
:param token_from: indicate token from headers cookies, websocket
587601
:param fresh: check freshness token if True
588602
"""
589-
issuer = self._decode_issuer if type_token == 'access' else None
590-
591-
if token:
592-
self._verifying_token(token,issuer)
603+
if type_token not in ['access','refresh']:
604+
raise ValueError("type_token must be between 'access' or 'refresh'")
605+
if token_from not in ['headers','cookies','websocket']:
606+
raise ValueError("token_from must be between 'headers', 'cookies', 'websocket'")
593607

594608
if not token and token_from == 'headers':
595609
raise MissingHeaderError(status_code=401,message="Missing {} Header".format(self._header_name))
596610

611+
issuer = self._decode_issuer if type_token == 'access' else None
612+
self._verifying_token(token,issuer)
613+
597614
if self.get_raw_jwt(token)['type'] != type_token:
598615
msg = "Only {} tokens are allowed".format(type_token)
599616
if type_token == 'access':
@@ -656,12 +673,12 @@ def jwt_required(self) -> None:
656673
if self._token and self.jwt_in_headers:
657674
self.verify_jwt_in_request(self._token,'access','headers')
658675
if not self._token and self.jwt_in_cookies:
659-
self._verify_and_get_jwt_in_cookies('access')
676+
self.verify_and_get_jwt_in_cookies('access',self._request)
660677
else:
661678
if self.jwt_in_headers:
662679
self.verify_jwt_in_request(self._token,'access','headers')
663680
if self.jwt_in_cookies:
664-
self._verify_and_get_jwt_in_cookies('access')
681+
self.verify_and_get_jwt_in_cookies('access',self._request)
665682

666683
def jwt_optional(self) -> None:
667684
"""
@@ -673,12 +690,12 @@ def jwt_optional(self) -> None:
673690
if self._token and self.jwt_in_headers:
674691
self.verify_jwt_optional_in_request(self._token)
675692
if not self._token and self.jwt_in_cookies:
676-
self._verify_and_get_jwt_optional_in_cookies()
693+
self.verify_and_get_jwt_optional_in_cookies(self._request)
677694
else:
678695
if self.jwt_in_headers:
679696
self.verify_jwt_optional_in_request(self._token)
680697
if self.jwt_in_cookies:
681-
self._verify_and_get_jwt_optional_in_cookies()
698+
self.verify_and_get_jwt_optional_in_cookies(self._request)
682699

683700
def jwt_refresh_token_required(self) -> None:
684701
"""
@@ -688,12 +705,12 @@ def jwt_refresh_token_required(self) -> None:
688705
if self._token and self.jwt_in_headers:
689706
self.verify_jwt_in_request(self._token,'refresh','headers')
690707
if not self._token and self.jwt_in_cookies:
691-
self._verify_and_get_jwt_in_cookies('refresh')
708+
self.verify_and_get_jwt_in_cookies('refresh',self._request)
692709
else:
693710
if self.jwt_in_headers:
694711
self.verify_jwt_in_request(self._token,'refresh','headers')
695712
if self.jwt_in_cookies:
696-
self._verify_and_get_jwt_in_cookies('refresh')
713+
self.verify_and_get_jwt_in_cookies('refresh',self._request)
697714

698715
def fresh_jwt_required(self) -> None:
699716
"""
@@ -703,12 +720,12 @@ def fresh_jwt_required(self) -> None:
703720
if self._token and self.jwt_in_headers:
704721
self.verify_jwt_in_request(self._token,'access','headers',True)
705722
if not self._token and self.jwt_in_cookies:
706-
self._verify_and_get_jwt_in_cookies('access',True)
723+
self.verify_and_get_jwt_in_cookies('access',self._request,fresh=True)
707724
else:
708725
if self.jwt_in_headers:
709726
self.verify_jwt_in_request(self._token,'access','headers',True)
710727
if self.jwt_in_cookies:
711-
self._verify_and_get_jwt_in_cookies('access',True)
728+
self.verify_and_get_jwt_in_cookies('access',self._request,fresh=True)
712729

713730
def get_raw_jwt(self,encoded_token: Optional[str] = None) -> Optional[Dict[str,Union[str,int,bool]]]:
714731
"""

0 commit comments

Comments
 (0)