From 7539808c50c34b1d78ba915f103471a18df04065 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 22 May 2025 13:19:57 +0530 Subject: [PATCH 1/8] feat: Cache google id tokens --- packages/toolbox-core/pyproject.toml | 1 + packages/toolbox-core/requirements.txt | 1 + .../src/toolbox_core/auth_methods.py | 158 ++++- .../toolbox-core/tests/test_auth_methods.py | 548 ++++++++++++------ 4 files changed, 506 insertions(+), 202 deletions(-) diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index 6a918b4e..90e593dd 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -12,6 +12,7 @@ authors = [ dependencies = [ "pydantic>=2.7.0,<3.0.0", "aiohttp>=3.8.6,<4.0.0", + "PyJWT>=2.0.0,<3.0.0", ] classifiers = [ diff --git a/packages/toolbox-core/requirements.txt b/packages/toolbox-core/requirements.txt index 3d86b0eb..f91e76e4 100644 --- a/packages/toolbox-core/requirements.txt +++ b/packages/toolbox-core/requirements.txt @@ -1,2 +1,3 @@ aiohttp==3.11.18 pydantic==2.11.4 +PyJWT==2.10.1 diff --git a/packages/toolbox-core/src/toolbox_core/auth_methods.py b/packages/toolbox-core/src/toolbox_core/auth_methods.py index 83f15774..87ab23e3 100644 --- a/packages/toolbox-core/src/toolbox_core/auth_methods.py +++ b/packages/toolbox-core/src/toolbox_core/auth_methods.py @@ -12,48 +12,113 @@ # See the License for the specific language governing permissions and # limitations under the License. -# The tokens obtained by these functions are formatted as "Bearer" tokens -# and are intended to be passed in the "Authorization" header of HTTP requests. -# -# Example User Experience: -# from toolbox_core import auth_methods -# -# auth_token_provider = auth_methods.aget_google_id_token -# toolbox = ToolboxClient( -# URL, -# client_headers={"Authorization": auth_token_provider}, -# ) -# tools = await toolbox.load_toolset() +""" +This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens, +for use in the "Authorization" header of HTTP requests. + +Example User Experience: +from toolbox_core import auth_methods +auth_token_provider = auth_methods.aget_google_id_token +toolbox = ToolboxClient( + URL, + client_headers={"Authorization": auth_token_provider}, +) +tools = await toolbox.load_toolset() +""" +import time from functools import partial +from typing import Optional, Dict, Any import google.auth from google.auth._credentials_async import Credentials +import jwt from google.auth._default_async import default_async from google.auth.transport import _aiohttp_requests from google.auth.transport.requests import AuthorizedSession, Request -async def aget_google_id_token(): +# --- Constants and Configuration --- +# Prefix for Authorization header tokens +BEARER_TOKEN_PREFIX = "Bearer " +# Margin in seconds to refresh token before its actual expiry +CACHE_REFRESH_MARGIN_SECONDS = 60 + + +# --- Global Cache Storage --- +# Stores the cached Google ID token and its expiry timestamp +_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0} + + +# --- Helper Functions --- +def _decode_jwt_and_get_expiry(id_token: str) -> Optional[float]: """ - Asynchronously fetches a Google ID token. + Decodes a JWT and extracts the 'exp' (expiration) claim. - The token is formatted as a 'Bearer' token string and is suitable for use - in an HTTP Authorization header. This function uses Application Default - Credentials. + Args: + id_token: The JWT string to decode. Returns: - A string in the format "Bearer ". + The 'exp' timestamp as a float if present and decoding is successful, + otherwise None. """ - creds, _ = default_async() - await creds.refresh(_aiohttp_requests.Request()) - creds.before_request = partial(Credentials.before_request, creds) - token = creds.id_token - return f"Bearer {token}" + try: + decoded_token = jwt.decode( + id_token, options={"verify_signature": False, "verify_aud": False} + ) + return decoded_token.get("exp") + except jwt.PyJWTError: + return None -def get_google_id_token(): +def _is_cached_token_valid( + cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS +) -> bool: + """ + Checks if a token in the cache is valid (exists and not expired). + + Args: + cache: The dictionary containing 'token' and 'expires_at'. + margin_seconds: The time in seconds before expiry to consider the token invalid. + + Returns: + True if the token is valid, False otherwise. + """ + if not cache.get("token"): + return False + + expires_at = cache.get("expires_at") + if not isinstance(expires_at, (int, float)) or expires_at <= 0: + return False + + return time.time() < (expires_at - margin_seconds) + + +def _update_token_cache(cache: Dict[str, Any], new_id_token: Optional[str]): + """ + Updates the global token cache with a new token and its expiry. + + Args: + cache: The dictionary containing 'token' and 'expires_at'. + new_id_token: The new ID token string to cache. + """ + if new_id_token: + cache["token"] = new_id_token + expiry_timestamp = _decode_jwt_and_get_expiry(new_id_token) + if expiry_timestamp: + cache["expires_at"] = expiry_timestamp + else: + # If expiry can't be determined, treat as immediately expired to force refresh + cache["expires_at"] = 0 + else: + # Clear cache if no new token is provided + cache["token"] = None + cache["expires_at"] = 0 + + +# --- Public API Functions --- +def get_google_id_token() -> str: """ Synchronously fetches a Google ID token. @@ -63,10 +128,51 @@ def get_google_id_token(): Returns: A string in the format "Bearer ". + + Raises: + Exception: If fetching the Google ID token fails. """ + if _is_cached_token_valid(_cached_google_id_token): + return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"] + credentials, _ = google.auth.default() session = AuthorizedSession(credentials) request = Request(session) credentials.refresh(request) - token = credentials.id_token - return f"Bearer {token}" + new_id_token = getattr(credentials, "id_token", None) + + _update_token_cache(_cached_google_id_token, new_id_token) + if new_id_token: + return BEARER_TOKEN_PREFIX + new_id_token + else: + raise Exception("Failed to fetch Google ID token.") + + +async def aget_google_id_token() -> str: + """ + Asynchronously fetches a Google ID token. + + The token is formatted as a 'Bearer' token string and is suitable for use + in an HTTP Authorization header. This function uses Application Default + Credentials. + + Returns: + A string in the format "Bearer ". + + Raises: + Exception: If fetching the Google ID token fails. + """ + if _is_cached_token_valid(_cached_google_id_token): + return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"] + + credentials, _ = default_async() + await credentials.refresh(_aiohttp_requests.Request()) + credentials.before_request = partial(Credentials.before_request, credentials) + new_id_token = getattr(credentials, "id_token", None) + + _update_token_cache(_cached_google_id_token, new_id_token) + + if new_id_token: + return BEARER_TOKEN_PREFIX + new_id_token + else: + raise Exception("Failed to fetch async Google ID token.") \ No newline at end of file diff --git a/packages/toolbox-core/tests/test_auth_methods.py b/packages/toolbox-core/tests/test_auth_methods.py index e316726b..feccf7f7 100644 --- a/packages/toolbox-core/tests/test_auth_methods.py +++ b/packages/toolbox-core/tests/test_auth_methods.py @@ -12,191 +12,387 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time from unittest.mock import AsyncMock, MagicMock, patch -import google.auth.exceptions import pytest from toolbox_core import auth_methods # Constants for test values -MOCK_ASYNC_ID_TOKEN = "test_async_id_token_123" -MOCK_SYNC_ID_TOKEN = "test_sync_id_token_456" +MOCK_GOOGLE_ID_TOKEN = "test_id_token_123" MOCK_PROJECT_ID = "test-project" - -# Error Messages -ADC_NOT_FOUND_MSG = "ADC not found" -TOKEN_REFRESH_FAILED_MSG = "Token refresh failed" -SYNC_ADC_NOT_FOUND_MSG = "Sync ADC not found" -SYNC_TOKEN_REFRESH_FAILED_MSG = "Sync token refresh failed" - - -@pytest.mark.asyncio -@patch("toolbox_core.auth_methods.partial") -@patch("toolbox_core.auth_methods._aiohttp_requests.Request") -@patch("toolbox_core.auth_methods.Credentials") -@patch("toolbox_core.auth_methods.default_async") -async def test_aget_google_id_token_success( - mock_default_async, - mock_credentials_class, - mock_aiohttp_request_class, - mock_partial, -): - """ - Test aget_google_id_token successfully retrieves and formats a token. - """ - mock_creds_instance = AsyncMock() - mock_creds_instance.id_token = MOCK_ASYNC_ID_TOKEN - mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_aio_request_instance = MagicMock() - mock_aiohttp_request_class.return_value = mock_aio_request_instance - - mock_unbound_before_request = MagicMock() - mock_credentials_class.before_request = mock_unbound_before_request - - mock_partial_object = MagicMock() - mock_partial.return_value = mock_partial_object - - token = await auth_methods.aget_google_id_token() - - mock_default_async.assert_called_once_with() - mock_aiohttp_request_class.assert_called_once_with() - mock_creds_instance.refresh.assert_called_once_with(mock_aio_request_instance) - - mock_partial.assert_called_once_with( - mock_unbound_before_request, mock_creds_instance - ) - assert mock_creds_instance.before_request == mock_partial_object - assert token == f"Bearer {MOCK_ASYNC_ID_TOKEN}" - - -@pytest.mark.asyncio -@patch("toolbox_core.auth_methods.default_async") -async def test_aget_google_id_token_default_credentials_error(mock_default_async): - """ - Test aget_google_id_token handles DefaultCredentialsError. - """ - mock_default_async.side_effect = google.auth.exceptions.DefaultCredentialsError( - ADC_NOT_FOUND_MSG - ) - - with pytest.raises( - google.auth.exceptions.DefaultCredentialsError, match=ADC_NOT_FOUND_MSG +# A realistic expiry timestamp (e.g., 1 hour from now) +MOCK_EXPIRY_TIMESTAMP = time.time() + 3600 + +# Expected exception messages from auth_methods.py +FETCH_TOKEN_FAILURE_MSG = "Failed to fetch Google ID token." +FETCH_ASYNC_TOKEN_FAILURE_MSG = "Failed to fetch async Google ID token." +# These will now match the actual messages from refresh.side_effect +NETWORK_ERROR_MSG = "Network error" +TIMEOUT_ERROR_MSG = "Timeout error" + + +@pytest.fixture(autouse=True) +def reset_cache_after_each_test(): + """Fixture to reset the cache before each test.""" + # Store initial state + original_cache_state = auth_methods._cached_google_id_token.copy() + auth_methods._cached_google_id_token = {"token": None, "expires_at": 0} + yield + # Restore initial state (optional, but good for isolation) + auth_methods._cached_google_id_token = original_cache_state + + +class TestAsyncAuthMethods: + """Tests for asynchronous Google ID token fetching.""" + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + async def test_aget_google_id_token_success_first_call( + self, mock_default_async, mock_async_req_class, mock_decode_expiry ): - await auth_methods.aget_google_id_token() - - mock_default_async.assert_called_once_with() - - -@pytest.mark.asyncio -@patch("toolbox_core.auth_methods._aiohttp_requests.Request") -@patch("toolbox_core.auth_methods.default_async") -async def test_aget_google_id_token_refresh_error( - mock_default_async, - mock_aiohttp_request_class, -): - """ - Test aget_google_id_token handles RefreshError. - """ - mock_creds_instance = AsyncMock() - mock_creds_instance.refresh.side_effect = google.auth.exceptions.RefreshError( - TOKEN_REFRESH_FAILED_MSG - ) - mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_aio_request_instance = MagicMock() - mock_aiohttp_request_class.return_value = mock_aio_request_instance - - with pytest.raises( - google.auth.exceptions.RefreshError, match=TOKEN_REFRESH_FAILED_MSG + """Tests successful fetching of an async token on the first call.""" + mock_creds_instance = AsyncMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_decode_expiry.return_value = MOCK_EXPIRY_TIMESTAMP + + mock_async_req_instance = MagicMock() + mock_async_req_class.return_value = mock_async_req_instance + + token = await auth_methods.aget_google_id_token() + + mock_default_async.assert_called_once_with() + mock_async_req_class.assert_called_once_with() + mock_creds_instance.refresh.assert_called_once_with(mock_async_req_instance) + + assert ( + mock_creds_instance.before_request.func + is auth_methods.Credentials.before_request + ) + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_TIMESTAMP + ) + mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + async def test_aget_google_id_token_success_uses_cache( + self, mock_default_async, mock_async_req_class, mock_decode_expiry ): - await auth_methods.aget_google_id_token() - - mock_default_async.assert_called_once_with() - mock_aiohttp_request_class.assert_called_once_with() - mock_creds_instance.refresh.assert_called_once_with(mock_aio_request_instance) - - -# --- Synchronous Tests --- - - -@patch("toolbox_core.auth_methods.Request") -@patch("toolbox_core.auth_methods.AuthorizedSession") -@patch("toolbox_core.auth_methods.google.auth.default") -def test_get_google_id_token_success( - mock_google_auth_default, - mock_authorized_session_class, - mock_request_class, -): - """ - Test get_google_id_token successfully retrieves and formats a token. - """ - mock_creds_instance = MagicMock() - mock_creds_instance.id_token = MOCK_SYNC_ID_TOKEN - mock_google_auth_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_session_instance = MagicMock() - mock_authorized_session_class.return_value = mock_session_instance - - mock_request_instance = MagicMock() - mock_request_class.return_value = mock_request_instance - - token = auth_methods.get_google_id_token() - - mock_google_auth_default.assert_called_once_with() - mock_authorized_session_class.assert_called_once_with(mock_creds_instance) - mock_request_class.assert_called_once_with(mock_session_instance) - mock_creds_instance.refresh.assert_called_once_with(mock_request_instance) - assert token == f"Bearer {MOCK_SYNC_ID_TOKEN}" - - -@patch("toolbox_core.auth_methods.google.auth.default") -def test_get_google_id_token_default_credentials_error(mock_google_auth_default): - """ - Test get_google_id_token handles DefaultCredentialsError. - """ - mock_google_auth_default.side_effect = ( - google.auth.exceptions.DefaultCredentialsError(SYNC_ADC_NOT_FOUND_MSG) - ) - - with pytest.raises( - google.auth.exceptions.DefaultCredentialsError, match=SYNC_ADC_NOT_FOUND_MSG + """Tests that subsequent calls use the cached token if valid.""" + auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN + auth_methods._cached_google_id_token["expires_at"] = ( + time.time() + auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 + ) # Ensure it's valid + + token = await auth_methods.aget_google_id_token() + + mock_default_async.assert_not_called() + mock_async_req_class.assert_not_called() + mock_decode_expiry.assert_not_called() + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + async def test_aget_google_id_token_refreshes_expired_cache( + self, mock_default_async, mock_async_req_class, mock_decode_expiry ): - auth_methods.get_google_id_token() - - mock_google_auth_default.assert_called_once_with() - - -@patch("toolbox_core.auth_methods.Request") -@patch("toolbox_core.auth_methods.AuthorizedSession") -@patch("toolbox_core.auth_methods.google.auth.default") -def test_get_google_id_token_refresh_error( - mock_google_auth_default, - mock_authorized_session_class, - mock_request_class, -): - """ - Test get_google_id_token handles RefreshError. - """ - mock_creds_instance = MagicMock() - mock_creds_instance.refresh.side_effect = google.auth.exceptions.RefreshError( - SYNC_TOKEN_REFRESH_FAILED_MSG - ) - mock_google_auth_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_session_instance = MagicMock() - mock_authorized_session_class.return_value = mock_session_instance - - mock_request_instance = MagicMock() - mock_request_class.return_value = mock_request_instance - - with pytest.raises( - google.auth.exceptions.RefreshError, match=SYNC_TOKEN_REFRESH_FAILED_MSG + """Tests that an expired cached token triggers a refresh.""" + auth_methods._cached_google_id_token["token"] = "expired_token" + auth_methods._cached_google_id_token["expires_at"] = ( + time.time() - 100 + ) # Expired + + mock_creds_instance = AsyncMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_decode_expiry.return_value = MOCK_EXPIRY_TIMESTAMP + + mock_async_req_instance = MagicMock() + mock_async_req_class.return_value = mock_async_req_instance + + token = await auth_methods.aget_google_id_token() + + mock_default_async.assert_called_once_with() + mock_async_req_class.assert_called_once_with() + mock_creds_instance.refresh.assert_called_once_with(mock_async_req_instance) + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_TIMESTAMP + ) + mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + async def test_aget_google_id_token_fetch_failure( + self, mock_default_async, mock_async_req_class ): - auth_methods.get_google_id_token() - - mock_google_auth_default.assert_called_once_with() - mock_authorized_session_class.assert_called_once_with(mock_creds_instance) - mock_request_class.assert_called_once_with(mock_session_instance) - mock_creds_instance.refresh.assert_called_once_with(mock_request_instance) + """Tests error handling when fetching the token fails (no id_token returned).""" + mock_creds_instance = AsyncMock() + mock_creds_instance.id_token = None # Simulate no ID token after refresh + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_async_req_class.return_value = MagicMock() + + with pytest.raises(Exception, match=FETCH_ASYNC_TOKEN_FAILURE_MSG): + await auth_methods.aget_google_id_token() + + assert auth_methods._cached_google_id_token["token"] is None + assert auth_methods._cached_google_id_token["expires_at"] == 0 + mock_async_req_class.assert_called_once_with() + mock_creds_instance.refresh.assert_called_once() + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + async def test_aget_google_id_token_refresh_raises_exception( + self, mock_default_async, mock_async_req_class + ): + """Tests exception handling when credentials refresh fails.""" + mock_creds_instance = AsyncMock() + mock_creds_instance.refresh.side_effect = Exception(NETWORK_ERROR_MSG) + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_async_req_class.return_value = MagicMock() + + with pytest.raises(Exception, match=NETWORK_ERROR_MSG): + await auth_methods.aget_google_id_token() + + assert auth_methods._cached_google_id_token["token"] is None + assert auth_methods._cached_google_id_token["expires_at"] == 0 + mock_async_req_class.assert_called_once_with() + mock_creds_instance.refresh.assert_called_once() + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + async def test_aget_google_id_token_no_expiry_info( + self, mock_default_async, mock_async_req_class, mock_decode_expiry + ): + """Tests that a token without expiry info is still cached but effectively expired.""" + mock_creds_instance = AsyncMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_decode_expiry.return_value = None # Simulate no expiry info + + mock_async_req_class.return_value = MagicMock() + + token = await auth_methods.aget_google_id_token() + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == 0 + ) # Should be 0 if no expiry + mock_async_req_class.assert_called_once_with() + mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) + + +class TestSyncAuthMethods: + """Tests for synchronous Google ID token fetching.""" + + @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_success_first_call( + self, + mock_sync_default, + mock_auth_session_class, + mock_sync_req_class, + mock_decode_expiry, + ): + """Tests successful fetching of a sync token on the first call.""" + mock_creds_instance = MagicMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_decode_expiry.return_value = MOCK_EXPIRY_TIMESTAMP + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_request_instance = MagicMock() + mock_sync_req_class.return_value = mock_sync_request_instance + + token = auth_methods.get_google_id_token() + + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_creds_instance.refresh.assert_called_once_with(mock_sync_request_instance) + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_TIMESTAMP + ) + mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) + + @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_success_uses_cache( + self, + mock_sync_default, + mock_auth_session_class, + mock_sync_req_class, + mock_decode_expiry, + ): + """Tests that subsequent calls use the cached token if valid.""" + auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN + auth_methods._cached_google_id_token["expires_at"] = ( + time.time() + auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 + ) # Ensure it's valid + + token = auth_methods.get_google_id_token() + + mock_sync_default.assert_not_called() + mock_auth_session_class.assert_not_called() + mock_sync_req_class.assert_not_called() + mock_decode_expiry.assert_not_called() + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + + @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_refreshes_expired_cache( + self, + mock_sync_default, + mock_auth_session_class, + mock_sync_req_class, + mock_decode_expiry, + ): + """Tests that an expired cached token triggers a refresh.""" + # Prime the cache with an expired token + auth_methods._cached_google_id_token["token"] = "expired_token_sync" + auth_methods._cached_google_id_token["expires_at"] = ( + time.time() - 100 + ) # Expired + + mock_creds_instance = MagicMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_decode_expiry.return_value = MOCK_EXPIRY_TIMESTAMP + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_request_instance = MagicMock() + mock_sync_req_class.return_value = mock_sync_request_instance + + token = auth_methods.get_google_id_token() + + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_creds_instance.refresh.assert_called_once_with(mock_sync_request_instance) + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_TIMESTAMP + ) + mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) + + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_fetch_failure( + self, mock_sync_default, mock_auth_session_class, mock_sync_req_class + ): + """Tests error handling when fetching the token fails (no id_token returned).""" + mock_creds_instance = MagicMock() + mock_creds_instance.id_token = None # Simulate no ID token after refresh + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_req_class.return_value = MagicMock() + + with pytest.raises(Exception, match=FETCH_TOKEN_FAILURE_MSG): + auth_methods.get_google_id_token() + + assert auth_methods._cached_google_id_token["token"] is None + assert auth_methods._cached_google_id_token["expires_at"] == 0 + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_creds_instance.refresh.assert_called_once() + + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_refresh_raises_exception( + self, mock_sync_default, mock_auth_session_class, mock_sync_req_class + ): + """Tests exception handling when credentials refresh fails.""" + mock_creds_instance = MagicMock() + mock_creds_instance.refresh.side_effect = Exception(TIMEOUT_ERROR_MSG) + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_req_class.return_value = MagicMock() + + with pytest.raises(Exception, match=TIMEOUT_ERROR_MSG): + auth_methods.get_google_id_token() + + assert auth_methods._cached_google_id_token["token"] is None + assert auth_methods._cached_google_id_token["expires_at"] == 0 + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_creds_instance.refresh.assert_called_once() + + @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_no_expiry_info( + self, + mock_sync_default, + mock_auth_session_class, + mock_sync_req_class, + mock_decode_expiry, + ): + """Tests that a token without expiry info is still cached but effectively expired.""" + mock_creds_instance = MagicMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_decode_expiry.return_value = None # Simulate no expiry info + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_request_instance = MagicMock() + mock_sync_req_class.return_value = mock_sync_request_instance + + token = auth_methods.get_google_id_token() + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == 0 + ) # Should be 0 if no expiry + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) From a4d154c05b6f713b214aff7283d2ab898843f476 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 22 May 2025 13:21:43 +0530 Subject: [PATCH 2/8] lint --- packages/toolbox-core/src/toolbox_core/auth_methods.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/auth_methods.py b/packages/toolbox-core/src/toolbox_core/auth_methods.py index 87ab23e3..525e9ccd 100644 --- a/packages/toolbox-core/src/toolbox_core/auth_methods.py +++ b/packages/toolbox-core/src/toolbox_core/auth_methods.py @@ -29,16 +29,15 @@ import time from functools import partial -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional import google.auth -from google.auth._credentials_async import Credentials import jwt +from google.auth._credentials_async import Credentials from google.auth._default_async import default_async from google.auth.transport import _aiohttp_requests from google.auth.transport.requests import AuthorizedSession, Request - # --- Constants and Configuration --- # Prefix for Authorization header tokens BEARER_TOKEN_PREFIX = "Bearer " @@ -175,4 +174,4 @@ async def aget_google_id_token() -> str: if new_id_token: return BEARER_TOKEN_PREFIX + new_id_token else: - raise Exception("Failed to fetch async Google ID token.") \ No newline at end of file + raise Exception("Failed to fetch async Google ID token.") From 16e9c8770885d13ee0f95a044cbbd1ca526f23eb Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 22 May 2025 13:27:31 +0530 Subject: [PATCH 3/8] lint --- packages/toolbox-core/src/toolbox_core/auth_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/auth_methods.py b/packages/toolbox-core/src/toolbox_core/auth_methods.py index 525e9ccd..4bc5f2f3 100644 --- a/packages/toolbox-core/src/toolbox_core/auth_methods.py +++ b/packages/toolbox-core/src/toolbox_core/auth_methods.py @@ -94,7 +94,7 @@ def _is_cached_token_valid( return time.time() < (expires_at - margin_seconds) -def _update_token_cache(cache: Dict[str, Any], new_id_token: Optional[str]): +def _update_token_cache(cache: Dict[str, Any], new_id_token: Optional[str]) -> None: """ Updates the global token cache with a new token and its expiry. From 01bc5938406c78f16d3d02c4f3d2489cdd76c55f Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 22 May 2025 13:29:25 +0530 Subject: [PATCH 4/8] fix tests --- packages/toolbox-core/tests/test_auth_methods.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/toolbox-core/tests/test_auth_methods.py b/packages/toolbox-core/tests/test_auth_methods.py index feccf7f7..2f4413e1 100644 --- a/packages/toolbox-core/tests/test_auth_methods.py +++ b/packages/toolbox-core/tests/test_auth_methods.py @@ -50,7 +50,7 @@ class TestAsyncAuthMethods: @pytest.mark.asyncio @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_success_first_call( self, mock_default_async, mock_async_req_class, mock_decode_expiry ): @@ -83,7 +83,7 @@ async def test_aget_google_id_token_success_first_call( @pytest.mark.asyncio @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_success_uses_cache( self, mock_default_async, mock_async_req_class, mock_decode_expiry ): @@ -105,7 +105,7 @@ async def test_aget_google_id_token_success_uses_cache( @pytest.mark.asyncio @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_refreshes_expired_cache( self, mock_default_async, mock_async_req_class, mock_decode_expiry ): @@ -137,7 +137,7 @@ async def test_aget_google_id_token_refreshes_expired_cache( @pytest.mark.asyncio @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_fetch_failure( self, mock_default_async, mock_async_req_class ): @@ -157,7 +157,7 @@ async def test_aget_google_id_token_fetch_failure( @pytest.mark.asyncio @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_refresh_raises_exception( self, mock_default_async, mock_async_req_class ): @@ -178,7 +178,7 @@ async def test_aget_google_id_token_refresh_raises_exception( @pytest.mark.asyncio @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=AsyncMock) + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_no_expiry_info( self, mock_default_async, mock_async_req_class, mock_decode_expiry ): From a578964f3641bccbcca90ffadce0fc9bcf0a3253 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 27 May 2025 14:49:20 +0530 Subject: [PATCH 5/8] get expiry directly from creds --- .../src/toolbox_core/auth_methods.py | 49 ++++++++----------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/auth_methods.py b/packages/toolbox-core/src/toolbox_core/auth_methods.py index 4bc5f2f3..cbc76fb4 100644 --- a/packages/toolbox-core/src/toolbox_core/auth_methods.py +++ b/packages/toolbox-core/src/toolbox_core/auth_methods.py @@ -27,12 +27,11 @@ tools = await toolbox.load_toolset() """ -import time +from datetime import datetime, timedelta, timezone from functools import partial from typing import Any, Dict, Optional import google.auth -import jwt from google.auth._credentials_async import Credentials from google.auth._default_async import default_async from google.auth.transport import _aiohttp_requests @@ -51,26 +50,6 @@ # --- Helper Functions --- -def _decode_jwt_and_get_expiry(id_token: str) -> Optional[float]: - """ - Decodes a JWT and extracts the 'exp' (expiration) claim. - - Args: - id_token: The JWT string to decode. - - Returns: - The 'exp' timestamp as a float if present and decoding is successful, - otherwise None. - """ - try: - decoded_token = jwt.decode( - id_token, options={"verify_signature": False, "verify_aud": False} - ) - return decoded_token.get("exp") - except jwt.PyJWTError: - return None - - def _is_cached_token_valid( cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS ) -> bool: @@ -87,14 +66,24 @@ def _is_cached_token_valid( if not cache.get("token"): return False - expires_at = cache.get("expires_at") - if not isinstance(expires_at, (int, float)) or expires_at <= 0: + expires_at_value = cache.get("expires_at") + if not isinstance(expires_at_value, datetime): return False - return time.time() < (expires_at - margin_seconds) + # Ensure expires_at_value is timezone-aware (UTC). + if expires_at_value.tzinfo is None or expires_at_value.tzinfo.utcoffset(expires_at_value) is None: + expires_at_value = expires_at_value.replace(tzinfo=timezone.utc) + + current_time_utc = datetime.now(timezone.utc) + if current_time_utc + timedelta(seconds=margin_seconds) < expires_at_value: + return True + + return False -def _update_token_cache(cache: Dict[str, Any], new_id_token: Optional[str]) -> None: +def _update_token_cache( + cache: Dict[str, Any], new_id_token: Optional[str], expiry: Optional[datetime] +) -> None: """ Updates the global token cache with a new token and its expiry. @@ -104,7 +93,7 @@ def _update_token_cache(cache: Dict[str, Any], new_id_token: Optional[str]) -> N """ if new_id_token: cache["token"] = new_id_token - expiry_timestamp = _decode_jwt_and_get_expiry(new_id_token) + expiry_timestamp = expiry if expiry_timestamp: cache["expires_at"] = expiry_timestamp else: @@ -139,8 +128,9 @@ def get_google_id_token() -> str: request = Request(session) credentials.refresh(request) new_id_token = getattr(credentials, "id_token", None) + expiry = getattr(credentials, "expiry") - _update_token_cache(_cached_google_id_token, new_id_token) + _update_token_cache(_cached_google_id_token, new_id_token, expiry) if new_id_token: return BEARER_TOKEN_PREFIX + new_id_token else: @@ -168,8 +158,9 @@ async def aget_google_id_token() -> str: await credentials.refresh(_aiohttp_requests.Request()) credentials.before_request = partial(Credentials.before_request, credentials) new_id_token = getattr(credentials, "id_token", None) + expiry = getattr(credentials, "expiry") - _update_token_cache(_cached_google_id_token, new_id_token) + _update_token_cache(_cached_google_id_token, new_id_token, expiry) if new_id_token: return BEARER_TOKEN_PREFIX + new_id_token From 71fac641c135dfa49795a94725cf95ce6d7c1ae4 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 27 May 2025 14:53:01 +0530 Subject: [PATCH 6/8] fix tests --- .../toolbox-core/tests/test_auth_methods.py | 74 +++++++------------ 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/packages/toolbox-core/tests/test_auth_methods.py b/packages/toolbox-core/tests/test_auth_methods.py index 2f4413e1..373280c0 100644 --- a/packages/toolbox-core/tests/test_auth_methods.py +++ b/packages/toolbox-core/tests/test_auth_methods.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock import pytest @@ -23,7 +22,8 @@ MOCK_GOOGLE_ID_TOKEN = "test_id_token_123" MOCK_PROJECT_ID = "test-project" # A realistic expiry timestamp (e.g., 1 hour from now) -MOCK_EXPIRY_TIMESTAMP = time.time() + 3600 +MOCK_EXPIRY_DATETIME = auth_methods.datetime.now(auth_methods.timezone.utc) + auth_methods.timedelta(hours=1) + # Expected exception messages from auth_methods.py FETCH_TOKEN_FAILURE_MSG = "Failed to fetch Google ID token." @@ -48,21 +48,19 @@ class TestAsyncAuthMethods: """Tests for asynchronous Google ID token fetching.""" @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods._aiohttp_requests.Request") @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_success_first_call( - self, mock_default_async, mock_async_req_class, mock_decode_expiry + self, mock_default_async, mock_async_req_class ): """Tests successful fetching of an async token on the first call.""" mock_creds_instance = AsyncMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - mock_decode_expiry.return_value = MOCK_EXPIRY_TIMESTAMP mock_async_req_instance = MagicMock() mock_async_req_class.return_value = mock_async_req_instance - token = await auth_methods.aget_google_id_token() mock_default_async.assert_called_once_with() @@ -76,49 +74,46 @@ async def test_aget_google_id_token_success_first_call( assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN assert ( - auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_TIMESTAMP + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME ) - mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods._aiohttp_requests.Request") @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_success_uses_cache( - self, mock_default_async, mock_async_req_class, mock_decode_expiry + self, mock_default_async, mock_async_req_class ): """Tests that subsequent calls use the cached token if valid.""" auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN auth_methods._cached_google_id_token["expires_at"] = ( - time.time() + auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 + auth_methods.datetime.now(auth_methods.timezone.utc) + + auth_methods.timedelta(seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100) ) # Ensure it's valid token = await auth_methods.aget_google_id_token() mock_default_async.assert_not_called() mock_async_req_class.assert_not_called() - mock_decode_expiry.assert_not_called() assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods._aiohttp_requests.Request") @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_refreshes_expired_cache( - self, mock_default_async, mock_async_req_class, mock_decode_expiry + self, mock_default_async, mock_async_req_class ): """Tests that an expired cached token triggers a refresh.""" auth_methods._cached_google_id_token["token"] = "expired_token" auth_methods._cached_google_id_token["expires_at"] = ( - time.time() - 100 + auth_methods.datetime.now(auth_methods.timezone.utc) - auth_methods.timedelta(seconds=100) ) # Expired mock_creds_instance = AsyncMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh + type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - mock_decode_expiry.return_value = MOCK_EXPIRY_TIMESTAMP mock_async_req_instance = MagicMock() mock_async_req_class.return_value = mock_async_req_instance @@ -130,10 +125,8 @@ async def test_aget_google_id_token_refreshes_expired_cache( mock_creds_instance.refresh.assert_called_once_with(mock_async_req_instance) assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert ( - auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_TIMESTAMP - ) - mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) + assert auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME + @pytest.mark.asyncio @patch("toolbox_core.auth_methods._aiohttp_requests.Request") @@ -144,6 +137,7 @@ async def test_aget_google_id_token_fetch_failure( """Tests error handling when fetching the token fails (no id_token returned).""" mock_creds_instance = AsyncMock() mock_creds_instance.id_token = None # Simulate no ID token after refresh + type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) # Still need expiry for update_cache mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_async_req_class.return_value = MagicMock() @@ -176,20 +170,19 @@ async def test_aget_google_id_token_refresh_raises_exception( mock_creds_instance.refresh.assert_called_once() @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods._aiohttp_requests.Request") @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) async def test_aget_google_id_token_no_expiry_info( - self, mock_default_async, mock_async_req_class, mock_decode_expiry + self, mock_default_async, mock_async_req_class ): """Tests that a token without expiry info is still cached but effectively expired.""" mock_creds_instance = AsyncMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + type(mock_creds_instance).expiry = PropertyMock(return_value=None) # Simulate no expiry info mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - mock_decode_expiry.return_value = None # Simulate no expiry info mock_async_req_class.return_value = MagicMock() - + token = await auth_methods.aget_google_id_token() assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" @@ -198,13 +191,11 @@ async def test_aget_google_id_token_no_expiry_info( auth_methods._cached_google_id_token["expires_at"] == 0 ) # Should be 0 if no expiry mock_async_req_class.assert_called_once_with() - mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) class TestSyncAuthMethods: """Tests for synchronous Google ID token fetching.""" - @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods.Request") @patch("toolbox_core.auth_methods.AuthorizedSession") @patch("toolbox_core.auth_methods.google.auth.default") @@ -213,13 +204,12 @@ def test_get_google_id_token_success_first_call( mock_sync_default, mock_auth_session_class, mock_sync_req_class, - mock_decode_expiry, ): """Tests successful fetching of a sync token on the first call.""" mock_creds_instance = MagicMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - mock_decode_expiry.return_value = MOCK_EXPIRY_TIMESTAMP mock_session_instance = MagicMock() mock_auth_session_class.return_value = mock_session_instance @@ -237,11 +227,9 @@ def test_get_google_id_token_success_first_call( assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN assert ( - auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_TIMESTAMP + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME ) - mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) - @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods.Request") @patch("toolbox_core.auth_methods.AuthorizedSession") @patch("toolbox_core.auth_methods.google.auth.default") @@ -250,12 +238,12 @@ def test_get_google_id_token_success_uses_cache( mock_sync_default, mock_auth_session_class, mock_sync_req_class, - mock_decode_expiry, ): """Tests that subsequent calls use the cached token if valid.""" auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN auth_methods._cached_google_id_token["expires_at"] = ( - time.time() + auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 + auth_methods.datetime.now(auth_methods.timezone.utc) + + auth_methods.timedelta(seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100) ) # Ensure it's valid token = auth_methods.get_google_id_token() @@ -263,12 +251,10 @@ def test_get_google_id_token_success_uses_cache( mock_sync_default.assert_not_called() mock_auth_session_class.assert_not_called() mock_sync_req_class.assert_not_called() - mock_decode_expiry.assert_not_called() assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods.Request") @patch("toolbox_core.auth_methods.AuthorizedSession") @patch("toolbox_core.auth_methods.google.auth.default") @@ -277,19 +263,18 @@ def test_get_google_id_token_refreshes_expired_cache( mock_sync_default, mock_auth_session_class, mock_sync_req_class, - mock_decode_expiry, ): """Tests that an expired cached token triggers a refresh.""" # Prime the cache with an expired token auth_methods._cached_google_id_token["token"] = "expired_token_sync" auth_methods._cached_google_id_token["expires_at"] = ( - time.time() - 100 + auth_methods.datetime.now(auth_methods.timezone.utc) - auth_methods.timedelta(seconds=100) ) # Expired mock_creds_instance = MagicMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh + type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - mock_decode_expiry.return_value = MOCK_EXPIRY_TIMESTAMP mock_session_instance = MagicMock() mock_auth_session_class.return_value = mock_session_instance @@ -305,10 +290,7 @@ def test_get_google_id_token_refreshes_expired_cache( mock_creds_instance.refresh.assert_called_once_with(mock_sync_request_instance) assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert ( - auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_TIMESTAMP - ) - mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) + assert auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME @patch("toolbox_core.auth_methods.Request") @patch("toolbox_core.auth_methods.AuthorizedSession") @@ -319,6 +301,7 @@ def test_get_google_id_token_fetch_failure( """Tests error handling when fetching the token fails (no id_token returned).""" mock_creds_instance = MagicMock() mock_creds_instance.id_token = None # Simulate no ID token after refresh + type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) # Still need expiry for update_cache mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_session_instance = MagicMock() @@ -362,7 +345,6 @@ def test_get_google_id_token_refresh_raises_exception( mock_sync_req_class.assert_called_once_with(mock_session_instance) mock_creds_instance.refresh.assert_called_once() - @patch("toolbox_core.auth_methods._decode_jwt_and_get_expiry") @patch("toolbox_core.auth_methods.Request") @patch("toolbox_core.auth_methods.AuthorizedSession") @patch("toolbox_core.auth_methods.google.auth.default") @@ -371,13 +353,12 @@ def test_get_google_id_token_no_expiry_info( mock_sync_default, mock_auth_session_class, mock_sync_req_class, - mock_decode_expiry, ): """Tests that a token without expiry info is still cached but effectively expired.""" mock_creds_instance = MagicMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + type(mock_creds_instance).expiry = PropertyMock(return_value=None) # Simulate no expiry info mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - mock_decode_expiry.return_value = None # Simulate no expiry info mock_session_instance = MagicMock() mock_auth_session_class.return_value = mock_session_instance @@ -395,4 +376,3 @@ def test_get_google_id_token_no_expiry_info( mock_sync_default.assert_called_once_with() mock_auth_session_class.assert_called_once_with(mock_creds_instance) mock_sync_req_class.assert_called_once_with(mock_session_instance) - mock_decode_expiry.assert_called_once_with(MOCK_GOOGLE_ID_TOKEN) From e0a789b72ac059d1259a4d871113172d152a7b79 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 27 May 2025 14:53:44 +0530 Subject: [PATCH 7/8] lint --- .../src/toolbox_core/auth_methods.py | 5 +- .../toolbox-core/tests/test_auth_methods.py | 75 +++++++++++++------ 2 files changed, 55 insertions(+), 25 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/auth_methods.py b/packages/toolbox-core/src/toolbox_core/auth_methods.py index cbc76fb4..4fb92ec6 100644 --- a/packages/toolbox-core/src/toolbox_core/auth_methods.py +++ b/packages/toolbox-core/src/toolbox_core/auth_methods.py @@ -71,7 +71,10 @@ def _is_cached_token_valid( return False # Ensure expires_at_value is timezone-aware (UTC). - if expires_at_value.tzinfo is None or expires_at_value.tzinfo.utcoffset(expires_at_value) is None: + if ( + expires_at_value.tzinfo is None + or expires_at_value.tzinfo.utcoffset(expires_at_value) is None + ): expires_at_value = expires_at_value.replace(tzinfo=timezone.utc) current_time_utc = datetime.now(timezone.utc) diff --git a/packages/toolbox-core/tests/test_auth_methods.py b/packages/toolbox-core/tests/test_auth_methods.py index 373280c0..68d0fef2 100644 --- a/packages/toolbox-core/tests/test_auth_methods.py +++ b/packages/toolbox-core/tests/test_auth_methods.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch import pytest @@ -22,7 +22,9 @@ MOCK_GOOGLE_ID_TOKEN = "test_id_token_123" MOCK_PROJECT_ID = "test-project" # A realistic expiry timestamp (e.g., 1 hour from now) -MOCK_EXPIRY_DATETIME = auth_methods.datetime.now(auth_methods.timezone.utc) + auth_methods.timedelta(hours=1) +MOCK_EXPIRY_DATETIME = auth_methods.datetime.now( + auth_methods.timezone.utc +) + auth_methods.timedelta(hours=1) # Expected exception messages from auth_methods.py @@ -56,7 +58,9 @@ async def test_aget_google_id_token_success_first_call( """Tests successful fetching of an async token on the first call.""" mock_creds_instance = AsyncMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN - type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_async_req_instance = MagicMock() @@ -85,9 +89,10 @@ async def test_aget_google_id_token_success_uses_cache( ): """Tests that subsequent calls use the cached token if valid.""" auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN - auth_methods._cached_google_id_token["expires_at"] = ( - auth_methods.datetime.now(auth_methods.timezone.utc) + - auth_methods.timedelta(seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100) + auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods.timezone.utc + ) + auth_methods.timedelta( + seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 ) # Ensure it's valid token = await auth_methods.aget_google_id_token() @@ -106,13 +111,17 @@ async def test_aget_google_id_token_refreshes_expired_cache( ): """Tests that an expired cached token triggers a refresh.""" auth_methods._cached_google_id_token["token"] = "expired_token" - auth_methods._cached_google_id_token["expires_at"] = ( - auth_methods.datetime.now(auth_methods.timezone.utc) - auth_methods.timedelta(seconds=100) + auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods.timezone.utc + ) - auth_methods.timedelta( + seconds=100 ) # Expired mock_creds_instance = AsyncMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh - type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_async_req_instance = MagicMock() @@ -125,8 +134,9 @@ async def test_aget_google_id_token_refreshes_expired_cache( mock_creds_instance.refresh.assert_called_once_with(mock_async_req_instance) assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME - + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME + ) @pytest.mark.asyncio @patch("toolbox_core.auth_methods._aiohttp_requests.Request") @@ -137,7 +147,9 @@ async def test_aget_google_id_token_fetch_failure( """Tests error handling when fetching the token fails (no id_token returned).""" mock_creds_instance = AsyncMock() mock_creds_instance.id_token = None # Simulate no ID token after refresh - type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) # Still need expiry for update_cache + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) # Still need expiry for update_cache mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_async_req_class.return_value = MagicMock() @@ -178,11 +190,13 @@ async def test_aget_google_id_token_no_expiry_info( """Tests that a token without expiry info is still cached but effectively expired.""" mock_creds_instance = AsyncMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN - type(mock_creds_instance).expiry = PropertyMock(return_value=None) # Simulate no expiry info + type(mock_creds_instance).expiry = PropertyMock( + return_value=None + ) # Simulate no expiry info mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_async_req_class.return_value = MagicMock() - + token = await auth_methods.aget_google_id_token() assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" @@ -208,7 +222,9 @@ def test_get_google_id_token_success_first_call( """Tests successful fetching of a sync token on the first call.""" mock_creds_instance = MagicMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN - type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_session_instance = MagicMock() @@ -241,9 +257,10 @@ def test_get_google_id_token_success_uses_cache( ): """Tests that subsequent calls use the cached token if valid.""" auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN - auth_methods._cached_google_id_token["expires_at"] = ( - auth_methods.datetime.now(auth_methods.timezone.utc) + - auth_methods.timedelta(seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100) + auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods.timezone.utc + ) + auth_methods.timedelta( + seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 ) # Ensure it's valid token = auth_methods.get_google_id_token() @@ -267,13 +284,17 @@ def test_get_google_id_token_refreshes_expired_cache( """Tests that an expired cached token triggers a refresh.""" # Prime the cache with an expired token auth_methods._cached_google_id_token["token"] = "expired_token_sync" - auth_methods._cached_google_id_token["expires_at"] = ( - auth_methods.datetime.now(auth_methods.timezone.utc) - auth_methods.timedelta(seconds=100) + auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods.timezone.utc + ) - auth_methods.timedelta( + seconds=100 ) # Expired mock_creds_instance = MagicMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh - type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_session_instance = MagicMock() @@ -290,7 +311,9 @@ def test_get_google_id_token_refreshes_expired_cache( mock_creds_instance.refresh.assert_called_once_with(mock_sync_request_instance) assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME + ) @patch("toolbox_core.auth_methods.Request") @patch("toolbox_core.auth_methods.AuthorizedSession") @@ -301,7 +324,9 @@ def test_get_google_id_token_fetch_failure( """Tests error handling when fetching the token fails (no id_token returned).""" mock_creds_instance = MagicMock() mock_creds_instance.id_token = None # Simulate no ID token after refresh - type(mock_creds_instance).expiry = PropertyMock(return_value=MOCK_EXPIRY_DATETIME) # Still need expiry for update_cache + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) # Still need expiry for update_cache mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_session_instance = MagicMock() @@ -357,7 +382,9 @@ def test_get_google_id_token_no_expiry_info( """Tests that a token without expiry info is still cached but effectively expired.""" mock_creds_instance = MagicMock() mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN - type(mock_creds_instance).expiry = PropertyMock(return_value=None) # Simulate no expiry info + type(mock_creds_instance).expiry = PropertyMock( + return_value=None + ) # Simulate no expiry info mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) mock_session_instance = MagicMock() From 9602bf68ff792dd95a3c54a0baa2c59dd34bd133 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 27 May 2025 14:57:08 +0530 Subject: [PATCH 8/8] remove pyjwt from requirements --- packages/toolbox-core/pyproject.toml | 1 - packages/toolbox-core/requirements.txt | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index 90e593dd..6a918b4e 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -12,7 +12,6 @@ authors = [ dependencies = [ "pydantic>=2.7.0,<3.0.0", "aiohttp>=3.8.6,<4.0.0", - "PyJWT>=2.0.0,<3.0.0", ] classifiers = [ diff --git a/packages/toolbox-core/requirements.txt b/packages/toolbox-core/requirements.txt index f91e76e4..d6dca066 100644 --- a/packages/toolbox-core/requirements.txt +++ b/packages/toolbox-core/requirements.txt @@ -1,3 +1,2 @@ aiohttp==3.11.18 -pydantic==2.11.4 -PyJWT==2.10.1 +pydantic==2.11.4 \ No newline at end of file