diff --git a/demo_project/api/api_v1/endpoints/graph.py b/demo_project/api/api_v1/endpoints/graph.py index 968e2368..96b47d22 100644 --- a/demo_project/api/api_v1/endpoints/graph.py +++ b/demo_project/api/api_v1/endpoints/graph.py @@ -1,11 +1,11 @@ from typing import Any import httpx +import jwt from demo_project.api.dependencies import azure_scheme from demo_project.core.config import settings from fastapi import APIRouter, Depends, Request from httpx import AsyncClient -from jose import jwt router = APIRouter() @@ -47,7 +47,7 @@ async def graph_world(request: Request) -> Any: # noqa: ANN401 # Return all the information to the end user return ( - {'claims': jwt.get_unverified_claims(token=request.state.user.access_token)} + {'claims': jwt.decode(request.state.user.access_token, options={'verify_signature': False})} | {'obo_response': obo_response.json()} | {'graph_response': graph} ) diff --git a/fastapi_azure_auth/auth.py b/fastapi_azure_auth/auth.py index 71c175c7..33db88b5 100644 --- a/fastapi_azure_auth/auth.py +++ b/fastapi_azure_auth/auth.py @@ -1,19 +1,30 @@ import inspect import logging -from typing import Any, Awaitable, Callable, Dict, Literal, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Literal, Optional from warnings import warn +import jwt from fastapi.exceptions import HTTPException from fastapi.security import OAuth2AuthorizationCodeBearer, SecurityScopes from fastapi.security.base import SecurityBase -from jose import jwt -from jose.exceptions import ExpiredSignatureError, JWTClaimsError, JWTError +from jwt.exceptions import ( + ExpiredSignatureError, + ImmatureSignatureError, + InvalidAudienceError, + InvalidIssuedAtError, + InvalidIssuerError, + InvalidTokenError, + MissingRequiredClaimError, +) from starlette.requests import Request from fastapi_azure_auth.exceptions import InvalidAuth from fastapi_azure_auth.openid_config import OpenIdConfig from fastapi_azure_auth.user import User -from fastapi_azure_auth.utils import is_guest +from fastapi_azure_auth.utils import get_unverified_claims, get_unverified_header, is_guest + +if TYPE_CHECKING: # pragma: no cover + from jwt.algorithms import AllowedPublicKeys log = logging.getLogger('fastapi_azure_auth') @@ -145,11 +156,13 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O Extends call to also validate the token. """ try: - access_token = await self.oauth(request=request) + access_token = await self.extract_access_token(request) try: + if access_token is None: + raise Exception('No access token provided') # Extract header information of the token. - header: dict[str, str] = jwt.get_unverified_header(token=access_token) or {} - claims: dict[str, Any] = jwt.get_unverified_claims(token=access_token) or {} + header: dict[str, Any] = get_unverified_header(access_token) + claims: dict[str, Any] = get_unverified_claims(access_token) except Exception as error: log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True) raise InvalidAuth(detail='Invalid token format') from error @@ -180,6 +193,10 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O try: if key := self.openid_config.signing_keys.get(header.get('kid', '')): # We require and validate all fields in an Azure AD token + required_claims = ['exp', 'aud', 'iat', 'nbf', 'sub'] + if self.validate_iss: + required_claims.append('iss') + options = { 'verify_signature': True, 'verify_aud': True, @@ -187,41 +204,29 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O 'verify_exp': True, 'verify_nbf': True, 'verify_iss': self.validate_iss, - 'verify_sub': True, - 'verify_jti': True, - 'verify_at_hash': True, - 'require_aud': True, - 'require_iat': True, - 'require_exp': True, - 'require_nbf': True, - 'require_iss': self.validate_iss, - 'require_sub': True, - 'require_jti': False, - 'require_at_hash': False, - 'leeway': self.leeway, + 'require': required_claims, } # Validate token - token = jwt.decode( - access_token, - key=key, - algorithms=['RS256'], - audience=self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}', - issuer=iss, - options=options, - ) + token = self.validate(access_token=access_token, iss=iss, key=key, options=options) # Attach the user to the request. Can be accessed through `request.state.user` user: User = User( **{**token, 'claims': token, 'access_token': access_token, 'is_guest': user_is_guest} ) request.state.user = user return user - except JWTClaimsError as error: + except ( + InvalidAudienceError, + InvalidIssuerError, + InvalidIssuedAtError, + ImmatureSignatureError, + MissingRequiredClaimError, + ) as error: log.info('Token contains invalid claims. %s', error) raise InvalidAuth(detail='Token contains invalid claims') from error except ExpiredSignatureError as error: log.info('Token signature has expired. %s', error) raise InvalidAuth(detail='Token signature has expired') from error - except JWTError as error: + except InvalidTokenError as error: log.warning('Invalid token. Error: %s', error, exc_info=True) raise InvalidAuth(detail='Unable to validate token') from error except Exception as error: @@ -235,6 +240,32 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O return None raise + async def extract_access_token(self, request: Request) -> Optional[str]: + """ + Extracts the access token from the request. + """ + return await self.oauth(request=request) + + def validate( + self, access_token: str, key: 'AllowedPublicKeys', iss: str, options: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Validates the token using the provided key and options. + """ + alg = 'RS256' + aud = self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}' + return dict( + jwt.decode( + access_token, + key=key, + algorithms=[alg], + audience=aud, + issuer=iss, + leeway=self.leeway, + options=options, + ) + ) + class SingleTenantAzureAuthorizationCodeBearer(AzureAuthorizationCodeBearerBase): def __init__( diff --git a/fastapi_azure_auth/openid_config.py b/fastapi_azure_auth/openid_config.py index 6b29a641..2809f396 100644 --- a/fastapi_azure_auth/openid_config.py +++ b/fastapi_azure_auth/openid_config.py @@ -1,11 +1,13 @@ import logging from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional -from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes as KeyTypes +import jwt from fastapi import HTTPException, status from httpx import AsyncClient -from jose import jwk + +if TYPE_CHECKING: # pragma: no cover + from jwt.algorithms import AllowedPublicKeys log = logging.getLogger('fastapi_azure_auth') @@ -27,7 +29,7 @@ def __init__( self.config_url = config_url self.authorization_endpoint: str - self.signing_keys: dict[str, KeyTypes] + self.signing_keys: dict[str, 'AllowedPublicKeys'] self.token_endpoint: str self.issuer: str @@ -98,6 +100,6 @@ def _load_keys(self, keys: List[Dict[str, Any]]) -> None: for key in keys: if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption log.debug('Loading public key from certificate: %s', key) - cert_obj = jwk.construct(key, 'RS256') + cert_obj = jwt.PyJWK(key, 'RS256') if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it. - self.signing_keys[kid] = cert_obj + self.signing_keys[kid] = cert_obj.key diff --git a/fastapi_azure_auth/utils.py b/fastapi_azure_auth/utils.py index f189a6bf..8a9f0547 100644 --- a/fastapi_azure_auth/utils.py +++ b/fastapi_azure_auth/utils.py @@ -1,5 +1,7 @@ from typing import Any, Dict +import jwt + def is_guest(claims: Dict[str, Any]) -> bool: """ @@ -12,3 +14,17 @@ def is_guest(claims: Dict[str, Any]) -> bool: claims_iss: str = claims.get('iss', '') idp: str = claims.get('idp', claims_iss) return idp != claims_iss + + +def get_unverified_header(access_token: str) -> Dict[str, Any]: + """ + Get header from the access token without verifying the signature + """ + return dict(jwt.get_unverified_header(access_token)) + + +def get_unverified_claims(access_token: str) -> Dict[str, Any]: + """ + Get claims from the access token without verifying the signature + """ + return dict(jwt.decode(access_token, options={'verify_signature': False})) diff --git a/poetry.lock b/poetry.lock index 7781894c..4d807bec 100644 --- a/poetry.lock +++ b/poetry.lock @@ -634,24 +634,6 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] -[[package]] -name = "ecdsa" -version = "0.18.0" -description = "ECDSA cryptographic signature library (pure python)" -optional = false -python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "ecdsa-0.18.0-py2.py3-none-any.whl", hash = "sha256:80600258e7ed2f16b9aa1d7c295bd70194109ad5a30fdee0eaeefef1d4c559dd"}, - {file = "ecdsa-0.18.0.tar.gz", hash = "sha256:190348041559e21b22a1d65cee485282ca11a6f81d503fddb84d5017e9ed1e49"}, -] - -[package.dependencies] -six = ">=1.9.0" - -[package.extras] -gmpy = ["gmpy"] -gmpy2 = ["gmpy2"] - [[package]] name = "exceptiongroup" version = "1.2.0" @@ -1452,17 +1434,6 @@ files = [ [package.extras] tests = ["pytest"] -[[package]] -name = "pyasn1" -version = "0.5.1" -description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" -files = [ - {file = "pyasn1-0.5.1-py2.py3-none-any.whl", hash = "sha256:4439847c58d40b1d0a573d07e3856e95333f1976294494c325775aeca506eb58"}, - {file = "pyasn1-0.5.1.tar.gz", hash = "sha256:6d391a96e59b23130a5cfa74d6fd7f388dbbe26cc8f1edf39fdddf08d9d6676c"}, -] - [[package]] name = "pycparser" version = "2.21" @@ -1618,6 +1589,23 @@ files = [ plugins = ["importlib-metadata"] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyjwt" +version = "2.8.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, + {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pytest" version = "7.4.4" @@ -1800,28 +1788,6 @@ files = [ [package.extras] cli = ["click (>=5.0)"] -[[package]] -name = "python-jose" -version = "3.3.0" -description = "JOSE implementation in Python" -optional = false -python-versions = "*" -files = [ - {file = "python-jose-3.3.0.tar.gz", hash = "sha256:55779b5e6ad599c6336191246e95eb2293a9ddebd555f796a65f838f07e5d78a"}, - {file = "python_jose-3.3.0-py2.py3-none-any.whl", hash = "sha256:9b1376b023f8b298536eedd47ae1089bcdb848f1535ab30555cd92002d78923a"}, -] - -[package.dependencies] -cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"cryptography\""} -ecdsa = "!=0.15" -pyasn1 = "*" -rsa = "*" - -[package.extras] -cryptography = ["cryptography (>=3.4.0)"] -pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"] -pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] - [[package]] name = "pyyaml" version = "6.0.1" @@ -1847,6 +1813,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2072,20 +2039,6 @@ files = [ {file = "rpds_py-0.18.0.tar.gz", hash = "sha256:42821446ee7a76f5d9f71f9e33a4fb2ffd724bb3e7f93386150b61a43115788d"}, ] -[[package]] -name = "rsa" -version = "4.9" -description = "Pure-Python RSA implementation" -optional = false -python-versions = ">=3.6,<4" -files = [ - {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, - {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, -] - -[package.dependencies] -pyasn1 = ">=0.1.3" - [[package]] name = "setuptools" version = "69.1.1" @@ -2417,4 +2370,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "862e33ab3ac6f74ec937e5263afc3ba11be46a71d1400761ec821adbe56fb22e" +content-hash = "958bce36839e94ad8f3129891ea4750c4dc01d526d4b18e0d6973621c6d81377" diff --git a/pyproject.toml b/pyproject.toml index 0f509cf2..15363665 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,8 +43,8 @@ classifiers = [ python = "^3.8" fastapi = ">0.68.0" cryptography = ">=40.0.1" -python-jose = {extras = ["cryptography"], version = "^3.3.0"} httpx = ">0.18.2" +pyjwt = "^2.8.0" [tool.poetry.group.dev.dependencies] diff --git a/tests/multi_tenant/test_multi_tenant.py b/tests/multi_tenant/test_multi_tenant.py index 75aa435f..496b704d 100644 --- a/tests/multi_tenant/test_multi_tenant.py +++ b/tests/multi_tenant/test_multi_tenant.py @@ -18,6 +18,7 @@ ) from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer +from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase from fastapi_azure_auth.exceptions import InvalidAuth @@ -283,7 +284,7 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys): @pytest.mark.anyio async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker): - mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol')) + mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol')) async with AsyncClient( app=app, base_url='http://test', diff --git a/tests/multi_tenant_b2c/test_multi_tenant.py b/tests/multi_tenant_b2c/test_multi_tenant.py index 46ca5bd5..0af476f8 100644 --- a/tests/multi_tenant_b2c/test_multi_tenant.py +++ b/tests/multi_tenant_b2c/test_multi_tenant.py @@ -14,6 +14,7 @@ build_evil_access_token, ) +from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase from fastapi_azure_auth.openid_config import OpenIdConfig @@ -120,15 +121,6 @@ async def test_no_keys_to_decode_with(multi_tenant_app, mock_openid_and_empty_ke assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} -@pytest.mark.anyio -async def test_no_keys_to_decode_with(multi_tenant_app, mock_openid_and_empty_keys): - async with AsyncClient( - app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token(version=2)} - ) as ac: - response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} - - @pytest.mark.anyio async def test_normal_user_rejected(multi_tenant_app, mock_openid_and_keys): async with AsyncClient( @@ -259,7 +251,7 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys): @pytest.mark.anyio async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker): - mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol')) + mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol')) mocker.patch.object(OpenIdConfig, 'load_config', return_value=True) async with AsyncClient( app=app, diff --git a/tests/single_tenant/test_single_tenant_v1_v2_tokens.py b/tests/single_tenant/test_single_tenant_v1_v2_tokens.py index f45ffaab..0fc858b1 100644 --- a/tests/single_tenant/test_single_tenant_v1_v2_tokens.py +++ b/tests/single_tenant/test_single_tenant_v1_v2_tokens.py @@ -14,6 +14,8 @@ build_evil_access_token, ) +from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase + def current_version(current_cases) -> int: return current_cases['single_tenant_app']['token_version'].params['version'] @@ -346,10 +348,23 @@ async def test_only_header(single_tenant_app, mock_openid_and_keys_v1_v2): assert response.json() == {'detail': 'Invalid token format'} +@pytest.mark.anyio +async def test_none_token(single_tenant_app, mock_openid_and_keys_v1_v2, mocker, current_cases): + test_version = current_version(current_cases) + mocker.patch.object(AzureAuthorizationCodeBearerBase, 'extract_access_token', return_value=None) + async with AsyncClient( + app=app, + base_url='http://test', + headers={'Authorization': 'Bearer ' + build_access_token_expired(version=test_version)}, + ) as ac: + response = await ac.get('api/v1/hello') + assert response.json() == {'detail': 'Invalid token format'} + + @pytest.mark.anyio async def test_exception_raised(single_tenant_app, mock_openid_and_keys_v1_v2, mocker, current_cases): test_version = current_version(current_cases) - mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol')) + mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol')) async with AsyncClient( app=app, base_url='http://test', diff --git a/tests/utils.py b/tests/utils.py index b8a1eb85..03e8a0ce 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,10 @@ import time from typing import Optional +import jwt from cryptography.hazmat.backends import default_backend as crypto_default_backend -from cryptography.hazmat.primitives import serialization, serialization as crypto_serialization +from cryptography.hazmat.primitives import serialization as crypto_serialization from cryptography.hazmat.primitives.asymmetric import rsa -from jose import jwk, jwt def generate_private_key(): @@ -162,14 +162,10 @@ def build_openid_keys(empty_keys: bool = False, no_valid_keys: bool = False) -> 'use': 'sig', 'kid': 'dummythumbprint', 'x5t': 'dummythumbprint', - **jwk.construct( - signing_key_a.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ), - 'RS256', - ).to_dict(), + **jwt.algorithms.RSAAlgorithm.to_jwk( + signing_key_a, + as_dict=True, + ), } ] } @@ -180,31 +176,19 @@ def build_openid_keys(empty_keys: bool = False, no_valid_keys: bool = False) -> 'use': 'sig', 'kid': 'dummythumbprint', 'x5t': 'dummythumbprint', - **jwk.construct( - signing_key_a.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ), - 'RS256', - ) - .public_key() - .to_dict(), + **jwt.algorithms.RSAAlgorithm.to_jwk( + signing_key_a.public_key(), + as_dict=True, + ), }, { 'use': 'sig', 'kid': 'real thumbprint', 'x5t': 'real thumbprint', - **jwk.construct( - signing_key_b.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ), - 'RS256', - ) - .public_key() - .to_dict(), + **jwt.algorithms.RSAAlgorithm.to_jwk( + signing_key_b.public_key(), + as_dict=True, + ), }, ] }