Skip to content

Commit b977d12

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: Add save_credential and load_credential in callback context for developer to access credential service
developer may want to save/load credentials themselves to/from credential service. see (#1816) PiperOrigin-RevId: 783487628
1 parent b1fa383 commit b977d12

File tree

4 files changed

+198
-20
lines changed

4 files changed

+198
-20
lines changed

src/google/adk/agents/callback_context.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
if TYPE_CHECKING:
2525
from google.genai import types
2626

27+
from ..auth.auth_credential import AuthCredential
28+
from ..auth.auth_tool import AuthConfig
2729
from ..events.event_actions import EventActions
2830
from ..sessions.state import State
2931
from .invocation_context import InvocationContext
@@ -115,3 +117,32 @@ async def list_artifacts(self) -> list[str]:
115117
user_id=self._invocation_context.user_id,
116118
session_id=self._invocation_context.session.id,
117119
)
120+
121+
async def save_credential(self, auth_config: AuthConfig) -> None:
122+
"""Saves a credential to the credential service.
123+
124+
Args:
125+
auth_config: The authentication configuration containing the credential.
126+
"""
127+
if self._invocation_context.credential_service is None:
128+
raise ValueError("Credential service is not initialized.")
129+
await self._invocation_context.credential_service.save_credential(
130+
auth_config, self
131+
)
132+
133+
async def load_credential(
134+
self, auth_config: AuthConfig
135+
) -> Optional[AuthCredential]:
136+
"""Loads a credential from the credential service.
137+
138+
Args:
139+
auth_config: The authentication configuration for the credential.
140+
141+
Returns:
142+
The loaded credential, or None if not found.
143+
"""
144+
if self._invocation_context.credential_service is None:
145+
raise ValueError("Credential service is not initialized.")
146+
return await self._invocation_context.credential_service.load_credential(
147+
auth_config, self
148+
)

