@@ -401,6 +401,90 @@ async def get_current_result():
401
401
mock_agent_executor .execute .assert_awaited_once ()
402
402
403
403
404
+ @pytest .mark .asyncio
405
+ async def test_on_message_send_with_push_notification_no_existing_Task ():
406
+ """Test on_message_send for new task sets push notification info if provided."""
407
+ mock_task_store = AsyncMock (spec = TaskStore )
408
+ mock_push_notification_store = AsyncMock (spec = PushNotificationConfigStore )
409
+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
410
+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
411
+
412
+ task_id = 'push_task_1'
413
+ context_id = 'push_ctx_1'
414
+
415
+ mock_task_store .get .return_value = (
416
+ None # Simulate new task scenario for TaskManager
417
+ )
418
+
419
+ # Mock _request_context_builder.build to return a context with the generated/confirmed IDs
420
+ mock_request_context = MagicMock (spec = RequestContext )
421
+ mock_request_context .task_id = task_id
422
+ mock_request_context .context_id = context_id
423
+ mock_request_context_builder .build .return_value = mock_request_context
424
+
425
+ request_handler = DefaultRequestHandler (
426
+ agent_executor = mock_agent_executor ,
427
+ task_store = mock_task_store ,
428
+ push_config_store = mock_push_notification_store ,
429
+ request_context_builder = mock_request_context_builder ,
430
+ )
431
+
432
+ push_config = PushNotificationConfig (url = 'http://callback.com/push' )
433
+ message_config = MessageSendConfiguration (
434
+ pushNotificationConfig = push_config ,
435
+ acceptedOutputModes = ['text/plain' ], # Added required field
436
+ )
437
+ params = MessageSendParams (
438
+ message = Message (
439
+ role = Role .user ,
440
+ messageId = 'msg_push' ,
441
+ parts = [],
442
+ taskId = task_id ,
443
+ contextId = context_id ,
444
+ ),
445
+ configuration = message_config ,
446
+ )
447
+
448
+ # Mock ResultAggregator and its consume_and_break_on_interrupt
449
+ mock_result_aggregator_instance = AsyncMock (spec = ResultAggregator )
450
+ final_task_result = create_sample_task (
451
+ task_id = task_id , context_id = context_id , status_state = TaskState .completed
452
+ )
453
+ mock_result_aggregator_instance .consume_and_break_on_interrupt .return_value = (
454
+ final_task_result ,
455
+ False ,
456
+ )
457
+
458
+ # Mock the current_result property to return the final task result
459
+ async def get_current_result ():
460
+ return final_task_result
461
+
462
+ # Configure the 'current_result' property on the type of the mock instance
463
+ type(mock_result_aggregator_instance ).current_result = PropertyMock (
464
+ return_value = get_current_result ()
465
+ )
466
+
467
+ with (
468
+ patch (
469
+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
470
+ return_value = mock_result_aggregator_instance ,
471
+ ),
472
+ patch (
473
+ 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task' ,
474
+ return_value = None ,
475
+ ),
476
+ ):
477
+ await request_handler .on_message_send (
478
+ params , create_server_call_context ()
479
+ )
480
+
481
+ mock_push_notification_store .set_info .assert_awaited_once_with (
482
+ task_id , push_config
483
+ )
484
+ # Other assertions for full flow if needed (e.g., agent execution)
485
+ mock_agent_executor .execute .assert_awaited_once ()
486
+
487
+
404
488
@pytest .mark .asyncio
405
489
async def test_on_message_send_no_result_from_aggregator ():
406
490
"""Test on_message_send when aggregator returns (None, False)."""
0 commit comments