diff --git a/fastapi_jwt_auth/auth_config.py b/fastapi_jwt_auth/auth_config.py index b259f2e..9fe4fa6 100644 --- a/fastapi_jwt_auth/auth_config.py +++ b/fastapi_jwt_auth/auth_config.py @@ -3,6 +3,7 @@ from typing import Callable, List from datetime import timedelta + class AuthConfig: _token = None _token_location = {'headers'} @@ -17,7 +18,7 @@ class AuthConfig: _decode_issuer = None _decode_audience = None _denylist_enabled = False - _denylist_token_checks = {'access','refresh'} + _denylist_token_checks = {'access', 'refresh'} _header_name = "Authorization" _header_type = "Bearer" _token_in_denylist_callback = None @@ -42,7 +43,7 @@ class AuthConfig: _refresh_csrf_cookie_path = "/" _access_csrf_header_name = "X-CSRF-Token" _refresh_csrf_header_name = "X-CSRF-Token" - _csrf_methods = {'POST','PUT','PATCH','DELETE'} + _csrf_methods = {'POST', 'PUT', 'PATCH', 'DELETE'} @property def jwt_in_cookies(self) -> bool: @@ -53,9 +54,9 @@ def jwt_in_headers(self) -> bool: return 'headers' in self._token_location @classmethod - def load_config(cls, settings: Callable[...,List[tuple]]) -> "AuthConfig": + def load_config(cls, settings: Callable[..., List[tuple]]) -> "AuthConfig": try: - config = LoadConfig(**{key.lower():value for key,value in settings()}) + config = LoadConfig(**{key.lower(): value for key, value in settings()}) cls._token_location = config.authjwt_token_location cls._secret_key = config.authjwt_secret_key @@ -97,7 +98,7 @@ def load_config(cls, settings: Callable[...,List[tuple]]) -> "AuthConfig": raise TypeError("Config must be pydantic 'BaseSettings' or list of tuple") @classmethod - def token_in_denylist_loader(cls, callback: Callable[...,bool]) -> "AuthConfig": + def token_in_denylist_loader(cls, callback: Callable[..., bool]) -> "AuthConfig": """ This decorator sets the callback function that will be called when a protected endpoint is accessed and will check if the JWT has been diff --git a/fastapi_jwt_auth/auth_jwt.py b/fastapi_jwt_auth/auth_jwt.py index 4110bdb..b829b94 100644 --- a/fastapi_jwt_auth/auth_jwt.py +++ b/fastapi_jwt_auth/auth_jwt.py @@ -15,8 +15,9 @@ FreshTokenRequired ) + class AuthJWT(AuthConfig): - def __init__(self,req: Request = None, res: Response = None): + def __init__(self, req: Request = None, res: Response = None): """ Get jwt header from incoming request or get request and response object if jwt in the cookie @@ -36,7 +37,7 @@ def __init__(self,req: Request = None, res: Response = None): auth = req.headers.get(self._header_name.lower()) if auth: self._get_jwt_from_headers(auth) - def _get_jwt_from_headers(self,auth: str) -> "AuthJWT": + def _get_jwt_from_headers(self, auth: str) -> "AuthJWT": """ Get token from the headers @@ -51,19 +52,19 @@ def _get_jwt_from_headers(self,auth: str) -> "AuthJWT": # : if len(parts) != 1: msg = "Bad {} header. Expected value ''".format(header_name) - raise InvalidHeaderError(status_code=422,message=msg) + raise InvalidHeaderError(status_code=422, message=msg) self._token = parts[0] else: # : - if not re.match(r"{}\s".format(header_type),auth) or len(parts) != 2: - msg = "Bad {} header. Expected value '{} '".format(header_name,header_type) - raise InvalidHeaderError(status_code=422,message=msg) + if not re.match(r"{}\s".format(header_type), auth) or len(parts) != 2: + msg = "Bad {} header. Expected value '{} '".format(header_name, header_type) + raise InvalidHeaderError(status_code=422, message=msg) self._token = parts[1] def _get_jwt_identifier(self) -> str: return str(uuid.uuid4()) - def _get_int_from_datetime(self,value: datetime) -> int: + def _get_int_from_datetime(self, value: datetime) -> int: """ :param value: datetime with or without timezone, if don't contains timezone it will managed as it is UTC @@ -82,7 +83,7 @@ def _get_secret_key(self, algorithm: str, process: str) -> str: :return: plain text or RSA depends on algorithm """ - symmetric_algorithms, asymmetric_algorithms = {"HS256","HS384","HS512"}, requires_cryptography + symmetric_algorithms, asymmetric_algorithms = {"HS256", "HS384", "HS512"}, requires_cryptography if algorithm not in symmetric_algorithms and algorithm not in asymmetric_algorithms: raise ValueError("Algorithm {} could not be found".format(algorithm)) @@ -117,16 +118,16 @@ def _get_secret_key(self, algorithm: str, process: str) -> str: return self._public_key def _create_token( - self, - subject: Union[str,int], - type_token: str, - exp_time: Optional[int], - fresh: Optional[bool] = False, - algorithm: Optional[str] = None, - headers: Optional[Dict] = None, - issuer: Optional[str] = None, - audience: Optional[Union[str,Sequence[str]]] = None, - user_claims: Optional[Dict] = {} + self, + subject: Union[str, int], + type_token: str, + exp_time: Optional[int], + fresh: Optional[bool] = False, + algorithm: Optional[str] = None, + headers: Optional[Dict] = None, + issuer: Optional[str] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {} ) -> str: """ Create token for access_token and refresh_token (utf-8) @@ -144,7 +145,7 @@ def _create_token( :return: Encoded token """ # Validation type data - if not isinstance(subject, (str,int)): + if not isinstance(subject, (str, int)): raise TypeError("subject must be a string or integer") if not isinstance(fresh, bool): raise TypeError("fresh must be a boolean") @@ -182,7 +183,7 @@ def _create_token( algorithm = algorithm or self._algorithm try: - secret_key = self._get_secret_key(algorithm,"encode") + secret_key = self._get_secret_key(algorithm, "encode") except Exception: raise @@ -193,13 +194,14 @@ def _create_token( headers=headers ).decode('utf-8') + def _has_token_in_denylist_callback(self) -> bool: """ Return True if token denylist callback set """ return self._token_in_denylist_callback is not None - def _check_token_is_revoked(self, raw_token: Dict[str,Union[str,int,bool]]) -> None: + def _check_token_is_revoked(self, raw_token: Dict[str, Union[str, int, bool]]) -> None: """ Ensure that AUTHJWT_DENYLIST_ENABLED is true and callback regulated, and then call function denylist callback with passing decode JWT, if true @@ -210,17 +212,17 @@ def _check_token_is_revoked(self, raw_token: Dict[str,Union[str,int,bool]]) -> N if not self._has_token_in_denylist_callback(): raise RuntimeError("A token_in_denylist_callback must be provided via " - "the '@AuthJWT.token_in_denylist_loader' if " - "authjwt_denylist_enabled is 'True'") + "the '@AuthJWT.token_in_denylist_loader' if " + "authjwt_denylist_enabled is 'True'") if self._token_in_denylist_callback.__func__(raw_token): - raise RevokedTokenError(status_code=401,message="Token has been revoked") + raise RevokedTokenError(status_code=401, message="Token has been revoked") def _get_expired_time( - self, - type_token: str, - expires_time: Optional[Union[timedelta,int,bool]] = None - ) -> Union[None,int]: + self, + type_token: str, + expires_time: Optional[Union[timedelta, int, bool]] = None + ) -> Union[None, int]: """ Dynamic token expired, if expires_time is False exp claim not created @@ -229,7 +231,7 @@ def _get_expired_time( :return: duration exp claim jwt """ - if expires_time and not isinstance(expires_time, (timedelta,int,bool)): + if expires_time and not isinstance(expires_time, (timedelta, int, bool)): raise TypeError("expires_time must be between timedelta, int, bool") if expires_time is not False: @@ -252,14 +254,14 @@ def _get_expired_time( return None def create_access_token( - self, - subject: Union[str,int], - fresh: Optional[bool] = False, - algorithm: Optional[str] = None, - headers: Optional[Dict] = None, - expires_time: Optional[Union[timedelta,int,bool]] = None, - audience: Optional[Union[str,Sequence[str]]] = None, - user_claims: Optional[Dict] = {} + self, + subject: Union[str, int], + fresh: Optional[bool] = False, + algorithm: Optional[str] = None, + headers: Optional[Dict] = None, + expires_time: Optional[Union[timedelta, int, bool]] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {} ) -> str: """ Create a access token with 15 minutes for expired time (default), @@ -270,7 +272,7 @@ def create_access_token( return self._create_token( subject=subject, type_token="access", - exp_time=self._get_expired_time("access",expires_time), + exp_time=self._get_expired_time("access", expires_time), fresh=fresh, algorithm=algorithm, headers=headers, @@ -280,13 +282,13 @@ def create_access_token( ) def create_refresh_token( - self, - subject: Union[str,int], - algorithm: Optional[str] = None, - headers: Optional[Dict] = None, - expires_time: Optional[Union[timedelta,int,bool]] = None, - audience: Optional[Union[str,Sequence[str]]] = None, - user_claims: Optional[Dict] = {} + self, + subject: Union[str, int], + algorithm: Optional[str] = None, + headers: Optional[Dict] = None, + expires_time: Optional[Union[timedelta, int, bool]] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {} ) -> str: """ Create a refresh token with 30 days for expired time (default), @@ -297,14 +299,14 @@ def create_refresh_token( return self._create_token( subject=subject, type_token="refresh", - exp_time=self._get_expired_time("refresh",expires_time), + exp_time=self._get_expired_time("refresh", expires_time), algorithm=algorithm, headers=headers, audience=audience, user_claims=user_claims ) - def _get_csrf_token(self,encoded_token: str) -> str: + def _get_csrf_token(self, encoded_token: str) -> str: """ Returns the CSRF double submit token from an encoded JWT. @@ -314,10 +316,10 @@ def _get_csrf_token(self,encoded_token: str) -> str: return self._verified_token(encoded_token)['csrf'] def set_access_cookies( - self, - encoded_access_token: str, - response: Optional[Response] = None, - max_age: Optional[int] = None + self, + encoded_access_token: str, + response: Optional[Response] = None, + max_age: Optional[int] = None ) -> None: """ Configures the response to set access token in a cookie. @@ -332,9 +334,9 @@ def set_access_cookies( "set_access_cookies() called without 'authjwt_token_location' configured to use cookies" ) - if max_age and not isinstance(max_age,int): + if max_age and not isinstance(max_age, int): raise TypeError("max_age must be a integer") - if response and not isinstance(response,Response): + if response and not isinstance(response, Response): raise TypeError("The response must be an object response FastAPI") response = response or self._response @@ -365,10 +367,10 @@ def set_access_cookies( ) def set_refresh_cookies( - self, - encoded_refresh_token: str, - response: Optional[Response] = None, - max_age: Optional[int] = None + self, + encoded_refresh_token: str, + response: Optional[Response] = None, + max_age: Optional[int] = None ) -> None: """ Configures the response to set refresh token in a cookie. @@ -383,9 +385,9 @@ def set_refresh_cookies( "set_refresh_cookies() called without 'authjwt_token_location' configured to use cookies" ) - if max_age and not isinstance(max_age,int): + if max_age and not isinstance(max_age, int): raise TypeError("max_age must be a integer") - if response and not isinstance(response,Response): + if response and not isinstance(response, Response): raise TypeError("The response must be an object response FastAPI") response = response or self._response @@ -415,7 +417,7 @@ def set_refresh_cookies( samesite=self._cookie_samesite ) - def unset_jwt_cookies(self,response: Optional[Response] = None) -> None: + def unset_jwt_cookies(self, response: Optional[Response] = None) -> None: """ Unset (delete) all jwt stored in a cookie @@ -424,7 +426,7 @@ def unset_jwt_cookies(self,response: Optional[Response] = None) -> None: self.unset_access_cookies(response) self.unset_refresh_cookies(response) - def unset_access_cookies(self,response: Optional[Response] = None) -> None: + def unset_access_cookies(self, response: Optional[Response] = None) -> None: """ Remove access token and access CSRF double submit from the response cookies @@ -435,7 +437,7 @@ def unset_access_cookies(self,response: Optional[Response] = None) -> None: "unset_access_cookies() called without 'authjwt_token_location' configured to use cookies" ) - if response and not isinstance(response,Response): + if response and not isinstance(response, Response): raise TypeError("The response must be an object response FastAPI") response = response or self._response @@ -453,7 +455,7 @@ def unset_access_cookies(self,response: Optional[Response] = None) -> None: domain=self._cookie_domain ) - def unset_refresh_cookies(self,response: Optional[Response] = None) -> None: + def unset_refresh_cookies(self, response: Optional[Response] = None) -> None: """ Remove refresh token and refresh CSRF double submit from the response cookies @@ -464,7 +466,7 @@ def unset_refresh_cookies(self,response: Optional[Response] = None) -> None: "unset_refresh_cookies() called without 'authjwt_token_location' configured to use cookies" ) - if response and not isinstance(response,Response): + if response and not isinstance(response, Response): raise TypeError("The response must be an object response FastAPI") response = response or self._response @@ -483,9 +485,9 @@ def unset_refresh_cookies(self,response: Optional[Response] = None) -> None: ) def _verify_and_get_jwt_optional_in_cookies( - self, - request: Union[Request,WebSocket], - csrf_token: Optional[str] = None, + self, + request: Union[Request, WebSocket], + csrf_token: Optional[str] = None, ) -> "AuthJWT": """ Optionally check if cookies have a valid access token. if an access token present in @@ -495,7 +497,7 @@ def _verify_and_get_jwt_optional_in_cookies( :param request: for identity get cookies from HTTP or WebSocket :param csrf_token: the CSRF double submit token """ - if not isinstance(request,(Request,WebSocket)): + if not isinstance(request, (Request, WebSocket)): raise TypeError("request must be an instance of 'Request' or 'WebSocket'") cookie_key = self._access_cookie_key @@ -505,7 +507,7 @@ def _verify_and_get_jwt_optional_in_cookies( if cookie and self._cookie_csrf_protect and not csrf_token: if isinstance(request, WebSocket) or request.method in self._csrf_methods: - raise CSRFError(status_code=401,message="Missing CSRF Token") + raise CSRFError(status_code=401, message="Missing CSRF Token") # set token from cookie and verify jwt self._token = cookie @@ -516,16 +518,16 @@ def _verify_and_get_jwt_optional_in_cookies( if decoded_token and self._cookie_csrf_protect and csrf_token: if isinstance(request, WebSocket) or request.method in self._csrf_methods: if 'csrf' not in decoded_token: - raise JWTDecodeError(status_code=422,message="Missing claim: csrf") - if not hmac.compare_digest(csrf_token,decoded_token['csrf']): - raise CSRFError(status_code=401,message="CSRF double submit tokens do not match") + raise JWTDecodeError(status_code=422, message="Missing claim: csrf") + if not hmac.compare_digest(csrf_token, decoded_token['csrf']): + raise CSRFError(status_code=401, message="CSRF double submit tokens do not match") def _verify_and_get_jwt_in_cookies( - self, - type_token: str, - request: Union[Request,WebSocket], - csrf_token: Optional[str] = None, - fresh: Optional[bool] = False, + self, + type_token: str, + request: Union[Request, WebSocket], + csrf_token: Optional[str] = None, + fresh: Optional[bool] = False, ) -> "AuthJWT": """ Check if cookies have a valid access or refresh token. if an token present in @@ -537,9 +539,9 @@ def _verify_and_get_jwt_in_cookies( :param csrf_token: the CSRF double submit token :param fresh: check freshness token if True """ - if type_token not in ['access','refresh']: + if type_token not in ['access', 'refresh']: raise ValueError("type_token must be between 'access' or 'refresh'") - if not isinstance(request,(Request,WebSocket)): + if not isinstance(request, (Request, WebSocket)): raise TypeError("request must be an instance of 'Request' or 'WebSocket'") if type_token == 'access': @@ -554,26 +556,26 @@ def _verify_and_get_jwt_in_cookies( csrf_token = request.headers.get(self._refresh_csrf_header_name) if not cookie: - raise MissingTokenError(status_code=401,message="Missing cookie {}".format(cookie_key)) + raise MissingTokenError(status_code=401, message="Missing cookie {}".format(cookie_key)) if self._cookie_csrf_protect and not csrf_token: if isinstance(request, WebSocket) or request.method in self._csrf_methods: - raise CSRFError(status_code=401,message="Missing CSRF Token") + raise CSRFError(status_code=401, message="Missing CSRF Token") # set token from cookie and verify jwt self._token = cookie - self._verify_jwt_in_request(self._token,type_token,'cookies',fresh) + self._verify_jwt_in_request(self._token, type_token, 'cookies', fresh) decoded_token = self.get_raw_jwt() if self._cookie_csrf_protect and csrf_token: if isinstance(request, WebSocket) or request.method in self._csrf_methods: if 'csrf' not in decoded_token: - raise JWTDecodeError(status_code=422,message="Missing claim: csrf") - if not hmac.compare_digest(csrf_token,decoded_token['csrf']): - raise CSRFError(status_code=401,message="CSRF double submit tokens do not match") + raise JWTDecodeError(status_code=422, message="Missing claim: csrf") + if not hmac.compare_digest(csrf_token, decoded_token['csrf']): + raise CSRFError(status_code=401, message="CSRF double submit tokens do not match") - def _verify_jwt_optional_in_request(self,token: str) -> None: + def _verify_jwt_optional_in_request(self, token: str) -> None: """ Optionally check if this request has a valid access token @@ -582,14 +584,14 @@ def _verify_jwt_optional_in_request(self,token: str) -> None: if token: self._verifying_token(token) if token and self.get_raw_jwt(token)['type'] != 'access': - raise AccessTokenRequired(status_code=422,message="Only access tokens are allowed") + raise AccessTokenRequired(status_code=422, message="Only access tokens are allowed") def _verify_jwt_in_request( - self, - token: str, - type_token: str, - token_from: str, - fresh: Optional[bool] = False + self, + token: str, + type_token: str, + token_from: str, + fresh: Optional[bool] = False ) -> None: """ Ensure that the requester has a valid token. this also check the freshness of the access token @@ -599,43 +601,44 @@ def _verify_jwt_in_request( :param token_from: indicate token from headers cookies, websocket :param fresh: check freshness token if True """ - if type_token not in ['access','refresh']: + if type_token not in ['access', 'refresh']: raise ValueError("type_token must be between 'access' or 'refresh'") - if token_from not in ['headers','cookies','websocket']: + if token_from not in ['headers', 'cookies', 'websocket']: raise ValueError("token_from must be between 'headers', 'cookies', 'websocket'") if not token: if token_from == 'headers': - raise MissingTokenError(status_code=401,message="Missing {} Header".format(self._header_name)) + raise MissingTokenError(status_code=401, message="Missing {} Header".format(self._header_name)) if token_from == 'websocket': - raise MissingTokenError(status_code=1008,message="Missing {} token from Query or Path".format(type_token)) + raise MissingTokenError(status_code=1008, + message="Missing {} token from Query or Path".format(type_token)) # verify jwt issuer = self._decode_issuer if type_token == 'access' else None - self._verifying_token(token,issuer) + self._verifying_token(token, issuer) if self.get_raw_jwt(token)['type'] != type_token: msg = "Only {} tokens are allowed".format(type_token) if type_token == 'access': - raise AccessTokenRequired(status_code=422,message=msg) + raise AccessTokenRequired(status_code=422, message=msg) if type_token == 'refresh': - raise RefreshTokenRequired(status_code=422,message=msg) + raise RefreshTokenRequired(status_code=422, message=msg) if fresh and not self.get_raw_jwt(token)['fresh']: - raise FreshTokenRequired(status_code=401,message="Fresh token required") + raise FreshTokenRequired(status_code=401, message="Fresh token required") - def _verifying_token(self,encoded_token: str, issuer: Optional[str] = None) -> None: + def _verifying_token(self, encoded_token: str, issuer: Optional[str] = None) -> None: """ Verified token and check if token is revoked :param encoded_token: token hash :param issuer: expected issuer in the JWT """ - raw_token = self._verified_token(encoded_token,issuer) + raw_token = self._verified_token(encoded_token, issuer) if raw_token['type'] in self._denylist_token_checks: self._check_token_is_revoked(raw_token) - def _verified_token(self,encoded_token: str, issuer: Optional[str] = None) -> Dict[str,Union[str,int,bool]]: + def _verified_token(self, encoded_token: str, issuer: Optional[str] = None) -> Dict[str, Union[str, int, bool]]: """ Verified token and catch all error from jwt package and return decode token @@ -649,10 +652,10 @@ def _verified_token(self,encoded_token: str, issuer: Optional[str] = None) -> Di try: unverified_headers = self.get_unverified_jwt_headers(encoded_token) except Exception as err: - raise InvalidHeaderError(status_code=422,message=str(err)) + raise InvalidHeaderError(status_code=422, message=str(err)) try: - secret_key = self._get_secret_key(unverified_headers['alg'],"decode") + secret_key = self._get_secret_key(unverified_headers['alg'], "decode") except Exception: raise @@ -666,14 +669,14 @@ def _verified_token(self,encoded_token: str, issuer: Optional[str] = None) -> Di algorithms=algorithms ) except Exception as err: - raise JWTDecodeError(status_code=422,message=str(err)) + raise JWTDecodeError(status_code=422, message=str(err)) def jwt_required( - self, - auth_from: str = "request", - token: Optional[str] = None, - websocket: Optional[WebSocket] = None, - csrf_token: Optional[str] = None, + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, ) -> None: """ Only access token can access this function @@ -686,27 +689,29 @@ def jwt_required( its must be passing csrf_token manually and can achieve by Query Url or Path """ if auth_from == "websocket": - if websocket: self._verify_and_get_jwt_in_cookies('access',websocket,csrf_token) - else: self._verify_jwt_in_request(token,'access','websocket') + if websocket: + self._verify_and_get_jwt_in_cookies('access', websocket, csrf_token) + else: + self._verify_jwt_in_request(token, 'access', 'websocket') if auth_from == "request": if len(self._token_location) == 2: if self._token and self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'access','headers') + self._verify_jwt_in_request(self._token, 'access', 'headers') if not self._token and self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('access',self._request) + self._verify_and_get_jwt_in_cookies('access', self._request) else: if self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'access','headers') + self._verify_jwt_in_request(self._token, 'access', 'headers') if self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('access',self._request) + self._verify_and_get_jwt_in_cookies('access', self._request) def jwt_optional( - self, - auth_from: str = "request", - token: Optional[str] = None, - websocket: Optional[WebSocket] = None, - csrf_token: Optional[str] = None, + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, ) -> None: """ If an access token in present in the request you can get data from get_raw_jwt() or get_jwt_subject(), @@ -721,8 +726,10 @@ def jwt_optional( its must be passing csrf_token manually and can achieve by Query Url or Path """ if auth_from == "websocket": - if websocket: self._verify_and_get_jwt_optional_in_cookies(websocket,csrf_token) - else: self._verify_jwt_optional_in_request(token) + if websocket: + self._verify_and_get_jwt_optional_in_cookies(websocket, csrf_token) + else: + self._verify_jwt_optional_in_request(token) if auth_from == "request": if len(self._token_location) == 2: @@ -737,11 +744,11 @@ def jwt_optional( self._verify_and_get_jwt_optional_in_cookies(self._request) def jwt_refresh_token_required( - self, - auth_from: str = "request", - token: Optional[str] = None, - websocket: Optional[WebSocket] = None, - csrf_token: Optional[str] = None, + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, ) -> None: """ This function will ensure that the requester has a valid refresh token @@ -754,27 +761,29 @@ def jwt_refresh_token_required( its must be passing csrf_token manually and can achieve by Query Url or Path """ if auth_from == "websocket": - if websocket: self._verify_and_get_jwt_in_cookies('refresh',websocket,csrf_token) - else: self._verify_jwt_in_request(token,'refresh','websocket') + if websocket: + self._verify_and_get_jwt_in_cookies('refresh', websocket, csrf_token) + else: + self._verify_jwt_in_request(token, 'refresh', 'websocket') if auth_from == "request": if len(self._token_location) == 2: if self._token and self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'refresh','headers') + self._verify_jwt_in_request(self._token, 'refresh', 'headers') if not self._token and self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('refresh',self._request) + self._verify_and_get_jwt_in_cookies('refresh', self._request) else: if self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'refresh','headers') + self._verify_jwt_in_request(self._token, 'refresh', 'headers') if self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('refresh',self._request) + self._verify_and_get_jwt_in_cookies('refresh', self._request) def fresh_jwt_required( - self, - auth_from: str = "request", - token: Optional[str] = None, - websocket: Optional[WebSocket] = None, - csrf_token: Optional[str] = None, + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, ) -> None: """ This function will ensure that the requester has a valid access token and fresh token @@ -787,22 +796,24 @@ def fresh_jwt_required( its must be passing csrf_token manually and can achieve by Query Url or Path """ if auth_from == "websocket": - if websocket: self._verify_and_get_jwt_in_cookies('access',websocket,csrf_token,True) - else: self._verify_jwt_in_request(token,'access','websocket',True) + if websocket: + self._verify_and_get_jwt_in_cookies('access', websocket, csrf_token, True) + else: + self._verify_jwt_in_request(token, 'access', 'websocket', True) if auth_from == "request": if len(self._token_location) == 2: if self._token and self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'access','headers',True) + self._verify_jwt_in_request(self._token, 'access', 'headers', True) if not self._token and self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('access',self._request,fresh=True) + self._verify_and_get_jwt_in_cookies('access', self._request, fresh=True) else: if self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'access','headers',True) + self._verify_jwt_in_request(self._token, 'access', 'headers', True) if self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('access',self._request,fresh=True) + self._verify_and_get_jwt_in_cookies('access', self._request, fresh=True) - def get_raw_jwt(self,encoded_token: Optional[str] = None) -> Optional[Dict[str,Union[str,int,bool]]]: + def get_raw_jwt(self, encoded_token: Optional[str] = None) -> Optional[Dict[str, Union[str, int, bool]]]: """ this will return the python dictionary which has all of the claims of the JWT that is accessing the endpoint. If no JWT is currently present, return None instead @@ -816,7 +827,7 @@ def get_raw_jwt(self,encoded_token: Optional[str] = None) -> Optional[Dict[str,U return self._verified_token(token) return None - def get_jti(self,encoded_token: str) -> str: + def get_jti(self, encoded_token: str) -> str: """ Returns the JTI (unique identifier) of an encoded JWT @@ -825,7 +836,7 @@ def get_jti(self,encoded_token: str) -> str: """ return self._verified_token(encoded_token)['jti'] - def get_jwt_subject(self) -> Optional[Union[str,int]]: + def get_jwt_subject(self) -> Optional[Union[str, int]]: """ this will return the subject of the JWT that is accessing this endpoint. If no JWT is present, `None` is returned instead. @@ -836,7 +847,7 @@ def get_jwt_subject(self) -> Optional[Union[str,int]]: return self._verified_token(self._token)['sub'] return None - def get_unverified_jwt_headers(self,encoded_token: Optional[str] = None) -> dict: + def get_unverified_jwt_headers(self, encoded_token: Optional[str] = None) -> dict: """ Returns the Headers of an encoded JWT without verifying the actual signature of JWT diff --git a/fastapi_jwt_auth/config.py b/fastapi_jwt_auth/config.py index c81b50c..2fd4eb3 100644 --- a/fastapi_jwt_auth/config.py +++ b/fastapi_jwt_auth/config.py @@ -2,12 +2,12 @@ from typing import Optional, Union, Sequence, List from pydantic import ( BaseModel, - validator, StrictBool, StrictInt, - StrictStr + StrictStr, field_validator ) + class LoadConfig(BaseModel): authjwt_token_location: Optional[Sequence[StrictStr]] = {'headers'} authjwt_secret_key: Optional[StrictStr] = None @@ -15,16 +15,16 @@ class LoadConfig(BaseModel): authjwt_private_key: Optional[StrictStr] = None authjwt_algorithm: Optional[StrictStr] = "HS256" authjwt_decode_algorithms: Optional[List[StrictStr]] = None - authjwt_decode_leeway: Optional[Union[StrictInt,timedelta]] = 0 + authjwt_decode_leeway: Optional[Union[StrictInt, timedelta]] = 0 authjwt_encode_issuer: Optional[StrictStr] = None authjwt_decode_issuer: Optional[StrictStr] = None - authjwt_decode_audience: Optional[Union[StrictStr,Sequence[StrictStr]]] = None + authjwt_decode_audience: Optional[Union[StrictStr, Sequence[StrictStr]]] = None authjwt_denylist_enabled: Optional[StrictBool] = False - authjwt_denylist_token_checks: Optional[Sequence[StrictStr]] = {'access','refresh'} + authjwt_denylist_token_checks: Optional[Sequence[StrictStr]] = {'access', 'refresh'} authjwt_header_name: Optional[StrictStr] = "Authorization" authjwt_header_type: Optional[StrictStr] = "Bearer" - authjwt_access_token_expires: Optional[Union[StrictBool,StrictInt,timedelta]] = timedelta(minutes=15) - authjwt_refresh_token_expires: Optional[Union[StrictBool,StrictInt,timedelta]] = timedelta(days=30) + authjwt_access_token_expires: Optional[Union[StrictBool, StrictInt, timedelta]] = timedelta(minutes=15) + authjwt_refresh_token_expires: Optional[Union[StrictBool, StrictInt, timedelta]] = timedelta(days=30) # option for create cookies authjwt_access_cookie_key: Optional[StrictStr] = "access_token_cookie" authjwt_refresh_cookie_key: Optional[StrictStr] = "refresh_token_cookie" @@ -42,44 +42,48 @@ class LoadConfig(BaseModel): authjwt_refresh_csrf_cookie_path: Optional[StrictStr] = "/" authjwt_access_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" authjwt_refresh_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" - authjwt_csrf_methods: Optional[Sequence[StrictStr]] = {'POST','PUT','PATCH','DELETE'} + authjwt_csrf_methods: Optional[Sequence[StrictStr]] = {'POST', 'PUT', 'PATCH', 'DELETE'} - @validator('authjwt_access_token_expires') + @field_validator('authjwt_access_token_expires') def validate_access_token_expires(cls, v): if v is True: raise ValueError("The 'authjwt_access_token_expires' only accept value False (bool)") return v - @validator('authjwt_refresh_token_expires') + @field_validator('authjwt_refresh_token_expires') def validate_refresh_token_expires(cls, v): if v is True: raise ValueError("The 'authjwt_refresh_token_expires' only accept value False (bool)") return v - @validator('authjwt_denylist_token_checks', each_item=True) + @field_validator('authjwt_denylist_token_checks') def validate_denylist_token_checks(cls, v): - if v not in ['access','refresh']: - raise ValueError("The 'authjwt_denylist_token_checks' must be between 'access' or 'refresh'") + for item in v: + if item not in ['access', 'refresh']: + raise ValueError("The 'authjwt_denylist_token_checks' must be between 'access' or 'refresh'") return v - @validator('authjwt_token_location', each_item=True) + @field_validator('authjwt_token_location') def validate_token_location(cls, v): - if v not in ['headers','cookies']: - raise ValueError("The 'authjwt_token_location' must be between 'headers' or 'cookies'") + for location in v: + if location not in ['headers', 'cookies']: + raise ValueError("The 'authjwt_token_location' must be between 'headers' or 'cookies'") return v - @validator('authjwt_cookie_samesite') + @field_validator('authjwt_cookie_samesite') def validate_cookie_samesite(cls, v): - if v not in ['strict','lax','none']: - raise ValueError("The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none'") + for item in v: + if item not in ['strict', 'lax', 'none']: + raise ValueError("The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none'") return v - @validator('authjwt_csrf_methods', each_item=True) + @field_validator('authjwt_csrf_methods') def validate_csrf_methods(cls, v): - if v.upper() not in {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}: - raise ValueError("The 'authjwt_csrf_methods' must be between http request methods") + for item in v: + if item.upper() not in {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}: + raise ValueError("The 'authjwt_csrf_methods' must be between http request methods") return v.upper() class Config: - min_anystr_length = 1 - anystr_strip_whitespace = True + str_min_length = 1 + str_strip_whitespace = True diff --git a/fastapi_jwt_auth/exceptions.py b/fastapi_jwt_auth/exceptions.py index 590423c..99f4a84 100644 --- a/fastapi_jwt_auth/exceptions.py +++ b/fastapi_jwt_auth/exceptions.py @@ -4,69 +4,85 @@ class AuthJWTException(Exception): """ pass + class InvalidHeaderError(AuthJWTException): """ An error getting jwt in header or jwt header information from a request """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class JWTDecodeError(AuthJWTException): """ An error decoding a JWT """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class CSRFError(AuthJWTException): """ An error with CSRF protection """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class MissingTokenError(AuthJWTException): """ Error raised when token not found """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class RevokedTokenError(AuthJWTException): """ Error raised when a revoked token attempt to access a protected endpoint """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class AccessTokenRequired(AuthJWTException): """ Error raised when a valid, non-access JWT attempt to access an endpoint protected by jwt_required, jwt_optional, fresh_jwt_required """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class RefreshTokenRequired(AuthJWTException): """ Error raised when a valid, non-refresh JWT attempt to access an endpoint protected by jwt_refresh_token_required """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class FreshTokenRequired(AuthJWTException): """ Error raised when a valid, non-fresh JWT attempt to access an endpoint protected by fresh_jwt_required """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message