src/google/adk/auth/credential_manager.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,7 @@ async def _load_from_credential_service(
168168
if credential_service:
169169
# Note: This should be made async in a future refactor
170170
# For now, assuming synchronous operation
171-
return await credential_service.load_credential(
172-
self._auth_config, callback_context
173-
)
171+
return await callback_context.load_credential(self._auth_config)
174172
return None
175173

176174
async def _load_from_auth_response(
@@ -255,10 +253,9 @@ async def _save_credential(
255253
self, callback_context: CallbackContext, credential: AuthCredential
256254
) -> None:
257255
"""Save credential to credential service if available."""
256+
# Update the exchanged credential in config
257+
self._auth_config.exchanged_auth_credential = credential
258+
258259
credential_service = callback_context._invocation_context.credential_service
259260
if credential_service:
260-
# Update the exchanged credential in config
261-
self._auth_config.exchanged_auth_credential = credential
262-
await credential_service.save_credential(
263-
self._auth_config, callback_context
264-
)
261+
await callback_context.save_credential(self._auth_config)

tests/unittests/agents/test_callback_context.py

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@
1616

1717
from unittest.mock import AsyncMock
1818
from unittest.mock import MagicMock
19+
from unittest.mock import Mock
1920

2021
from google.adk.agents.callback_context import CallbackContext
22+
from google.adk.auth.auth_credential import AuthCredential
23+
from google.adk.auth.auth_credential import AuthCredentialTypes
24+
from google.adk.auth.auth_tool import AuthConfig
2125
from google.adk.tools.tool_context import ToolContext
26+
from google.genai.types import Part
2227
import pytest
2328

2429

@@ -32,6 +37,8 @@ def mock_invocation_context():
3237
mock_context.session.id = "test-session-id"
3338
mock_context.app_name = "test-app"
3439
mock_context.user_id = "test-user"
40+
mock_context.artifact_service = None
41+
mock_context.credential_service = None
3542
return mock_context
3643

3744

@@ -63,6 +70,21 @@ def callback_context_without_artifact_service(mock_invocation_context):
6370
return CallbackContext(mock_invocation_context)
6471

6572

73+
@pytest.fixture
74+
def mock_auth_config():
75+
"""Create a mock auth config for testing."""
76+
mock_config = Mock(spec=AuthConfig)
77+
return mock_config
78+
79+
80+
@pytest.fixture
81+
def mock_auth_credential():
82+
"""Create a mock auth credential for testing."""
83+
mock_credential = Mock(spec=AuthCredential)
84+
mock_credential.auth_type = AuthCredentialTypes.OAUTH2
85+
return mock_credential
86+
87+
6688
class TestCallbackContextListArtifacts:
6789
"""Test the list_artifacts method in CallbackContext."""
6890

@@ -119,8 +141,8 @@ async def test_list_artifacts_passes_through_service_exceptions(
119141
await callback_context_with_artifact_service.list_artifacts()
120142

121143

122-
class TestToolContextListArtifacts:
123-
"""Test that list_artifacts is available in ToolContext through inheritance."""
144+
class TestCallbackContext:
145+
"""Test suite for CallbackContext."""
124146

125147
@pytest.mark.asyncio
126148
async def test_tool_context_inherits_list_artifacts(
@@ -167,3 +189,134 @@ def test_tool_context_shares_same_list_artifacts_method_with_callback_context(
167189
):
168190
"""Test that ToolContext and CallbackContext share the same list_artifacts method."""
169191
assert ToolContext.list_artifacts is CallbackContext.list_artifacts
192+
193+
def test_initialization(self, mock_invocation_context):
194+
"""Test CallbackContext initialization."""
195+
context = CallbackContext(mock_invocation_context)
196+
assert context._invocation_context == mock_invocation_context
197+
assert context._event_actions is not None
198+
assert context._state is not None
199+
200+
@pytest.mark.asyncio
201+
async def test_save_credential_with_service(
202+
self, mock_invocation_context, mock_auth_config
203+
):
204+
"""Test save_credential when credential service is available."""
205+
# Mock credential service
206+
credential_service = AsyncMock()
207+
mock_invocation_context.credential_service = credential_service
208+
209+
context = CallbackContext(mock_invocation_context)
210+
await context.save_credential(mock_auth_config)
211+
212+
credential_service.save_credential.assert_called_once_with(
213+
mock_auth_config, context
214+
)
215+
216+
@pytest.mark.asyncio
217+
async def test_save_credential_no_service(
218+
self, mock_invocation_context, mock_auth_config
219+
):
220+
"""Test save_credential when credential service is not available."""
221+
mock_invocation_context.credential_service = None
222+
223+
context = CallbackContext(mock_invocation_context)
224+
225+
with pytest.raises(
226+
ValueError, match="Credential service is not initialized"
227+
):
228+
await context.save_credential(mock_auth_config)
229+
230+
@pytest.mark.asyncio
231+
async def test_load_credential_with_service(
232+
self, mock_invocation_context, mock_auth_config, mock_auth_credential
233+
):
234+
"""Test load_credential when credential service is available."""
235+
# Mock credential service
236+
credential_service = AsyncMock()
237+
credential_service.load_credential.return_value = mock_auth_credential
238+
mock_invocation_context.credential_service = credential_service
239+
240+
context = CallbackContext(mock_invocation_context)
241+
result = await context.load_credential(mock_auth_config)
242+
243+
credential_service.load_credential.assert_called_once_with(
244+
mock_auth_config, context
245+
)
246+
assert result == mock_auth_credential
247+
248+
@pytest.mark.asyncio
249+
async def test_load_credential_no_service(
250+
self, mock_invocation_context, mock_auth_config
251+
):
252+
"""Test load_credential when credential service is not available."""
253+
mock_invocation_context.credential_service = None
254+
255+
context = CallbackContext(mock_invocation_context)
256+
257+
with pytest.raises(
258+
ValueError, match="Credential service is not initialized"
259+
):
260+
await context.load_credential(mock_auth_config)
261+
262+
@pytest.mark.asyncio
263+
async def test_load_credential_returns_none(
264+
self, mock_invocation_context, mock_auth_config
265+
):
266+
"""Test load_credential returns None when credential not found."""
267+
# Mock credential service
268+
credential_service = AsyncMock()
269+
credential_service.load_credential.return_value = None
270+
mock_invocation_context.credential_service = credential_service
271+
272+
context = CallbackContext(mock_invocation_context)
273+
result = await context.load_credential(mock_auth_config)
274+
275+
credential_service.load_credential.assert_called_once_with(
276+
mock_auth_config, context
277+
)
278+
assert result is None
279+
280+
@pytest.mark.asyncio
281+
async def test_save_artifact_integration(self, mock_invocation_context):
282+
"""Test save_artifact to ensure credential methods follow same pattern."""
283+
# Mock artifact service
284+
artifact_service = AsyncMock()
285+
artifact_service.save_artifact.return_value = 1
286+
mock_invocation_context.artifact_service = artifact_service
287+
288+
context = CallbackContext(mock_invocation_context)
289+
test_artifact = Part.from_text(text="test content")
290+
291+
version = await context.save_artifact("test_file.txt", test_artifact)
292+
293+
artifact_service.save_artifact.assert_called_once_with(
294+
app_name="test-app",
295+
user_id="test-user",
296+
session_id="test-session-id",
297+
filename="test_file.txt",
298+
artifact=test_artifact,
299+
)
300+
assert version == 1
301+
302+
@pytest.mark.asyncio
303+
async def test_load_artifact_integration(self, mock_invocation_context):
304+
"""Test load_artifact to ensure credential methods follow same pattern."""
305+
# Mock artifact service
306+
artifact_service = AsyncMock()
307+
test_artifact = Part.from_text(text="test content")
308+
artifact_service.load_artifact.return_value = test_artifact
309+
mock_invocation_context.artifact_service = artifact_service
310+
311+
context = CallbackContext(mock_invocation_context)
312+
313+
result = await context.load_artifact("test_file.txt")
314+
315+
artifact_service.load_artifact.assert_called_once_with(
316+
app_name="test-app",
317+
user_id="test-user",
318+
session_id="test-session-id",
319+
filename="test_file.txt",
320+
version=None,
321+
)
322+
assert result == test_artifact

tests/unittests/auth/test_credential_manager.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,21 +167,19 @@ async def test_load_from_credential_service_with_service(self):
167167

168168
# Mock credential service
169169
credential_service = Mock()
170-
credential_service.load_credential = AsyncMock(return_value=mock_credential)
171170

172171
# Mock invocation context
173172
invocation_context = Mock()
174173
invocation_context.credential_service = credential_service
175174

176175
callback_context = Mock()
177176
callback_context._invocation_context = invocation_context
177+
callback_context.load_credential = AsyncMock(return_value=mock_credential)
178178

179179
manager = CredentialManager(auth_config)
180180
result = await manager._load_from_credential_service(callback_context)
181181

182-
credential_service.load_credential.assert_called_once_with(
183-
auth_config, callback_context
184-
)
182+
callback_context.load_credential.assert_called_once_with(auth_config)
185183
assert result == mock_credential
186184

187185
@pytest.mark.asyncio
@@ -216,13 +214,12 @@ async def test_save_credential_with_service(self):
216214

217215
callback_context = Mock()
218216
callback_context._invocation_context = invocation_context
217+
callback_context.save_credential = AsyncMock()
219218

220219
manager = CredentialManager(auth_config)
221220
await manager._save_credential(callback_context, mock_credential)
222221

223-
credential_service.save_credential.assert_called_once_with(
224-
auth_config, callback_context
225-
)
222+
callback_context.save_credential.assert_called_once_with(auth_config)
226223
assert auth_config.exchanged_auth_credential == mock_credential
227224

228225
@pytest.mark.asyncio
@@ -242,9 +239,9 @@ async def test_save_credential_no_service(self):
242239
manager = CredentialManager(auth_config)
243240
await manager._save_credential(callback_context, mock_credential)
244241

245-
# Should not raise an error, and credential should not be set in auth_config
246-
# when there's no credential service (according to implementation)
247-
assert auth_config.exchanged_auth_credential is None
242+
# Should not raise an error, and credential should be set in auth_config
243+
# even when there's no credential service (config is updated regardless)
244+
assert auth_config.exchanged_auth_credential == mock_credential
248245

249246
@pytest.mark.asyncio
250247
async def test_refresh_credential_oauth2(self):

0 commit comments

Comments
 (0)