1
- from unittest .mock import AsyncMock
1
+ from unittest .mock import AsyncMock , MagicMock
2
2
3
+ import grpc
3
4
import pytest
4
5
5
6
from a2a .client import A2AGrpcClient
6
7
from a2a .grpc import a2a_pb2 , a2a_pb2_grpc
7
8
from a2a .types import (
8
9
AgentCapabilities ,
9
10
AgentCard ,
11
+ Artifact ,
10
12
Message ,
11
13
MessageSendParams ,
12
14
Part ,
15
+ PushNotificationAuthenticationInfo ,
16
+ PushNotificationConfig ,
13
17
Role ,
14
18
Task ,
19
+ TaskArtifactUpdateEvent ,
15
20
TaskIdParams ,
21
+ TaskPushNotificationConfig ,
16
22
TaskQueryParams ,
17
23
TaskState ,
18
24
TaskStatus ,
25
+ TaskStatusUpdateEvent ,
19
26
TextPart ,
20
27
)
21
28
from a2a .utils import proto_utils
29
+ from a2a .utils .errors import ServerError
22
30
23
31
24
32
# Fixtures
@@ -30,8 +38,8 @@ def mock_grpc_stub() -> AsyncMock:
30
38
stub .SendStreamingMessage = AsyncMock ()
31
39
stub .GetTask = AsyncMock ()
32
40
stub .CancelTask = AsyncMock ()
33
- stub .CreateTaskPushNotification = AsyncMock ()
34
- stub .GetTaskPushNotification = AsyncMock ()
41
+ stub .CreateTaskPushNotificationConfig = AsyncMock ()
42
+ stub .GetTaskPushNotificationConfig = AsyncMock ()
35
43
return stub
36
44
37
45
@@ -90,6 +98,78 @@ def sample_message() -> Message:
90
98
)
91
99
92
100
101
+ @pytest .fixture
102
+ def sample_artifact () -> Artifact :
103
+ """Provides a sample Artifact object."""
104
+ return Artifact (
105
+ artifactId = 'artifact-1' ,
106
+ name = 'example.txt' ,
107
+ description = 'An example artifact' ,
108
+ parts = [Part (root = TextPart (text = 'Hi there' ))],
109
+ metadata = {},
110
+ extensions = [],
111
+ )
112
+
113
+
114
+ @pytest .fixture
115
+ def sample_task_status_update_event () -> TaskStatusUpdateEvent :
116
+ """Provides a sample TaskStatusUpdateEvent."""
117
+ return TaskStatusUpdateEvent (
118
+ taskId = 'task-1' ,
119
+ contextId = 'ctx-1' ,
120
+ status = TaskStatus (state = TaskState .working ),
121
+ final = False ,
122
+ metadata = {},
123
+ )
124
+
125
+
126
+ @pytest .fixture
127
+ def sample_task_artifact_update_event (
128
+ sample_artifact ,
129
+ ) -> TaskArtifactUpdateEvent :
130
+ """Provides a sample TaskArtifactUpdateEvent."""
131
+ return TaskArtifactUpdateEvent (
132
+ taskId = 'task-1' ,
133
+ contextId = 'ctx-1' ,
134
+ artifact = sample_artifact ,
135
+ append = True ,
136
+ last_chunk = True ,
137
+ metadata = {},
138
+ )
139
+
140
+
141
+ @pytest .fixture
142
+ def sample_authentication_info () -> PushNotificationAuthenticationInfo :
143
+ """Provides a sample AuthenticationInfo object."""
144
+ return PushNotificationAuthenticationInfo (
145
+ schemes = ['apikey' , 'oauth2' ], credentials = 'secret-token'
146
+ )
147
+
148
+
149
+ @pytest .fixture
150
+ def sample_push_notification_config (
151
+ sample_authentication_info : PushNotificationAuthenticationInfo ,
152
+ ) -> PushNotificationConfig :
153
+ """Provides a sample PushNotificationConfig object."""
154
+ return PushNotificationConfig (
155
+ id = 'config-1' ,
156
+ url = 'https://example.com/notify' ,
157
+ token = 'example-token' ,
158
+ authentication = sample_authentication_info ,
159
+ )
160
+
161
+
162
+ @pytest .fixture
163
+ def sample_task_push_notification_config (
164
+ sample_push_notification_config : PushNotificationConfig ,
165
+ ) -> TaskPushNotificationConfig :
166
+ """Provides a sample TaskPushNotificationConfig object."""
167
+ return TaskPushNotificationConfig (
168
+ taskId = 'task-1' ,
169
+ pushNotificationConfig = sample_push_notification_config ,
170
+ )
171
+
172
+
93
173
@pytest .mark .asyncio
94
174
async def test_send_message_task_response (
95
175
grpc_client : A2AGrpcClient ,
@@ -109,6 +189,76 @@ async def test_send_message_task_response(
109
189
assert response .id == sample_task .id
110
190
111
191
192
+ @pytest .mark .asyncio
193
+ async def test_send_message_message_response (
194
+ grpc_client : A2AGrpcClient ,
195
+ mock_grpc_stub : AsyncMock ,
196
+ sample_message_send_params : MessageSendParams ,
197
+ sample_message : Message ,
198
+ ):
199
+ """Test send_message that returns a Message."""
200
+ mock_grpc_stub .SendMessage .return_value = a2a_pb2 .SendMessageResponse (
201
+ msg = proto_utils .ToProto .message (sample_message )
202
+ )
203
+
204
+ response = await grpc_client .send_message (sample_message_send_params )
205
+
206
+ mock_grpc_stub .SendMessage .assert_awaited_once ()
207
+ assert isinstance (response , Message )
208
+ assert response .messageId == sample_message .messageId
209
+
210
+
211
+ @pytest .mark .asyncio
212
+ async def test_send_message_streaming (
213
+ grpc_client : A2AGrpcClient ,
214
+ mock_grpc_stub : AsyncMock ,
215
+ sample_message_send_params : MessageSendParams ,
216
+ sample_message : Message ,
217
+ sample_task : Task ,
218
+ sample_task_status_update_event : TaskStatusUpdateEvent ,
219
+ sample_task_artifact_update_event : TaskArtifactUpdateEvent ,
220
+ ):
221
+ """Test send_message_streaming that yields responses."""
222
+ stream = MagicMock ()
223
+ stream .read = AsyncMock (
224
+ side_effect = [
225
+ a2a_pb2 .StreamResponse (
226
+ msg = proto_utils .ToProto .message (sample_message )
227
+ ),
228
+ a2a_pb2 .StreamResponse (task = proto_utils .ToProto .task (sample_task )),
229
+ a2a_pb2 .StreamResponse (
230
+ status_update = proto_utils .ToProto .task_status_update_event (
231
+ sample_task_status_update_event
232
+ )
233
+ ),
234
+ a2a_pb2 .StreamResponse (
235
+ artifact_update = proto_utils .ToProto .task_artifact_update_event (
236
+ sample_task_artifact_update_event
237
+ )
238
+ ),
239
+ grpc .aio .EOF ,
240
+ ]
241
+ )
242
+ mock_grpc_stub .SendStreamingMessage .return_value = stream
243
+
244
+ responses = [
245
+ response
246
+ async for response in grpc_client .send_message_streaming (
247
+ sample_message_send_params
248
+ )
249
+ ]
250
+
251
+ mock_grpc_stub .SendStreamingMessage .assert_awaited_once ()
252
+ assert isinstance (responses [0 ], Message )
253
+ assert responses [0 ].messageId == sample_message .messageId
254
+ assert isinstance (responses [1 ], Task )
255
+ assert responses [1 ].id == sample_task .id
256
+ assert isinstance (responses [2 ], TaskStatusUpdateEvent )
257
+ assert responses [2 ].taskId == sample_task_status_update_event .taskId
258
+ assert isinstance (responses [3 ], TaskArtifactUpdateEvent )
259
+ assert responses [3 ].taskId == sample_task_artifact_update_event .taskId
260
+
261
+
112
262
@pytest .mark .asyncio
113
263
async def test_get_task (
114
264
grpc_client : A2AGrpcClient , mock_grpc_stub : AsyncMock , sample_task : Task
@@ -143,3 +293,117 @@ async def test_cancel_task(
143
293
a2a_pb2 .CancelTaskRequest (name = f'tasks/{ sample_task .id } ' )
144
294
)
145
295
assert response .status .state == TaskState .canceled
296
+
297
+
298
+ @pytest .mark .asyncio
299
+ async def test_set_task_callback_with_valid_task (
300
+ grpc_client : A2AGrpcClient ,
301
+ mock_grpc_stub : AsyncMock ,
302
+ sample_task_push_notification_config : TaskPushNotificationConfig ,
303
+ ):
304
+ """Test setting a task push notification config with a valid task id."""
305
+ task_id = 'task-1'
306
+ config_id = 'config-1'
307
+ mock_grpc_stub .CreateTaskPushNotificationConfig .return_value = (
308
+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
309
+ parent = f'tasks/{ task_id } ' ,
310
+ config_id = config_id ,
311
+ config = proto_utils .ToProto .task_push_notification_config (
312
+ sample_task_push_notification_config
313
+ ),
314
+ )
315
+ )
316
+
317
+ response = await grpc_client .set_task_callback (
318
+ sample_task_push_notification_config
319
+ )
320
+
321
+ mock_grpc_stub .CreateTaskPushNotificationConfig .assert_awaited_once_with (
322
+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
323
+ config = proto_utils .ToProto .task_push_notification_config (
324
+ sample_task_push_notification_config
325
+ ),
326
+ )
327
+ )
328
+ assert response .taskId == task_id
329
+
330
+
331
+ @pytest .mark .asyncio
332
+ async def test_set_task_callback_with_invalid_task (
333
+ grpc_client : A2AGrpcClient ,
334
+ mock_grpc_stub : AsyncMock ,
335
+ sample_task_push_notification_config : TaskPushNotificationConfig ,
336
+ ):
337
+ """Test setting a task push notification config with a invalid task id."""
338
+ task_id = 'task-1'
339
+ config_id = 'config-1'
340
+ mock_grpc_stub .CreateTaskPushNotificationConfig .return_value = (
341
+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
342
+ parent = f'invalid-path-to-tasks/{ task_id } ' ,
343
+ config_id = config_id ,
344
+ config = proto_utils .ToProto .task_push_notification_config (
345
+ sample_task_push_notification_config
346
+ ),
347
+ )
348
+ )
349
+
350
+ with pytest .raises (ServerError ) as exc_info :
351
+ await grpc_client .set_task_callback (
352
+ sample_task_push_notification_config
353
+ )
354
+ assert 'No task for' in exc_info .value .error .message
355
+
356
+
357
+ @pytest .mark .asyncio
358
+ async def test_get_task_callback_with_valid_task (
359
+ grpc_client : A2AGrpcClient ,
360
+ mock_grpc_stub : AsyncMock ,
361
+ sample_task_push_notification_config : TaskPushNotificationConfig ,
362
+ ):
363
+ """Test retrieving a task push notification config with a valid task id."""
364
+ task_id = 'task-1'
365
+ config_id = 'config-1'
366
+ mock_grpc_stub .GetTaskPushNotificationConfig .return_value = (
367
+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
368
+ parent = f'tasks/{ task_id } ' ,
369
+ config_id = config_id ,
370
+ config = proto_utils .ToProto .task_push_notification_config (
371
+ sample_task_push_notification_config
372
+ ),
373
+ )
374
+ )
375
+ params = TaskIdParams (id = sample_task_push_notification_config .taskId )
376
+
377
+ response = await grpc_client .get_task_callback (params )
378
+
379
+ mock_grpc_stub .GetTaskPushNotificationConfig .assert_awaited_once_with (
380
+ a2a_pb2 .GetTaskPushNotificationConfigRequest (
381
+ name = f'tasks/{ params .id } /pushNotification/undefined' ,
382
+ )
383
+ )
384
+ assert response .taskId == task_id
385
+
386
+
387
+ @pytest .mark .asyncio
388
+ async def test_get_task_callback_with_invalid_task (
389
+ grpc_client : A2AGrpcClient ,
390
+ mock_grpc_stub : AsyncMock ,
391
+ sample_task_push_notification_config : TaskPushNotificationConfig ,
392
+ ):
393
+ """Test retrieving a task push notification config with a invalid task id."""
394
+ task_id = 'task-1'
395
+ config_id = 'config-1'
396
+ mock_grpc_stub .GetTaskPushNotificationConfig .return_value = (
397
+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
398
+ parent = f'invalid-path-to-tasks/{ task_id } ' ,
399
+ config_id = config_id ,
400
+ config = proto_utils .ToProto .task_push_notification_config (
401
+ sample_task_push_notification_config
402
+ ),
403
+ )
404
+ )
405
+ params = TaskIdParams (id = sample_task_push_notification_config .taskId )
406
+
407
+ with pytest .raises (ServerError ) as exc_info :
408
+ await grpc_client .get_task_callback (params )
409
+ assert 'No task for' in exc_info .value .error .message
0 commit comments