Skip to content

Commit 2f20d55

Browse files
authored
Merge pull request #194 from dvdalilue/main
2 parents 562882b + 2b1298c commit 2f20d55

File tree

10 files changed

+140
-146
lines changed

10 files changed

+140
-146
lines changed

demo_project/api/api_v1/endpoints/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import Any
22

33
import httpx
4+
import jwt
45
from demo_project.api.dependencies import azure_scheme
56
from demo_project.core.config import settings
67
from fastapi import APIRouter, Depends, Request
78
from httpx import AsyncClient
8-
from jose import jwt
99

1010
router = APIRouter()
1111

@@ -47,7 +47,7 @@ async def graph_world(request: Request) -> Any: # noqa: ANN401
4747

4848
# Return all the information to the end user
4949
return (
50-
{'claims': jwt.get_unverified_claims(token=request.state.user.access_token)}
50+
{'claims': jwt.decode(request.state.user.access_token, options={'verify_signature': False})}
5151
| {'obo_response': obo_response.json()}
5252
| {'graph_response': graph}
5353
)

fastapi_azure_auth/auth.py

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
11
import inspect
22
import logging
3-
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
3+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Literal, Optional
44
from warnings import warn
55

6+
import jwt
67
from fastapi.exceptions import HTTPException
78
from fastapi.security import OAuth2AuthorizationCodeBearer, SecurityScopes
89
from fastapi.security.base import SecurityBase
9-
from jose import jwt
10-
from jose.exceptions import ExpiredSignatureError, JWTClaimsError, JWTError
10+
from jwt.exceptions import (
11+
ExpiredSignatureError,
12+
ImmatureSignatureError,
13+
InvalidAudienceError,
14+
InvalidIssuedAtError,
15+
InvalidIssuerError,
16+
InvalidTokenError,
17+
MissingRequiredClaimError,
18+
)
1119
from starlette.requests import Request
1220

1321
from fastapi_azure_auth.exceptions import InvalidAuth
1422
from fastapi_azure_auth.openid_config import OpenIdConfig
1523
from fastapi_azure_auth.user import User
16-
from fastapi_azure_auth.utils import is_guest
24+
from fastapi_azure_auth.utils import get_unverified_claims, get_unverified_header, is_guest
25+
26+
if TYPE_CHECKING: # pragma: no cover
27+
from jwt.algorithms import AllowedPublicKeys
1728

1829
log = logging.getLogger('fastapi_azure_auth')
1930

