Skip to content

Commit 1a75848

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: Use CallbackContext in credential service for saving/loading credential
1. credential service may be accessed by callbacks 2. plan to add load_credential and save_credential method in CallbackContext (see cl/782158513) given customer has requirement to access credential service themselves. (see #1816) It's backward compatible given CallbackContext is parent class of ToolContext PiperOrigin-RevId: 783480378
1 parent bf7745f commit 1a75848

File tree

7 files changed

+354
-304
lines changed

7 files changed

+354
-304
lines changed

src/google/adk/auth/credential_manager.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from typing import Optional
1818

19-
from ..tools.tool_context import ToolContext
19+
from ..agents.callback_context import CallbackContext
2020
from ..utils.feature_decorator import experimental
2121
from .auth_credential import AuthCredential
2222
from .auth_credential import AuthCredentialTypes
@@ -63,7 +63,7 @@ class CredentialManager:
6363
)
6464
6565
# Load and prepare credential
66-
credential = await manager.load_auth_credential(tool_context)
66+
credential = await manager.load_auth_credential(callback_context)
6767
```
6868
"""
6969

@@ -100,11 +100,11 @@ def register_credential_exchanger(
100100
"""
101101
self._exchanger_registry.register(credential_type, exchanger_instance)
102102

103-
async def request_credential(self, tool_context: ToolContext) -> None:
104-
tool_context.request_credential(self._auth_config)
103+
async def request_credential(self, callback_context: CallbackContext) -> None:
104+
callback_context.request_credential(self._auth_config)
105105

106106
async def get_auth_credential(
107-
self, tool_context: ToolContext
107+
self, callback_context: CallbackContext
108108
) -> Optional[AuthCredential]:
109109
"""Load and prepare authentication credential through a structured workflow."""
110110

@@ -116,14 +116,14 @@ async def get_auth_credential(
116116
return self._auth_config.raw_auth_credential
117117

118118
# Step 3: Try to load existing processed credential
119-
credential = await self._load_existing_credential(tool_context)
119+
credential = await self._load_existing_credential(callback_context)
120120

121121
# Step 4: If no existing credential, load from auth response
122122
# TODO instead of load from auth response, we can store auth response in
123123
# credential service.
124124
was_from_auth_response = False
125125
if not credential:
126-
credential = await self._load_from_auth_response(tool_context)
126+
credential = await self._load_from_auth_response(callback_context)
127127
was_from_auth_response = True
128128

129129
# Step 5: If still no credential available, return None
@@ -134,22 +134,23 @@ async def get_auth_credential(
134134
credential, was_exchanged = await self._exchange_credential(credential)
135135

136136
# Step 7: Refresh credential if expired
137+
was_refreshed = False
137138
if not was_exchanged:
138139
credential, was_refreshed = await self._refresh_credential(credential)
139140

140141
# Step 8: Save credential if it was modified
141142
if was_from_auth_response or was_exchanged or was_refreshed:
142-
await self._save_credential(tool_context, credential)
143+
await self._save_credential(callback_context, credential)
143144

144145
return credential
145146

146147
async def _load_existing_credential(
147-
self, tool_context: ToolContext
148+
self, callback_context: CallbackContext
148149
) -> Optional[AuthCredential]:
149150
"""Load existing credential from credential service or cached exchanged credential."""
150151

151152
# Try loading from credential service first
152-
credential = await self._load_from_credential_service(tool_context)
153+
credential = await self._load_from_credential_service(callback_context)
153154
if credential:
154155
return credential
155156

@@ -160,23 +161,23 @@ async def _load_existing_credential(
160161
return None
161162

162163
async def _load_from_credential_service(
163-
self, tool_context: ToolContext
164+
self, callback_context: CallbackContext
164165
) -> Optional[AuthCredential]:
165166
"""Load credential from credential service if available."""
166-
credential_service = tool_context._invocation_context.credential_service
167+
credential_service = callback_context._invocation_context.credential_service
167168
if credential_service:
168169
# Note: This should be made async in a future refactor
169170
# For now, assuming synchronous operation
170171
return await credential_service.load_credential(
171-
self._auth_config, tool_context
172+
self._auth_config, callback_context
172173
)
173174
return None
174175

175176
async def _load_from_auth_response(
176-
self, tool_context: ToolContext
177+
self, callback_context: CallbackContext
177178
) -> Optional[AuthCredential]:
178-
"""Load credential from auth response in tool context."""
179-
return tool_context.get_auth_response(self._auth_config)
179+
"""Load credential from auth response in callback context."""
180+
return callback_context.get_auth_response(self._auth_config)
180181

181182
async def _exchange_credential(
182183
self, credential: AuthCredential
@@ -251,11 +252,13 @@ async def _validate_credential(self) -> None:
251252
# Additional validation can be added here
252253

253254
async def _save_credential(
254-
self, tool_context: ToolContext, credential: AuthCredential
255+
self, callback_context: CallbackContext, credential: AuthCredential
255256
) -> None:
256257
"""Save credential to credential service if available."""
257-
credential_service = tool_context._invocation_context.credential_service
258+
credential_service = callback_context._invocation_context.credential_service
258259
if credential_service:
259260
# Update the exchanged credential in config
260261
self._auth_config.exchanged_auth_credential = credential
261-
await credential_service.save_credential(self._auth_config, tool_context)
262+
await credential_service.save_credential(
263+
self._auth_config, callback_context
264+
)

src/google/adk/auth/credential_service/base_credential_service.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from abc import abstractmethod
1919
from typing import Optional
2020

21-
from ...tools.tool_context import ToolContext
21+
from ...agents.callback_context import CallbackContext
2222
from ...utils.feature_decorator import experimental
2323
from ..auth_credential import AuthCredential
2424
from ..auth_tool import AuthConfig
@@ -33,18 +33,18 @@ class BaseCredentialService(ABC):
3333
async def load_credential(
3434
self,
3535
auth_config: AuthConfig,
36-
tool_context: ToolContext,
36+
callback_context: CallbackContext,
3737
) -> Optional[AuthCredential]:
3838
"""
39-
Loads the credential by auth config and current tool context from the
39+
Loads the credential by auth config and current callback context from the
4040
backend credential store.
4141
4242
Args:
4343
auth_config: The auth config which contains the auth scheme and auth
4444
credential information. auth_config.get_credential_key will be used to
4545
build the key to load the credential.
4646
47-
tool_context: The context of the current invocation when the tool is
47+
callback_context: The context of the current invocation when the tool is
4848
trying to load the credential.
4949
5050
Returns:
@@ -56,7 +56,7 @@ async def load_credential(
5656
async def save_credential(
5757
self,
5858
auth_config: AuthConfig,
59-
tool_context: ToolContext,
59+
callback_context: CallbackContext,
6060
) -> None:
6161
"""
6262
Saves the exchanged_auth_credential in auth config to the backend credential
@@ -67,7 +67,7 @@ async def save_credential(
6767
credential information. auth_config.get_credential_key will be used to
6868
build the key to save the credential.
6969
70-
tool_context: The context of the current invocation when the tool is
70+
callback_context: The context of the current invocation when the tool is
7171
trying to save the credential.
7272
7373
Returns:

src/google/adk/auth/credential_service/in_memory_credential_service.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from typing_extensions import override
2020

21-
from ...tools.tool_context import ToolContext
21+
from ...agents.callback_context import CallbackContext
2222
from ...utils.feature_decorator import experimental
2323
from ..auth_credential import AuthCredential
2424
from ..auth_tool import AuthConfig
@@ -37,25 +37,27 @@ def __init__(self):
3737
async def load_credential(
3838
self,
3939
auth_config: AuthConfig,
40-
tool_context: ToolContext,
40+
callback_context: CallbackContext,
4141
) -> Optional[AuthCredential]:
42-
credential_bucket = self._get_bucket_for_current_context(tool_context)
42+
credential_bucket = self._get_bucket_for_current_context(callback_context)
4343
return credential_bucket.get(auth_config.credential_key)
4444

4545
@override
4646
async def save_credential(
4747
self,
4848
auth_config: AuthConfig,
49-
tool_context: ToolContext,
49+
callback_context: CallbackContext,
5050
) -> None:
51-
credential_bucket = self._get_bucket_for_current_context(tool_context)
51+
credential_bucket = self._get_bucket_for_current_context(callback_context)
5252
credential_bucket[auth_config.credential_key] = (
5353
auth_config.exchanged_auth_credential
5454
)
5555

56-
def _get_bucket_for_current_context(self, tool_context: ToolContext) -> str:
57-
app_name = tool_context._invocation_context.app_name
58-
user_id = tool_context._invocation_context.user_id
56+
def _get_bucket_for_current_context(
57+
self, callback_context: CallbackContext
58+
) -> str:
59+
app_name = callback_context._invocation_context.app_name
60+
user_id = callback_context._invocation_context.user_id
5961

6062
if app_name not in self._credentials:
6163
self._credentials[app_name] = {}

src/google/adk/auth/credential_service/session_state_credential_service.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from typing_extensions import override
2020

21-
from ...tools.tool_context import ToolContext
21+
from ...agents.callback_context import CallbackContext
2222
from ...utils.feature_decorator import experimental
2323
from ..auth_credential import AuthCredential
2424
from ..auth_tool import AuthConfig
@@ -36,31 +36,31 @@ class SessionStateCredentialService(BaseCredentialService):
3636
async def load_credential(
3737
self,
3838
auth_config: AuthConfig,
39-
tool_context: ToolContext,
39+
callback_context: CallbackContext,
4040
) -> Optional[AuthCredential]:
4141
"""
42-
Loads the credential by auth config and current tool context from the
42+
Loads the credential by auth config and current callback context from the
4343
backend credential store.
4444
4545
Args:
4646
auth_config: The auth config which contains the auth scheme and auth
4747
credential information. auth_config.get_credential_key will be used to
4848
build the key to load the credential.
4949
50-
tool_context: The context of the current invocation when the tool is
50+
callback_context: The context of the current invocation when the tool is
5151
trying to load the credential.
5252
5353
Returns:
5454
Optional[AuthCredential]: the credential saved in the store.
5555
5656
"""
57-
return tool_context.state.get(auth_config.credential_key)
57+
return callback_context.state.get(auth_config.credential_key)
5858

5959
@override
6060
async def save_credential(
6161
self,
6262
auth_config: AuthConfig,
63-
tool_context: ToolContext,
63+
callback_context: CallbackContext,
6464
) -> None:
6565
"""
6666
Saves the exchanged_auth_credential in auth config to the backend credential
@@ -71,13 +71,13 @@ async def save_credential(
7171
credential information. auth_config.get_credential_key will be used to
7272
build the key to save the credential.
7373
74-
tool_context: The context of the current invocation when the tool is
74+
callback_context: The context of the current invocation when the tool is
7575
trying to save the credential.
7676
7777
Returns:
7878
None
7979
"""
8080

81-
tool_context.state[auth_config.credential_key] = (
81+
callback_context.state[auth_config.credential_key] = (
8282
auth_config.exchanged_auth_credential
8383
)

0 commit comments

Comments
 (0)