Skip to content

Commit 6e67c7e

Browse files
committed
migrate from python-jose to pyjwt
1 parent 562882b commit 6e67c7e

File tree

10 files changed

+115
-131
lines changed

10 files changed

+115
-131
lines changed

demo_project/api/api_v1/endpoints/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import Any
22

33
import httpx
4+
import jwt
45
from demo_project.api.dependencies import azure_scheme
56
from demo_project.core.config import settings
67
from fastapi import APIRouter, Depends, Request
78
from httpx import AsyncClient
8-
from jose import jwt
99

1010
router = APIRouter()
1111

@@ -47,7 +47,7 @@ async def graph_world(request: Request) -> Any: # noqa: ANN401
4747

4848
# Return all the information to the end user
4949
return (
50-
{'claims': jwt.get_unverified_claims(token=request.state.user.access_token)}
50+
{'claims': jwt.decode(request.state.user.access_token, options={'verify_signature': False})}
5151
| {'obo_response': obo_response.json()}
5252
| {'graph_response': graph}
5353
)

fastapi_azure_auth/auth.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,30 @@
33
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
44
from warnings import warn
55

6+
import jwt
67
from fastapi.exceptions import HTTPException
78
from fastapi.security import OAuth2AuthorizationCodeBearer, SecurityScopes
89
from fastapi.security.base import SecurityBase
9-
from jose import jwt
10-
from jose.exceptions import ExpiredSignatureError, JWTClaimsError, JWTError
10+
from jwt.exceptions import (
11+
ExpiredSignatureError,
12+
InvalidTokenError,
13+
InvalidAudienceError,
14+
InvalidIssuerError,
15+
InvalidIssuedAtError,
16+
ImmatureSignatureError,
17+
InvalidAlgorithmError,
18+
MissingRequiredClaimError,
19+
)
1120
from starlette.requests import Request
1221

1322
from fastapi_azure_auth.exceptions import InvalidAuth
1423
from fastapi_azure_auth.openid_config import OpenIdConfig
1524
from fastapi_azure_auth.user import User
16-
from fastapi_azure_auth.utils import is_guest
25+
from fastapi_azure_auth.utils import (
26+
is_guest,
27+
get_unverified_header,
28+
get_unverified_claims,
29+
)
1730

1831
log = logging.getLogger('fastapi_azure_auth')
1932

@@ -148,8 +161,8 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
148161
access_token = await self.oauth(request=request)
149162
try:
150163
# Extract header information of the token.
151-
header: dict[str, str] = jwt.get_unverified_header(token=access_token) or {}
152-
claims: dict[str, Any] = jwt.get_unverified_claims(token=access_token) or {}
164+
header: dict[str, Any] = get_unverified_header(access_token)
165+
claims: dict[str, Any] = get_unverified_claims(access_token)
153166
except Exception as error:
154167
log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True)
155168
raise InvalidAuth(detail='Invalid token format') from error
@@ -180,48 +193,44 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
180193
try:
181194
if key := self.openid_config.signing_keys.get(header.get('kid', '')):
182195
# We require and validate all fields in an Azure AD token
196+
required_claims = ['exp', 'aud', 'iat', 'nbf', 'sub']
197+
if self.validate_iss:
198+
required_claims.append('iss')
199+
183200
options = {
184201
'verify_signature': True,
185202
'verify_aud': True,
186203
'verify_iat': True,
187204
'verify_exp': True,
188205
'verify_nbf': True,
189206
'verify_iss': self.validate_iss,
190-
'verify_sub': True,
191-
'verify_jti': True,
192-
'verify_at_hash': True,
193-
'require_aud': True,
194-
'require_iat': True,
195-
'require_exp': True,
196-
'require_nbf': True,
197-
'require_iss': self.validate_iss,
198-
'require_sub': True,
199-
'require_jti': False,
200-
'require_at_hash': False,
201-
'leeway': self.leeway,
202207
}
203208
# Validate token
204-
token = jwt.decode(
205-
access_token,
209+
token = self.validate(
210+
access_token=access_token,
211+
iss=iss,
206212
key=key,
207-
algorithms=['RS256'],
208-
audience=self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}',
209-
issuer=iss,
210-
options=options,
211-
)
213+
options=options)
212214
# Attach the user to the request. Can be accessed through `request.state.user`
213215
user: User = User(
214216
**{**token, 'claims': token, 'access_token': access_token, 'is_guest': user_is_guest}
215217
)
216218
request.state.user = user
217219
return user
218-
except JWTClaimsError as error:
220+
except (
221+
InvalidAudienceError,
222+
InvalidIssuerError,
223+
InvalidIssuedAtError,
224+
ImmatureSignatureError,
225+
InvalidAlgorithmError,
226+
MissingRequiredClaimError
227+
) as error:
219228
log.info('Token contains invalid claims. %s', error)
220229
raise InvalidAuth(detail='Token contains invalid claims') from error
221230
except ExpiredSignatureError as error:
222231
log.info('Token signature has expired. %s', error)
223232
raise InvalidAuth(detail='Token signature has expired') from error
224-
except JWTError as error:
233+
except InvalidTokenError as error:
225234
log.warning('Invalid token. Error: %s', error, exc_info=True)
226235
raise InvalidAuth(detail='Unable to validate token') from error
227236
except Exception as error:
@@ -235,6 +244,20 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
235244
return None
236245
raise
237246

