diff --git a/packages/toolbox-core/src/toolbox_core/auth_methods.py b/packages/toolbox-core/src/toolbox_core/auth_methods.py index 83f15774..4fb92ec6 100644 --- a/packages/toolbox-core/src/toolbox_core/auth_methods.py +++ b/packages/toolbox-core/src/toolbox_core/auth_methods.py @@ -12,21 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -# The tokens obtained by these functions are formatted as "Bearer" tokens -# and are intended to be passed in the "Authorization" header of HTTP requests. -# -# Example User Experience: -# from toolbox_core import auth_methods -# -# auth_token_provider = auth_methods.aget_google_id_token -# toolbox = ToolboxClient( -# URL, -# client_headers={"Authorization": auth_token_provider}, -# ) -# tools = await toolbox.load_toolset() +""" +This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens, +for use in the "Authorization" header of HTTP requests. + +Example User Experience: +from toolbox_core import auth_methods +auth_token_provider = auth_methods.aget_google_id_token +toolbox = ToolboxClient( + URL, + client_headers={"Authorization": auth_token_provider}, +) +tools = await toolbox.load_toolset() +""" +from datetime import datetime, timedelta, timezone from functools import partial +from typing import Any, Dict, Optional import google.auth from google.auth._credentials_async import Credentials @@ -34,26 +37,79 @@ from google.auth.transport import _aiohttp_requests from google.auth.transport.requests import AuthorizedSession, Request +# --- Constants and Configuration --- +# Prefix for Authorization header tokens +BEARER_TOKEN_PREFIX = "Bearer " +# Margin in seconds to refresh token before its actual expiry +CACHE_REFRESH_MARGIN_SECONDS = 60 + + +# --- Global Cache Storage --- +# Stores the cached Google ID token and its expiry timestamp +_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0} -async def aget_google_id_token(): + +# --- Helper Functions --- +def _is_cached_token_valid( + cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS +) -> bool: """ - Asynchronously fetches a Google ID token. + Checks if a token in the cache is valid (exists and not expired). - The token is formatted as a 'Bearer' token string and is suitable for use - in an HTTP Authorization header. This function uses Application Default - Credentials. + Args: + cache: The dictionary containing 'token' and 'expires_at'. + margin_seconds: The time in seconds before expiry to consider the token invalid. Returns: - A string in the format "Bearer ". + True if the token is valid, False otherwise. + """ + if not cache.get("token"): + return False + + expires_at_value = cache.get("expires_at") + if not isinstance(expires_at_value, datetime): + return False + + # Ensure expires_at_value is timezone-aware (UTC). + if ( + expires_at_value.tzinfo is None + or expires_at_value.tzinfo.utcoffset(expires_at_value) is None + ): + expires_at_value = expires_at_value.replace(tzinfo=timezone.utc) + + current_time_utc = datetime.now(timezone.utc) + if current_time_utc + timedelta(seconds=margin_seconds) < expires_at_value: + return True + + return False + + +def _update_token_cache( + cache: Dict[str, Any], new_id_token: Optional[str], expiry: Optional[datetime] +) -> None: + """ + Updates the global token cache with a new token and its expiry. + + Args: + cache: The dictionary containing 'token' and 'expires_at'. + new_id_token: The new ID token string to cache. """ - creds, _ = default_async() - await creds.refresh(_aiohttp_requests.Request()) - creds.before_request = partial(Credentials.before_request, creds) - token = creds.id_token - return f"Bearer {token}" + if new_id_token: + cache["token"] = new_id_token + expiry_timestamp = expiry + if expiry_timestamp: + cache["expires_at"] = expiry_timestamp + else: + # If expiry can't be determined, treat as immediately expired to force refresh + cache["expires_at"] = 0 + else: + # Clear cache if no new token is provided + cache["token"] = None + cache["expires_at"] = 0 -def get_google_id_token(): +# --- Public API Functions --- +def get_google_id_token() -> str: """ Synchronously fetches a Google ID token. @@ -63,10 +119,53 @@ def get_google_id_token(): Returns: A string in the format "Bearer ". + + Raises: + Exception: If fetching the Google ID token fails. """ + if _is_cached_token_valid(_cached_google_id_token): + return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"] + credentials, _ = google.auth.default() session = AuthorizedSession(credentials) request = Request(session) credentials.refresh(request) - token = credentials.id_token - return f"Bearer {token}" + new_id_token = getattr(credentials, "id_token", None) + expiry = getattr(credentials, "expiry") + + _update_token_cache(_cached_google_id_token, new_id_token, expiry) + if new_id_token: + return BEARER_TOKEN_PREFIX + new_id_token + else: + raise Exception("Failed to fetch Google ID token.") + + +async def aget_google_id_token() -> str: + """ + Asynchronously fetches a Google ID token. + + The token is formatted as a 'Bearer' token string and is suitable for use + in an HTTP Authorization header. This function uses Application Default + Credentials. + + Returns: + A string in the format "Bearer ". + + Raises: + Exception: If fetching the Google ID token fails. + """ + if _is_cached_token_valid(_cached_google_id_token): + return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"] + + credentials, _ = default_async() + await credentials.refresh(_aiohttp_requests.Request()) + credentials.before_request = partial(Credentials.before_request, credentials) + new_id_token = getattr(credentials, "id_token", None) + expiry = getattr(credentials, "expiry") + + _update_token_cache(_cached_google_id_token, new_id_token, expiry) + + if new_id_token: + return BEARER_TOKEN_PREFIX + new_id_token + else: + raise Exception("Failed to fetch async Google ID token.") diff --git a/packages/toolbox-core/tests/test_auth_methods.py b/packages/toolbox-core/tests/test_auth_methods.py index e316726b..68d0fef2 100644 --- a/packages/toolbox-core/tests/test_auth_methods.py +++ b/packages/toolbox-core/tests/test_auth_methods.py @@ -12,191 +12,394 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch -import google.auth.exceptions import pytest from toolbox_core import auth_methods # Constants for test values -MOCK_ASYNC_ID_TOKEN = "test_async_id_token_123" -MOCK_SYNC_ID_TOKEN = "test_sync_id_token_456" +MOCK_GOOGLE_ID_TOKEN = "test_id_token_123" MOCK_PROJECT_ID = "test-project" - -# Error Messages -ADC_NOT_FOUND_MSG = "ADC not found" -TOKEN_REFRESH_FAILED_MSG = "Token refresh failed" -SYNC_ADC_NOT_FOUND_MSG = "Sync ADC not found" -SYNC_TOKEN_REFRESH_FAILED_MSG = "Sync token refresh failed" - - -@pytest.mark.asyncio -@patch("toolbox_core.auth_methods.partial") -@patch("toolbox_core.auth_methods._aiohttp_requests.Request") -@patch("toolbox_core.auth_methods.Credentials") -@patch("toolbox_core.auth_methods.default_async") -async def test_aget_google_id_token_success( - mock_default_async, - mock_credentials_class, - mock_aiohttp_request_class, - mock_partial, -): - """ - Test aget_google_id_token successfully retrieves and formats a token. - """ - mock_creds_instance = AsyncMock() - mock_creds_instance.id_token = MOCK_ASYNC_ID_TOKEN - mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_aio_request_instance = MagicMock() - mock_aiohttp_request_class.return_value = mock_aio_request_instance - - mock_unbound_before_request = MagicMock() - mock_credentials_class.before_request = mock_unbound_before_request - - mock_partial_object = MagicMock() - mock_partial.return_value = mock_partial_object - - token = await auth_methods.aget_google_id_token() - - mock_default_async.assert_called_once_with() - mock_aiohttp_request_class.assert_called_once_with() - mock_creds_instance.refresh.assert_called_once_with(mock_aio_request_instance) - - mock_partial.assert_called_once_with( - mock_unbound_before_request, mock_creds_instance - ) - assert mock_creds_instance.before_request == mock_partial_object - assert token == f"Bearer {MOCK_ASYNC_ID_TOKEN}" - - -@pytest.mark.asyncio -@patch("toolbox_core.auth_methods.default_async") -async def test_aget_google_id_token_default_credentials_error(mock_default_async): - """ - Test aget_google_id_token handles DefaultCredentialsError. - """ - mock_default_async.side_effect = google.auth.exceptions.DefaultCredentialsError( - ADC_NOT_FOUND_MSG - ) - - with pytest.raises( - google.auth.exceptions.DefaultCredentialsError, match=ADC_NOT_FOUND_MSG +# A realistic expiry timestamp (e.g., 1 hour from now) +MOCK_EXPIRY_DATETIME = auth_methods.datetime.now( + auth_methods.timezone.utc +) + auth_methods.timedelta(hours=1) + + +# Expected exception messages from auth_methods.py +FETCH_TOKEN_FAILURE_MSG = "Failed to fetch Google ID token." +FETCH_ASYNC_TOKEN_FAILURE_MSG = "Failed to fetch async Google ID token." +# These will now match the actual messages from refresh.side_effect +NETWORK_ERROR_MSG = "Network error" +TIMEOUT_ERROR_MSG = "Timeout error" + + +@pytest.fixture(autouse=True) +def reset_cache_after_each_test(): + """Fixture to reset the cache before each test.""" + # Store initial state + original_cache_state = auth_methods._cached_google_id_token.copy() + auth_methods._cached_google_id_token = {"token": None, "expires_at": 0} + yield + # Restore initial state (optional, but good for isolation) + auth_methods._cached_google_id_token = original_cache_state + + +class TestAsyncAuthMethods: + """Tests for asynchronous Google ID token fetching.""" + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) + async def test_aget_google_id_token_success_first_call( + self, mock_default_async, mock_async_req_class ): - await auth_methods.aget_google_id_token() - - mock_default_async.assert_called_once_with() - - -@pytest.mark.asyncio -@patch("toolbox_core.auth_methods._aiohttp_requests.Request") -@patch("toolbox_core.auth_methods.default_async") -async def test_aget_google_id_token_refresh_error( - mock_default_async, - mock_aiohttp_request_class, -): - """ - Test aget_google_id_token handles RefreshError. - """ - mock_creds_instance = AsyncMock() - mock_creds_instance.refresh.side_effect = google.auth.exceptions.RefreshError( - TOKEN_REFRESH_FAILED_MSG - ) - mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_aio_request_instance = MagicMock() - mock_aiohttp_request_class.return_value = mock_aio_request_instance - - with pytest.raises( - google.auth.exceptions.RefreshError, match=TOKEN_REFRESH_FAILED_MSG + """Tests successful fetching of an async token on the first call.""" + mock_creds_instance = AsyncMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_async_req_instance = MagicMock() + mock_async_req_class.return_value = mock_async_req_instance + token = await auth_methods.aget_google_id_token() + + mock_default_async.assert_called_once_with() + mock_async_req_class.assert_called_once_with() + mock_creds_instance.refresh.assert_called_once_with(mock_async_req_instance) + + assert ( + mock_creds_instance.before_request.func + is auth_methods.Credentials.before_request + ) + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME + ) + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) + async def test_aget_google_id_token_success_uses_cache( + self, mock_default_async, mock_async_req_class ): - await auth_methods.aget_google_id_token() - - mock_default_async.assert_called_once_with() - mock_aiohttp_request_class.assert_called_once_with() - mock_creds_instance.refresh.assert_called_once_with(mock_aio_request_instance) - - -# --- Synchronous Tests --- - - -@patch("toolbox_core.auth_methods.Request") -@patch("toolbox_core.auth_methods.AuthorizedSession") -@patch("toolbox_core.auth_methods.google.auth.default") -def test_get_google_id_token_success( - mock_google_auth_default, - mock_authorized_session_class, - mock_request_class, -): - """ - Test get_google_id_token successfully retrieves and formats a token. - """ - mock_creds_instance = MagicMock() - mock_creds_instance.id_token = MOCK_SYNC_ID_TOKEN - mock_google_auth_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_session_instance = MagicMock() - mock_authorized_session_class.return_value = mock_session_instance - - mock_request_instance = MagicMock() - mock_request_class.return_value = mock_request_instance - - token = auth_methods.get_google_id_token() - - mock_google_auth_default.assert_called_once_with() - mock_authorized_session_class.assert_called_once_with(mock_creds_instance) - mock_request_class.assert_called_once_with(mock_session_instance) - mock_creds_instance.refresh.assert_called_once_with(mock_request_instance) - assert token == f"Bearer {MOCK_SYNC_ID_TOKEN}" - - -@patch("toolbox_core.auth_methods.google.auth.default") -def test_get_google_id_token_default_credentials_error(mock_google_auth_default): - """ - Test get_google_id_token handles DefaultCredentialsError. - """ - mock_google_auth_default.side_effect = ( - google.auth.exceptions.DefaultCredentialsError(SYNC_ADC_NOT_FOUND_MSG) - ) - - with pytest.raises( - google.auth.exceptions.DefaultCredentialsError, match=SYNC_ADC_NOT_FOUND_MSG + """Tests that subsequent calls use the cached token if valid.""" + auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN + auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods.timezone.utc + ) + auth_methods.timedelta( + seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 + ) # Ensure it's valid + + token = await auth_methods.aget_google_id_token() + + mock_default_async.assert_not_called() + mock_async_req_class.assert_not_called() + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) + async def test_aget_google_id_token_refreshes_expired_cache( + self, mock_default_async, mock_async_req_class ): - auth_methods.get_google_id_token() - - mock_google_auth_default.assert_called_once_with() - - -@patch("toolbox_core.auth_methods.Request") -@patch("toolbox_core.auth_methods.AuthorizedSession") -@patch("toolbox_core.auth_methods.google.auth.default") -def test_get_google_id_token_refresh_error( - mock_google_auth_default, - mock_authorized_session_class, - mock_request_class, -): - """ - Test get_google_id_token handles RefreshError. - """ - mock_creds_instance = MagicMock() - mock_creds_instance.refresh.side_effect = google.auth.exceptions.RefreshError( - SYNC_TOKEN_REFRESH_FAILED_MSG - ) - mock_google_auth_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_session_instance = MagicMock() - mock_authorized_session_class.return_value = mock_session_instance - - mock_request_instance = MagicMock() - mock_request_class.return_value = mock_request_instance - - with pytest.raises( - google.auth.exceptions.RefreshError, match=SYNC_TOKEN_REFRESH_FAILED_MSG + """Tests that an expired cached token triggers a refresh.""" + auth_methods._cached_google_id_token["token"] = "expired_token" + auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods.timezone.utc + ) - auth_methods.timedelta( + seconds=100 + ) # Expired + + mock_creds_instance = AsyncMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_async_req_instance = MagicMock() + mock_async_req_class.return_value = mock_async_req_instance + + token = await auth_methods.aget_google_id_token() + + mock_default_async.assert_called_once_with() + mock_async_req_class.assert_called_once_with() + mock_creds_instance.refresh.assert_called_once_with(mock_async_req_instance) + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME + ) + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) + async def test_aget_google_id_token_fetch_failure( + self, mock_default_async, mock_async_req_class ): - auth_methods.get_google_id_token() - - mock_google_auth_default.assert_called_once_with() - mock_authorized_session_class.assert_called_once_with(mock_creds_instance) - mock_request_class.assert_called_once_with(mock_session_instance) - mock_creds_instance.refresh.assert_called_once_with(mock_request_instance) + """Tests error handling when fetching the token fails (no id_token returned).""" + mock_creds_instance = AsyncMock() + mock_creds_instance.id_token = None # Simulate no ID token after refresh + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) # Still need expiry for update_cache + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_async_req_class.return_value = MagicMock() + + with pytest.raises(Exception, match=FETCH_ASYNC_TOKEN_FAILURE_MSG): + await auth_methods.aget_google_id_token() + + assert auth_methods._cached_google_id_token["token"] is None + assert auth_methods._cached_google_id_token["expires_at"] == 0 + mock_async_req_class.assert_called_once_with() + mock_creds_instance.refresh.assert_called_once() + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) + async def test_aget_google_id_token_refresh_raises_exception( + self, mock_default_async, mock_async_req_class + ): + """Tests exception handling when credentials refresh fails.""" + mock_creds_instance = AsyncMock() + mock_creds_instance.refresh.side_effect = Exception(NETWORK_ERROR_MSG) + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + mock_async_req_class.return_value = MagicMock() + + with pytest.raises(Exception, match=NETWORK_ERROR_MSG): + await auth_methods.aget_google_id_token() + + assert auth_methods._cached_google_id_token["token"] is None + assert auth_methods._cached_google_id_token["expires_at"] == 0 + mock_async_req_class.assert_called_once_with() + mock_creds_instance.refresh.assert_called_once() + + @pytest.mark.asyncio + @patch("toolbox_core.auth_methods._aiohttp_requests.Request") + @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) + async def test_aget_google_id_token_no_expiry_info( + self, mock_default_async, mock_async_req_class + ): + """Tests that a token without expiry info is still cached but effectively expired.""" + mock_creds_instance = AsyncMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + type(mock_creds_instance).expiry = PropertyMock( + return_value=None + ) # Simulate no expiry info + mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_async_req_class.return_value = MagicMock() + + token = await auth_methods.aget_google_id_token() + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == 0 + ) # Should be 0 if no expiry + mock_async_req_class.assert_called_once_with() + + +class TestSyncAuthMethods: + """Tests for synchronous Google ID token fetching.""" + + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_success_first_call( + self, + mock_sync_default, + mock_auth_session_class, + mock_sync_req_class, + ): + """Tests successful fetching of a sync token on the first call.""" + mock_creds_instance = MagicMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_request_instance = MagicMock() + mock_sync_req_class.return_value = mock_sync_request_instance + + token = auth_methods.get_google_id_token() + + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_creds_instance.refresh.assert_called_once_with(mock_sync_request_instance) + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME + ) + + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_success_uses_cache( + self, + mock_sync_default, + mock_auth_session_class, + mock_sync_req_class, + ): + """Tests that subsequent calls use the cached token if valid.""" + auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN + auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods.timezone.utc + ) + auth_methods.timedelta( + seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 + ) # Ensure it's valid + + token = auth_methods.get_google_id_token() + + mock_sync_default.assert_not_called() + mock_auth_session_class.assert_not_called() + mock_sync_req_class.assert_not_called() + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_refreshes_expired_cache( + self, + mock_sync_default, + mock_auth_session_class, + mock_sync_req_class, + ): + """Tests that an expired cached token triggers a refresh.""" + # Prime the cache with an expired token + auth_methods._cached_google_id_token["token"] = "expired_token_sync" + auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods.timezone.utc + ) - auth_methods.timedelta( + seconds=100 + ) # Expired + + mock_creds_instance = MagicMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_request_instance = MagicMock() + mock_sync_req_class.return_value = mock_sync_request_instance + + token = auth_methods.get_google_id_token() + + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_creds_instance.refresh.assert_called_once_with(mock_sync_request_instance) + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME + ) + + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_fetch_failure( + self, mock_sync_default, mock_auth_session_class, mock_sync_req_class + ): + """Tests error handling when fetching the token fails (no id_token returned).""" + mock_creds_instance = MagicMock() + mock_creds_instance.id_token = None # Simulate no ID token after refresh + type(mock_creds_instance).expiry = PropertyMock( + return_value=MOCK_EXPIRY_DATETIME + ) # Still need expiry for update_cache + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_req_class.return_value = MagicMock() + + with pytest.raises(Exception, match=FETCH_TOKEN_FAILURE_MSG): + auth_methods.get_google_id_token() + + assert auth_methods._cached_google_id_token["token"] is None + assert auth_methods._cached_google_id_token["expires_at"] == 0 + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_creds_instance.refresh.assert_called_once() + + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_refresh_raises_exception( + self, mock_sync_default, mock_auth_session_class, mock_sync_req_class + ): + """Tests exception handling when credentials refresh fails.""" + mock_creds_instance = MagicMock() + mock_creds_instance.refresh.side_effect = Exception(TIMEOUT_ERROR_MSG) + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_req_class.return_value = MagicMock() + + with pytest.raises(Exception, match=TIMEOUT_ERROR_MSG): + auth_methods.get_google_id_token() + + assert auth_methods._cached_google_id_token["token"] is None + assert auth_methods._cached_google_id_token["expires_at"] == 0 + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_creds_instance.refresh.assert_called_once() + + @patch("toolbox_core.auth_methods.Request") + @patch("toolbox_core.auth_methods.AuthorizedSession") + @patch("toolbox_core.auth_methods.google.auth.default") + def test_get_google_id_token_no_expiry_info( + self, + mock_sync_default, + mock_auth_session_class, + mock_sync_req_class, + ): + """Tests that a token without expiry info is still cached but effectively expired.""" + mock_creds_instance = MagicMock() + mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN + type(mock_creds_instance).expiry = PropertyMock( + return_value=None + ) # Simulate no expiry info + mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + + mock_session_instance = MagicMock() + mock_auth_session_class.return_value = mock_session_instance + + mock_sync_request_instance = MagicMock() + mock_sync_req_class.return_value = mock_sync_request_instance + + token = auth_methods.get_google_id_token() + + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" + assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + assert ( + auth_methods._cached_google_id_token["expires_at"] == 0 + ) # Should be 0 if no expiry + mock_sync_default.assert_called_once_with() + mock_auth_session_class.assert_called_once_with(mock_creds_instance) + mock_sync_req_class.assert_called_once_with(mock_session_instance)