1
1
import inspect
2
2
import logging
3
- from typing import Any , Awaitable , Callable , Dict , Literal , Optional
3
+ from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict , Literal , Optional
4
4
from warnings import warn
5
5
6
+ import jwt
6
7
from fastapi .exceptions import HTTPException
7
8
from fastapi .security import OAuth2AuthorizationCodeBearer , SecurityScopes
8
9
from fastapi .security .base import SecurityBase
9
- from jose import jwt
10
- from jose .exceptions import ExpiredSignatureError , JWTClaimsError , JWTError
10
+ from jwt .exceptions import (
11
+ ExpiredSignatureError ,
12
+ ImmatureSignatureError ,
13
+ InvalidAudienceError ,
14
+ InvalidIssuedAtError ,
15
+ InvalidIssuerError ,
16
+ InvalidTokenError ,
17
+ MissingRequiredClaimError ,
18
+ )
11
19
from starlette .requests import Request
12
20
13
21
from fastapi_azure_auth .exceptions import InvalidAuth
14
22
from fastapi_azure_auth .openid_config import OpenIdConfig
15
23
from fastapi_azure_auth .user import User
16
- from fastapi_azure_auth .utils import is_guest
24
+ from fastapi_azure_auth .utils import get_unverified_claims , get_unverified_header , is_guest
25
+
26
+ if TYPE_CHECKING : # pragma: no cover
27
+ from jwt .algorithms import AllowedPublicKeys
17
28
18
29
log = logging .getLogger ('fastapi_azure_auth' )
19
30
@@ -145,11 +156,13 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
145
156
Extends call to also validate the token.
146
157
"""
147
158
try :
148
- access_token = await self .oauth ( request = request )
159
+ access_token = await self .extract_access_token ( request )
149
160
try :
161
+ if access_token is None :
162
+ raise Exception ('No access token provided' )
150
163
# Extract header information of the token.
151
- header : dict [str , str ] = jwt . get_unverified_header (token = access_token ) or {}
152
- claims : dict [str , Any ] = jwt . get_unverified_claims (token = access_token ) or {}
164
+ header : dict [str , Any ] = get_unverified_header (access_token )
165
+ claims : dict [str , Any ] = get_unverified_claims (access_token )
153
166
except Exception as error :
154
167
log .warning ('Malformed token received. %s. Error: %s' , access_token , error , exc_info = True )
155
168
raise InvalidAuth (detail = 'Invalid token format' ) from error
@@ -180,48 +193,40 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
180
193
try :
181
194
if key := self .openid_config .signing_keys .get (header .get ('kid' , '' )):
182
195
# We require and validate all fields in an Azure AD token
196
+ required_claims = ['exp' , 'aud' , 'iat' , 'nbf' , 'sub' ]
197
+ if self .validate_iss :
198
+ required_claims .append ('iss' )
199
+
183
200
options = {
184
201
'verify_signature' : True ,
185
202
'verify_aud' : True ,
186
203
'verify_iat' : True ,
187
204
'verify_exp' : True ,
188
205
'verify_nbf' : True ,
189
206
'verify_iss' : self .validate_iss ,
190
- 'verify_sub' : True ,
191
- 'verify_jti' : True ,
192
- 'verify_at_hash' : True ,
193
- 'require_aud' : True ,
194
- 'require_iat' : True ,
195
- 'require_exp' : True ,
196
- 'require_nbf' : True ,
197
- 'require_iss' : self .validate_iss ,
198
- 'require_sub' : True ,
199
- 'require_jti' : False ,
200
- 'require_at_hash' : False ,
201
- 'leeway' : self .leeway ,
207
+ 'require' : required_claims ,
202
208
}
203
209
# Validate token
204
- token = jwt .decode (
205
- access_token ,
206
- key = key ,
207
- algorithms = ['RS256' ],
208
- audience = self .app_client_id if self .token_version == 2 else f'api://{ self .app_client_id } ' ,
209
- issuer = iss ,
210
- options = options ,
211
- )
210
+ token = self .validate (access_token = access_token , iss = iss , key = key , options = options )
212
211
# Attach the user to the request. Can be accessed through `request.state.user`
213
212
user : User = User (
214
213
** {** token , 'claims' : token , 'access_token' : access_token , 'is_guest' : user_is_guest }
215
214
)
216
215
request .state .user = user
217
216
return user
218
- except JWTClaimsError as error :
217
+ except (
218
+ InvalidAudienceError ,
219
+ InvalidIssuerError ,
220
+ InvalidIssuedAtError ,
221
+ ImmatureSignatureError ,
222
+ MissingRequiredClaimError ,
223
+ ) as error :
219
224
log .info ('Token contains invalid claims. %s' , error )
220
225
raise InvalidAuth (detail = 'Token contains invalid claims' ) from error
221
226
except ExpiredSignatureError as error :
222
227
log .info ('Token signature has expired. %s' , error )
223
228
raise InvalidAuth (detail = 'Token signature has expired' ) from error
224
- except JWTError as error :
229
+ except InvalidTokenError as error :
225
230
log .warning ('Invalid token. Error: %s' , error , exc_info = True )
226
231
raise InvalidAuth (detail = 'Unable to validate token' ) from error
227
232
except Exception as error :
@@ -235,6 +240,32 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
235
240
return None
236
241
raise
237
242
243
+ async def extract_access_token (self , request : Request ) -> Optional [str ]:
244
+ """
245
+ Extracts the access token from the request.
246
+ """
247
+ return await self .oauth (request = request )
248
+
249
+ def validate (
250
+ self , access_token : str , key : 'AllowedPublicKeys' , iss : str , options : Dict [str , Any ]
251
+ ) -> Dict [str , Any ]:
252
+ """
253
+ Validates the token using the provided key and options.
254
+ """
255
+ alg = 'RS256'
256
+ aud = self .app_client_id if self .token_version == 2 else f'api://{ self .app_client_id } '
257
+ return dict (
258
+ jwt .decode (
259
+ access_token ,
260
+ key = key ,
261
+ algorithms = [alg ],
262
+ audience = aud ,
263
+ issuer = iss ,
264
+ leeway = self .leeway ,
265
+ options = options ,
266
+ )
267
+ )
268
+
238
269
239
270
class SingleTenantAzureAuthorizationCodeBearer (AzureAuthorizationCodeBearerBase ):
240
271
def __init__ (
0 commit comments