@@ -145,11 +156,13 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
145156
Extends call to also validate the token.
146157
"""
147158
try:
148-
access_token = await self.oauth(request=request)
159+
access_token = await self.extract_access_token(request)
149160
try:
161+
if access_token is None:
162+
raise Exception('No access token provided')
150163
# Extract header information of the token.
151-
header: dict[str, str] = jwt.get_unverified_header(token=access_token) or {}
152-
claims: dict[str, Any] = jwt.get_unverified_claims(token=access_token) or {}
164+
header: dict[str, Any] = get_unverified_header(access_token)
165+
claims: dict[str, Any] = get_unverified_claims(access_token)
153166
except Exception as error:
154167
log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True)
155168
raise InvalidAuth(detail='Invalid token format') from error
@@ -180,48 +193,40 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
180193
try:
181194
if key := self.openid_config.signing_keys.get(header.get('kid', '')):
182195
# We require and validate all fields in an Azure AD token
196+
required_claims = ['exp', 'aud', 'iat', 'nbf', 'sub']
197+
if self.validate_iss:
198+
required_claims.append('iss')
199+
183200
options = {
184201
'verify_signature': True,
185202
'verify_aud': True,
186203
'verify_iat': True,
187204
'verify_exp': True,
188205
'verify_nbf': True,
189206
'verify_iss': self.validate_iss,
190-
'verify_sub': True,
191-
'verify_jti': True,
192-
'verify_at_hash': True,
193-
'require_aud': True,
194-
'require_iat': True,
195-
'require_exp': True,
196-
'require_nbf': True,
197-
'require_iss': self.validate_iss,
198-
'require_sub': True,
199-
'require_jti': False,
200-
'require_at_hash': False,
201-
'leeway': self.leeway,
207+
'require': required_claims,
202208
}
203209
# Validate token
204-
token = jwt.decode(
205-
access_token,
206-
key=key,
207-
algorithms=['RS256'],
208-
audience=self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}',
209-
issuer=iss,
210-
options=options,
211-
)
210+
token = self.validate(access_token=access_token, iss=iss, key=key, options=options)
212211
# Attach the user to the request. Can be accessed through `request.state.user`
213212
user: User = User(
214213
**{**token, 'claims': token, 'access_token': access_token, 'is_guest': user_is_guest}
215214
)
216215
request.state.user = user
217216
return user
218-
except JWTClaimsError as error:
217+
except (
218+
InvalidAudienceError,
219+
InvalidIssuerError,
220+
InvalidIssuedAtError,
221+
ImmatureSignatureError,
222+
MissingRequiredClaimError,
223+
) as error:
219224
log.info('Token contains invalid claims. %s', error)
220225
raise InvalidAuth(detail='Token contains invalid claims') from error
221226
except ExpiredSignatureError as error:
222227
log.info('Token signature has expired. %s', error)
223228
raise InvalidAuth(detail='Token signature has expired') from error
224-
except JWTError as error:
229+
except InvalidTokenError as error:
225230
log.warning('Invalid token. Error: %s', error, exc_info=True)
226231
raise InvalidAuth(detail='Unable to validate token') from error
227232
except Exception as error:
@@ -235,6 +240,32 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
235240
return None
236241
raise
237242

243+
async def extract_access_token(self, request: Request) -> Optional[str]:
244+
"""
245+
Extracts the access token from the request.
246+
"""
247+
return await self.oauth(request=request)
248+
249+
def validate(
250+
self, access_token: str, key: 'AllowedPublicKeys', iss: str, options: Dict[str, Any]
251+
) -> Dict[str, Any]:
252+
"""
253+
Validates the token using the provided key and options.
254+
"""
255+
alg = 'RS256'
256+
aud = self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}'
257+
return dict(
258+
jwt.decode(
259+
access_token,
260+
key=key,
261+
algorithms=[alg],
262+
audience=aud,
263+
issuer=iss,
264+
leeway=self.leeway,
265+
options=options,
266+
)
267+
)
268+
238269

239270
class SingleTenantAzureAuthorizationCodeBearer(AzureAuthorizationCodeBearerBase):
240271
def __init__(

fastapi_azure_auth/openid_config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
22
from datetime import datetime, timedelta
3-
from typing import Any, Dict, List, Optional
3+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
44

5-
from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes as KeyTypes
5+
import jwt
66
from fastapi import HTTPException, status
77
from httpx import AsyncClient
8-
from jose import jwk
8+
9+
if TYPE_CHECKING: # pragma: no cover
10+
from jwt.algorithms import AllowedPublicKeys
911

1012
log = logging.getLogger('fastapi_azure_auth')
1113

@@ -27,7 +29,7 @@ def __init__(
2729
self.config_url = config_url
2830

2931
self.authorization_endpoint: str
30-
self.signing_keys: dict[str, KeyTypes]
32+
self.signing_keys: dict[str, 'AllowedPublicKeys']
3133
self.token_endpoint: str
3234
self.issuer: str
3335

@@ -98,6 +100,6 @@ def _load_keys(self, keys: List[Dict[str, Any]]) -> None:
98100
for key in keys:
99101
if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption
100102
log.debug('Loading public key from certificate: %s', key)
101-
cert_obj = jwk.construct(key, 'RS256')
103+
cert_obj = jwt.PyJWK(key, 'RS256')
102104
if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it.
103-
self.signing_keys[kid] = cert_obj
105+
self.signing_keys[kid] = cert_obj.key

fastapi_azure_auth/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any, Dict
22

3+
import jwt
4+
35

46
def is_guest(claims: Dict[str, Any]) -> bool:
57
"""
@@ -12,3 +14,17 @@ def is_guest(claims: Dict[str, Any]) -> bool:
1214
claims_iss: str = claims.get('iss', '')
1315
idp: str = claims.get('idp', claims_iss)
1416
return idp != claims_iss
17+
18+
19+
def get_unverified_header(access_token: str) -> Dict[str, Any]:
20+
"""
21+
Get header from the access token without verifying the signature
22+
"""
23+
return dict(jwt.get_unverified_header(access_token))
24+
25+
26+
def get_unverified_claims(access_token: str) -> Dict[str, Any]:
27+
"""
28+
Get claims from the access token without verifying the signature
29+
"""
30+
return dict(jwt.decode(access_token, options={'verify_signature': False}))

poetry.lock

Lines changed: 19 additions & 66 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ classifiers = [
4343
python = "^3.8"
4444
fastapi = ">0.68.0"
4545
cryptography = ">=40.0.1"
46-
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
4746
httpx = ">0.18.2"
47+
pyjwt = "^2.8.0"
4848

4949

5050
[tool.poetry.group.dev.dependencies]

tests/multi_tenant/test_multi_tenant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919

2020
from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer
21+
from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase
2122
from fastapi_azure_auth.exceptions import InvalidAuth
2223

2324

@@ -283,7 +284,7 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys):
283284

284285
@pytest.mark.anyio
285286
async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker):
286-
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
287+
mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol'))
287288
async with AsyncClient(
288289
app=app,
289290
base_url='http://test',

0 commit comments

Comments
 (0)