Skip to content

Commit 0274112

Browse files
authored
fix: Send push notifications for message/send (#298)
# Description Push notifications are not being triggered for message/send operations without an initial task.
1 parent 75aa4ed commit 0274112

File tree

2 files changed

+94
-32
lines changed

2 files changed

+94
-32
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
InvalidParamsError,
3434
ListTaskPushNotificationConfigParams,
3535
Message,
36-
MessageSendConfiguration,
3736
MessageSendParams,
38-
PushNotificationConfig,
3937
Task,
4038
TaskIdParams,
4139
TaskNotFoundError,
@@ -202,18 +200,6 @@ async def _setup_message_execution(
202200
)
203201

204202
task = task_manager.update_with_message(params.message, task)
205-
if self.should_add_push_info(params):
206-
assert self._push_config_store is not None
207-
assert isinstance(
208-
params.configuration, MessageSendConfiguration
209-
)
210-
assert isinstance(
211-
params.configuration.pushNotificationConfig,
212-
PushNotificationConfig,
213-
)
214-
await self._push_config_store.set_info(
215-
task.id, params.configuration.pushNotificationConfig
216-
)
217203

218204
# Build request context
219205
request_context = await self._request_context_builder.build(
@@ -228,6 +214,16 @@ async def _setup_message_execution(
228214
# Always assign a task ID. We may not actually upgrade to a task, but
229215
# dictating the task ID at this layer is useful for tracking running
230216
# agents.
217+
218+
if (
219+
self._push_config_store
220+
and params.configuration
221+
and params.configuration.pushNotificationConfig
222+
):
223+
await self._push_config_store.set_info(
224+
task_id, params.configuration.pushNotificationConfig
225+
)
226+
231227
queue = await self._queue_manager.create_or_tap(task_id)
232228
result_aggregator = ResultAggregator(task_manager)
233229
# TODO: to manage the non-blocking flows.
@@ -333,16 +329,6 @@ async def on_message_send_stream(
333329
if isinstance(event, Task):
334330
self._validate_task_id_match(task_id, event.id)
335331

336-
if (
337-
self._push_config_store
338-
and params.configuration
339-
and params.configuration.pushNotificationConfig
340-
):
341-
await self._push_config_store.set_info(
342-
task_id,
343-
params.configuration.pushNotificationConfig,
344-
)
345-
346332
await self._send_push_notification_if_needed(
347333
task_id, result_aggregator
348334
)
@@ -509,11 +495,3 @@ async def on_delete_task_push_notification_config(
509495
await self._push_config_store.delete_info(
510496
params.id, params.pushNotificationConfigId
511497
)
512-
513-
def should_add_push_info(self, params: MessageSendParams) -> bool:
514-
"""Determines if push notification info should be set for a task."""
515-
return bool(
516-
self._push_config_store
517-
and params.configuration
518-
and params.configuration.pushNotificationConfig
519-
)

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,90 @@ async def get_current_result():
401401
mock_agent_executor.execute.assert_awaited_once()
402402

403403

404+
@pytest.mark.asyncio
405+
async def test_on_message_send_with_push_notification_no_existing_Task():
406+
"""Test on_message_send for new task sets push notification info if provided."""
407+
mock_task_store = AsyncMock(spec=TaskStore)
408+
mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore)
409+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
410+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
411+
412+
task_id = 'push_task_1'
413+
context_id = 'push_ctx_1'
414+
415+
mock_task_store.get.return_value = (
416+
None # Simulate new task scenario for TaskManager
417+
)
418+
419+
# Mock _request_context_builder.build to return a context with the generated/confirmed IDs
420+
mock_request_context = MagicMock(spec=RequestContext)
421+
mock_request_context.task_id = task_id
422+
mock_request_context.context_id = context_id
423+
mock_request_context_builder.build.return_value = mock_request_context
424+
425+
request_handler = DefaultRequestHandler(
426+
agent_executor=mock_agent_executor,
427+
task_store=mock_task_store,
428+
push_config_store=mock_push_notification_store,
429+
request_context_builder=mock_request_context_builder,
430+
)
431+
432+
push_config = PushNotificationConfig(url='http://callback.com/push')
433+
message_config = MessageSendConfiguration(
434+
pushNotificationConfig=push_config,
435+
acceptedOutputModes=['text/plain'], # Added required field
436+
)
437+
params = MessageSendParams(
438+
message=Message(
439+
role=Role.user,
440+
messageId='msg_push',
441+
parts=[],
442+
taskId=task_id,
443+
contextId=context_id,
444+
),
445+
configuration=message_config,
446+
)
447+
448+
# Mock ResultAggregator and its consume_and_break_on_interrupt
449+
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
450+
final_task_result = create_sample_task(
451+
task_id=task_id, context_id=context_id, status_state=TaskState.completed
452+
)
453+
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
454+
final_task_result,
455+
False,
456+
)
457+
458+
# Mock the current_result property to return the final task result
459+
async def get_current_result():
460+
return final_task_result
461+
462+
# Configure the 'current_result' property on the type of the mock instance
463+
type(mock_result_aggregator_instance).current_result = PropertyMock(
464+
return_value=get_current_result()
465+
)
466+
467+
with (
468+
patch(
469+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
470+
return_value=mock_result_aggregator_instance,
471+
),
472+
patch(
473+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
474+
return_value=None,
475+
),
476+
):
477+
await request_handler.on_message_send(
478+
params, create_server_call_context()
479+
)
480+
481+
mock_push_notification_store.set_info.assert_awaited_once_with(
482+
task_id, push_config
483+
)
484+
# Other assertions for full flow if needed (e.g., agent execution)
485+
mock_agent_executor.execute.assert_awaited_once()
486+
487+
404488
@pytest.mark.asyncio
405489
async def test_on_message_send_no_result_from_aggregator():
406490
"""Test on_message_send when aggregator returns (None, False)."""

0 commit comments

Comments
 (0)