16
16
17
17
from unittest .mock import AsyncMock
18
18
from unittest .mock import MagicMock
19
+ from unittest .mock import Mock
19
20
20
21
from google .adk .agents .callback_context import CallbackContext
22
+ from google .adk .auth .auth_credential import AuthCredential
23
+ from google .adk .auth .auth_credential import AuthCredentialTypes
24
+ from google .adk .auth .auth_tool import AuthConfig
21
25
from google .adk .tools .tool_context import ToolContext
26
+ from google .genai .types import Part
22
27
import pytest
23
28
24
29
@@ -32,6 +37,8 @@ def mock_invocation_context():
32
37
mock_context .session .id = "test-session-id"
33
38
mock_context .app_name = "test-app"
34
39
mock_context .user_id = "test-user"
40
+ mock_context .artifact_service = None
41
+ mock_context .credential_service = None
35
42
return mock_context
36
43
37
44
@@ -63,6 +70,21 @@ def callback_context_without_artifact_service(mock_invocation_context):
63
70
return CallbackContext (mock_invocation_context )
64
71
65
72
73
+ @pytest .fixture
74
+ def mock_auth_config ():
75
+ """Create a mock auth config for testing."""
76
+ mock_config = Mock (spec = AuthConfig )
77
+ return mock_config
78
+
79
+
80
+ @pytest .fixture
81
+ def mock_auth_credential ():
82
+ """Create a mock auth credential for testing."""
83
+ mock_credential = Mock (spec = AuthCredential )
84
+ mock_credential .auth_type = AuthCredentialTypes .OAUTH2
85
+ return mock_credential
86
+
87
+
66
88
class TestCallbackContextListArtifacts :
67
89
"""Test the list_artifacts method in CallbackContext."""
68
90
@@ -119,8 +141,8 @@ async def test_list_artifacts_passes_through_service_exceptions(
119
141
await callback_context_with_artifact_service .list_artifacts ()
120
142
121
143
122
- class TestToolContextListArtifacts :
123
- """Test that list_artifacts is available in ToolContext through inheritance ."""
144
+ class TestCallbackContext :
145
+ """Test suite for CallbackContext ."""
124
146
125
147
@pytest .mark .asyncio
126
148
async def test_tool_context_inherits_list_artifacts (
@@ -167,3 +189,134 @@ def test_tool_context_shares_same_list_artifacts_method_with_callback_context(
167
189
):
168
190
"""Test that ToolContext and CallbackContext share the same list_artifacts method."""
169
191
assert ToolContext .list_artifacts is CallbackContext .list_artifacts
192
+
193
+ def test_initialization (self , mock_invocation_context ):
194
+ """Test CallbackContext initialization."""
195
+ context = CallbackContext (mock_invocation_context )
196
+ assert context ._invocation_context == mock_invocation_context
197
+ assert context ._event_actions is not None
198
+ assert context ._state is not None
199
+
200
+ @pytest .mark .asyncio
201
+ async def test_save_credential_with_service (
202
+ self , mock_invocation_context , mock_auth_config
203
+ ):
204
+ """Test save_credential when credential service is available."""
205
+ # Mock credential service
206
+ credential_service = AsyncMock ()
207
+ mock_invocation_context .credential_service = credential_service
208
+
209
+ context = CallbackContext (mock_invocation_context )
210
+ await context .save_credential (mock_auth_config )
211
+
212
+ credential_service .save_credential .assert_called_once_with (
213
+ mock_auth_config , context
214
+ )
215
+
216
+ @pytest .mark .asyncio
217
+ async def test_save_credential_no_service (
218
+ self , mock_invocation_context , mock_auth_config
219
+ ):
220
+ """Test save_credential when credential service is not available."""
221
+ mock_invocation_context .credential_service = None
222
+
223
+ context = CallbackContext (mock_invocation_context )
224
+
225
+ with pytest .raises (
226
+ ValueError , match = "Credential service is not initialized"
227
+ ):
228
+ await context .save_credential (mock_auth_config )
229
+
230
+ @pytest .mark .asyncio
231
+ async def test_load_credential_with_service (
232
+ self , mock_invocation_context , mock_auth_config , mock_auth_credential
233
+ ):
234
+ """Test load_credential when credential service is available."""
235
+ # Mock credential service
236
+ credential_service = AsyncMock ()
237
+ credential_service .load_credential .return_value = mock_auth_credential
238
+ mock_invocation_context .credential_service = credential_service
239
+
240
+ context = CallbackContext (mock_invocation_context )
241
+ result = await context .load_credential (mock_auth_config )
242
+
243
+ credential_service .load_credential .assert_called_once_with (
244
+ mock_auth_config , context
245
+ )
246
+ assert result == mock_auth_credential
247
+
248
+ @pytest .mark .asyncio
249
+ async def test_load_credential_no_service (
250
+ self , mock_invocation_context , mock_auth_config
251
+ ):
252
+ """Test load_credential when credential service is not available."""
253
+ mock_invocation_context .credential_service = None
254
+
255
+ context = CallbackContext (mock_invocation_context )
256
+
257
+ with pytest .raises (
258
+ ValueError , match = "Credential service is not initialized"
259
+ ):
260
+ await context .load_credential (mock_auth_config )
261
+
262
+ @pytest .mark .asyncio
263
+ async def test_load_credential_returns_none (
264
+ self , mock_invocation_context , mock_auth_config
265
+ ):
266
+ """Test load_credential returns None when credential not found."""
267
+ # Mock credential service
268
+ credential_service = AsyncMock ()
269
+ credential_service .load_credential .return_value = None
270
+ mock_invocation_context .credential_service = credential_service
271
+
272
+ context = CallbackContext (mock_invocation_context )
273
+ result = await context .load_credential (mock_auth_config )
274
+
275
+ credential_service .load_credential .assert_called_once_with (
276
+ mock_auth_config , context
277
+ )
278
+ assert result is None
279
+
280
+ @pytest .mark .asyncio
281
+ async def test_save_artifact_integration (self , mock_invocation_context ):
282
+ """Test save_artifact to ensure credential methods follow same pattern."""
283
+ # Mock artifact service
284
+ artifact_service = AsyncMock ()
285
+ artifact_service .save_artifact .return_value = 1
286
+ mock_invocation_context .artifact_service = artifact_service
287
+
288
+ context = CallbackContext (mock_invocation_context )
289
+ test_artifact = Part .from_text (text = "test content" )
290
+
291
+ version = await context .save_artifact ("test_file.txt" , test_artifact )
292
+
293
+ artifact_service .save_artifact .assert_called_once_with (
294
+ app_name = "test-app" ,
295
+ user_id = "test-user" ,
296
+ session_id = "test-session-id" ,
297
+ filename = "test_file.txt" ,
298
+ artifact = test_artifact ,
299
+ )
300
+ assert version == 1
301
+
302
+ @pytest .mark .asyncio
303
+ async def test_load_artifact_integration (self , mock_invocation_context ):
304
+ """Test load_artifact to ensure credential methods follow same pattern."""
305
+ # Mock artifact service
306
+ artifact_service = AsyncMock ()
307
+ test_artifact = Part .from_text (text = "test content" )
308
+ artifact_service .load_artifact .return_value = test_artifact
309
+ mock_invocation_context .artifact_service = artifact_service
310
+
311
+ context = CallbackContext (mock_invocation_context )
312
+
313
+ result = await context .load_artifact ("test_file.txt" )
314
+
315
+ artifact_service .load_artifact .assert_called_once_with (
316
+ app_name = "test-app" ,
317
+ user_id = "test-user" ,
318
+ session_id = "test-session-id" ,
319
+ filename = "test_file.txt" ,
320
+ version = None ,
321
+ )
322
+ assert result == test_artifact
0 commit comments