Skip to content

feat: Cache google id tokens #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/toolbox-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ authors = [
dependencies = [
"pydantic>=2.7.0,<3.0.0",
"aiohttp>=3.8.6,<4.0.0",
"PyJWT>=2.0.0,<3.0.0",
]

classifiers = [
Expand Down
1 change: 1 addition & 0 deletions packages/toolbox-core/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
aiohttp==3.11.18
pydantic==2.11.4
PyJWT==2.10.1
157 changes: 131 additions & 26 deletions packages/toolbox-core/src/toolbox_core/auth_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,112 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# The tokens obtained by these functions are formatted as "Bearer" tokens
# and are intended to be passed in the "Authorization" header of HTTP requests.
#
# Example User Experience:
# from toolbox_core import auth_methods
#
# auth_token_provider = auth_methods.aget_google_id_token
# toolbox = ToolboxClient(
# URL,
# client_headers={"Authorization": auth_token_provider},
# )
# tools = await toolbox.load_toolset()
"""
This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens,
for use in the "Authorization" header of HTTP requests.

Example User Experience:
from toolbox_core import auth_methods

auth_token_provider = auth_methods.aget_google_id_token
toolbox = ToolboxClient(
URL,
client_headers={"Authorization": auth_token_provider},
)
tools = await toolbox.load_toolset()
"""

import time
from functools import partial
from typing import Any, Dict, Optional

import google.auth
import jwt
from google.auth._credentials_async import Credentials
from google.auth._default_async import default_async
from google.auth.transport import _aiohttp_requests
from google.auth.transport.requests import AuthorizedSession, Request

# --- Constants and Configuration ---
# Prefix for Authorization header tokens
BEARER_TOKEN_PREFIX = "Bearer "
# Margin in seconds to refresh token before its actual expiry
CACHE_REFRESH_MARGIN_SECONDS = 60


# --- Global Cache Storage ---
# Stores the cached Google ID token and its expiry timestamp
_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0}


async def aget_google_id_token():
# --- Helper Functions ---
def _decode_jwt_and_get_expiry(id_token: str) -> Optional[float]:
"""
Asynchronously fetches a Google ID token.
Decodes a JWT and extracts the 'exp' (expiration) claim.

The token is formatted as a 'Bearer' token string and is suitable for use
in an HTTP Authorization header. This function uses Application Default
Credentials.
Args:
id_token: The JWT string to decode.

Returns:
A string in the format "Bearer <google_id_token>".
The 'exp' timestamp as a float if present and decoding is successful,
otherwise None.
"""
try:
decoded_token = jwt.decode(
id_token, options={"verify_signature": False, "verify_aud": False}
)
return decoded_token.get("exp")
except jwt.PyJWTError:
return None


def _is_cached_token_valid(
cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS
) -> bool:
"""
Checks if a token in the cache is valid (exists and not expired).

Args:
cache: The dictionary containing 'token' and 'expires_at'.
margin_seconds: The time in seconds before expiry to consider the token invalid.

Returns:
True if the token is valid, False otherwise.
"""
if not cache.get("token"):
return False

expires_at = cache.get("expires_at")
if not isinstance(expires_at, (int, float)) or expires_at <= 0:
return False

return time.time() < (expires_at - margin_seconds)


def _update_token_cache(cache: Dict[str, Any], new_id_token: Optional[str]) -> None:
"""
Updates the global token cache with a new token and its expiry.

Args:
cache: The dictionary containing 'token' and 'expires_at'.
new_id_token: The new ID token string to cache.
"""
creds, _ = default_async()
await creds.refresh(_aiohttp_requests.Request())
creds.before_request = partial(Credentials.before_request, creds)
token = creds.id_token
return f"Bearer {token}"
if new_id_token:
cache["token"] = new_id_token
expiry_timestamp = _decode_jwt_and_get_expiry(new_id_token)
if expiry_timestamp:
cache["expires_at"] = expiry_timestamp
else:
# If expiry can't be determined, treat as immediately expired to force refresh
cache["expires_at"] = 0
else:
# Clear cache if no new token is provided
cache["token"] = None
cache["expires_at"] = 0


def get_google_id_token():
# --- Public API Functions ---
def get_google_id_token() -> str:
"""
Synchronously fetches a Google ID token.

Expand All @@ -63,10 +127,51 @@ def get_google_id_token():

Returns:
A string in the format "Bearer <google_id_token>".

Raises:
Exception: If fetching the Google ID token fails.
"""
if _is_cached_token_valid(_cached_google_id_token):
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]

credentials, _ = google.auth.default()
session = AuthorizedSession(credentials)
request = Request(session)
credentials.refresh(request)
token = credentials.id_token
return f"Bearer {token}"
new_id_token = getattr(credentials, "id_token", None)

_update_token_cache(_cached_google_id_token, new_id_token)
if new_id_token:
return BEARER_TOKEN_PREFIX + new_id_token
else:
raise Exception("Failed to fetch Google ID token.")


async def aget_google_id_token() -> str:
"""
Asynchronously fetches a Google ID token.

The token is formatted as a 'Bearer' token string and is suitable for use
in an HTTP Authorization header. This function uses Application Default
Credentials.

Returns:
A string in the format "Bearer <google_id_token>".

Raises:
Exception: If fetching the Google ID token fails.
"""
if _is_cached_token_valid(_cached_google_id_token):
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]

credentials, _ = default_async()
await credentials.refresh(_aiohttp_requests.Request())
credentials.before_request = partial(Credentials.before_request, credentials)
new_id_token = getattr(credentials, "id_token", None)

_update_token_cache(_cached_google_id_token, new_id_token)

if new_id_token:
return BEARER_TOKEN_PREFIX + new_id_token
else:
raise Exception("Failed to fetch async Google ID token.")
Loading