247+
def validate(self, access_token: str, key: str, iss: str, options: Dict[str, Any]) -> Dict[str, Any]:
248+
alg = 'RS256'
249+
aud = self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}'
250+
return jwt.decode(
251+
access_token,
252+
key=key,
253+
algorithms=[alg],
254+
audience=aud,
255+
issuer=iss,
256+
leeway=self.leeway,
257+
options=options,
258+
)
259+
260+
238261

239262
class SingleTenantAzureAuthorizationCodeBearer(AzureAuthorizationCodeBearerBase):
240263
def __init__(

fastapi_azure_auth/openid_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from datetime import datetime, timedelta
33
from typing import Any, Dict, List, Optional
44

5+
import jwt
56
from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes as KeyTypes
67
from fastapi import HTTPException, status
78
from httpx import AsyncClient
8-
from jose import jwk
99

1010
log = logging.getLogger('fastapi_azure_auth')
1111

@@ -98,6 +98,6 @@ def _load_keys(self, keys: List[Dict[str, Any]]) -> None:
9898
for key in keys:
9999
if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption
100100
log.debug('Loading public key from certificate: %s', key)
101-
cert_obj = jwk.construct(key, 'RS256')
101+
cert_obj = jwt.PyJWK(key, 'RS256')
102102
if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it.
103-
self.signing_keys[kid] = cert_obj
103+
self.signing_keys[kid] = cert_obj.key

fastapi_azure_auth/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any, Dict
22

3+
import jwt
4+
35

46
def is_guest(claims: Dict[str, Any]) -> bool:
57
"""
@@ -12,3 +14,21 @@ def is_guest(claims: Dict[str, Any]) -> bool:
1214
claims_iss: str = claims.get('iss', '')
1315
idp: str = claims.get('idp', claims_iss)
1416
return idp != claims_iss
17+
18+
19+
def get_unverified_header(access_token: str | None) -> Dict[str, Any]:
20+
"""
21+
Get header from the access token without verifying the signature
22+
"""
23+
if access_token is None:
24+
return {}
25+
return jwt.get_unverified_header(access_token)
26+
27+
28+
def get_unverified_claims(access_token: str | None) -> Dict[str, Any]:
29+
"""
30+
Get claims from the access token without verifying the signature
31+
"""
32+
if access_token is None:
33+
return {}
34+
return jwt.decode(access_token, options={'verify_signature': False})

poetry.lock

Lines changed: 19 additions & 66 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ classifiers = [
4343
python = "^3.8"
4444
fastapi = ">0.68.0"
4545
cryptography = ">=40.0.1"
46-
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
4746
httpx = ">0.18.2"
47+
pyjwt = "^2.8.0"
4848

4949

5050
[tool.poetry.group.dev.dependencies]

tests/multi_tenant/test_multi_tenant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919

2020
from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer
21+
from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase
2122
from fastapi_azure_auth.exceptions import InvalidAuth
2223

2324

@@ -283,7 +284,7 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys):
283284

284285
@pytest.mark.anyio
285286
async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker):
286-
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
287+
mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol'))
287288
async with AsyncClient(
288289
app=app,
289290
base_url='http://test',

tests/multi_tenant_b2c/test_multi_tenant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
build_evil_access_token,
1515
)
1616

17+
from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase
1718
from fastapi_azure_auth.openid_config import OpenIdConfig
1819

1920

@@ -259,7 +260,7 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys):
259260

260261
@pytest.mark.anyio
261262
async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker):
262-
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
263+
mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol'))
263264
mocker.patch.object(OpenIdConfig, 'load_config', return_value=True)
264265
async with AsyncClient(
265266
app=app,

tests/single_tenant/test_single_tenant_v1_v2_tokens.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
build_evil_access_token,
1515
)
1616

17+
from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase
18+
1719

1820
def current_version(current_cases) -> int:
1921
return current_cases['single_tenant_app']['token_version'].params['version']
@@ -349,7 +351,7 @@ async def test_only_header(single_tenant_app, mock_openid_and_keys_v1_v2):
349351
@pytest.mark.anyio
350352
async def test_exception_raised(single_tenant_app, mock_openid_and_keys_v1_v2, mocker, current_cases):
351353
test_version = current_version(current_cases)
352-
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
354+
mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol'))
353355
async with AsyncClient(
354356
app=app,
355357
base_url='http://test',

0 commit comments

Comments
 (0)