Skip to content

Commit 91539d6

Browse files
authored
fix: send notifications on message not streaming (#219)
# Description The proposed fix, if the team does want push notifications to be supported in a non-streaming setup Fixes #218
1 parent 9dd7783 commit 91539d6

File tree

2 files changed

+94
-90
lines changed

2 files changed

+94
-90
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 79 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TaskState.rejected,
5656
}
5757

58+
5859
@trace_class(kind=SpanKind.SERVER)
5960
class DefaultRequestHandler(RequestHandler):
6061
"""Default request handler for all incoming requests.
@@ -168,23 +169,25 @@ async def _run_event_stream(
168169
await self.agent_executor.execute(request, queue)
169170
await queue.close()
170171

171-
async def on_message_send(
172+
async def _setup_message_execution(
172173
self,
173174
params: MessageSendParams,
174175
context: ServerCallContext | None = None,
175-
) -> Message | Task:
176-
"""Default handler for 'message/send' interface (non-streaming).
176+
) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]:
177+
"""Common setup logic for both streaming and non-streaming message handling.
177178
178-
Starts the agent execution for the message and waits for the final
179-
result (Task or Message).
179+
Returns:
180+
A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
180181
"""
182+
# Create task manager and validate existing task
181183
task_manager = TaskManager(
182184
task_id=params.message.taskId,
183185
context_id=params.message.contextId,
184186
task_store=self.task_store,
185187
initial_message=params.message,
186188
)
187189
task: Task | None = await task_manager.get_task()
190+
188191
if task:
189192
if task.status.state in TERMINAL_TASK_STATES:
190193
raise ServerError(
@@ -206,6 +209,8 @@ async def on_message_send(
206209
await self._push_notifier.set_info(
207210
task.id, params.configuration.pushNotificationConfig
208211
)
212+
213+
# Build request context
209214
request_context = await self._request_context_builder.build(
210215
params=params,
211216
task_id=task.id if task else None,
@@ -222,13 +227,49 @@ async def on_message_send(
222227
result_aggregator = ResultAggregator(task_manager)
223228
# TODO: to manage the non-blocking flows.
224229
producer_task = asyncio.create_task(
225-
self._run_event_stream(
226-
request_context,
227-
queue,
228-
)
230+
self._run_event_stream(request_context, queue)
229231
)
230232
await self._register_producer(task_id, producer_task)
231233

234+
return task_manager, task_id, queue, result_aggregator, producer_task
235+
236+
def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
237+
"""Validates that agent-generated task ID matches the expected task ID."""
238+
if task_id != event_task_id:
239+
logger.error(
240+
f'Agent generated task_id={event_task_id} does not match the RequestContext task_id={task_id}.'
241+
)
242+
raise ServerError(
243+
InternalError(message='Task ID mismatch in agent response')
244+
)
245+
246+
async def _send_push_notification_if_needed(
247+
self, task_id: str, result_aggregator: ResultAggregator
248+
) -> None:
249+
"""Sends push notification if configured and task is available."""
250+
if self._push_notifier and task_id:
251+
latest_task = await result_aggregator.current_result
252+
if isinstance(latest_task, Task):
253+
await self._push_notifier.send_notification(latest_task)
254+
255+
async def on_message_send(
256+
self,
257+
params: MessageSendParams,
258+
context: ServerCallContext | None = None,
259+
) -> Message | Task:
260+
"""Default handler for 'message/send' interface (non-streaming).
261+
262+
Starts the agent execution for the message and waits for the final
263+
result (Task or Message).
264+
"""
265+
(
266+
task_manager,
267+
task_id,
268+
queue,
269+
result_aggregator,
270+
producer_task,
271+
) = await self._setup_message_execution(params, context)
272+
232273
consumer = EventConsumer(queue)
233274
producer_task.add_done_callback(consumer.agent_task_callback)
234275

@@ -241,13 +282,13 @@ async def on_message_send(
241282
if not result:
242283
raise ServerError(error=InternalError())
243284

244-
if isinstance(result, Task) and task_id != result.id:
245-
logger.error(
246-
f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.'
247-
)
248-
raise ServerError(
249-
InternalError(message='Task ID mismatch in agent response')
250-
)
285+
if isinstance(result, Task):
286+
self._validate_task_id_match(task_id, result.id)
287+
288+
await self._send_push_notification_if_needed(
289+
task_id, result_aggregator
290+
)
291+
251292
except Exception as e:
252293
logger.error(f'Agent execution failed. Error: {e}')
253294
raise
@@ -272,85 +313,34 @@ async def on_message_send_stream(
272313
Starts the agent execution and yields events as they are produced
273314
by the agent.
274315
"""
275-
task_manager = TaskManager(
276-
task_id=params.message.taskId,
277-
context_id=params.message.contextId,
278-
task_store=self.task_store,
279-
initial_message=params.message,
280-
)
281-
task: Task | None = await task_manager.get_task()
282-
283-
if task:
284-
if task.status.state in TERMINAL_TASK_STATES:
285-
raise ServerError(
286-
error=InvalidParamsError(
287-
message=f'Task {task.id} is in terminal state: {task.status.state}'
288-
)
289-
)
290-
291-
task = task_manager.update_with_message(params.message, task)
292-
if self.should_add_push_info(params):
293-
assert isinstance(self._push_notifier, PushNotifier)
294-
assert isinstance(
295-
params.configuration, MessageSendConfiguration
296-
)
297-
assert isinstance(
298-
params.configuration.pushNotificationConfig,
299-
PushNotificationConfig,
300-
)
301-
await self._push_notifier.set_info(
302-
task.id, params.configuration.pushNotificationConfig
303-
)
304-
else:
305-
queue = EventQueue()
306-
result_aggregator = ResultAggregator(task_manager)
307-
request_context = await self._request_context_builder.build(
308-
params=params,
309-
task_id=task.id if task else None,
310-
context_id=params.message.contextId,
311-
task=task,
312-
context=context,
313-
)
314-
315-
task_id = cast('str', request_context.task_id)
316-
queue = await self._queue_manager.create_or_tap(task_id)
317-
producer_task = asyncio.create_task(
318-
self._run_event_stream(
319-
request_context,
320-
queue,
321-
)
322-
)
323-
await self._register_producer(task_id, producer_task)
316+
(
317+
task_manager,
318+
task_id,
319+
queue,
320+
result_aggregator,
321+
producer_task,
322+
) = await self._setup_message_execution(params, context)
324323

325324
try:
326325
consumer = EventConsumer(queue)
327326
producer_task.add_done_callback(consumer.agent_task_callback)
328327
async for event in result_aggregator.consume_and_emit(consumer):
329328
if isinstance(event, Task):
330-
if task_id != event.id:
331-
logger.error(
332-
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
333-
)
334-
raise ServerError(
335-
InternalError(
336-
message='Task ID mismatch in agent response'
337-
)
338-
)
339-
340-
if (
341-
self._push_notifier
342-
and params.configuration
343-
and params.configuration.pushNotificationConfig
344-
):
345-
await self._push_notifier.set_info(
346-
task_id,
347-
params.configuration.pushNotificationConfig,
348-
)
349-
350-
if self._push_notifier and task_id:
351-
latest_task = await result_aggregator.current_result
352-
if isinstance(latest_task, Task):
353-
await self._push_notifier.send_notification(latest_task)
329+
self._validate_task_id_match(task_id, event.id)
330+
331+
if (
332+
self._push_notifier
333+
and params.configuration
334+
and params.configuration.pushNotificationConfig
335+
):
336+
await self._push_notifier.set_info(
337+
task_id,
338+
params.configuration.pushNotificationConfig,
339+
)
340+
341+
await self._send_push_notification_if_needed(
342+
task_id, result_aggregator
343+
)
354344
yield event
355345
finally:
356346
await self._cleanup_producer(producer_task, task_id)

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,15 @@ async def test_on_message_send_with_push_notification():
361361
False,
362362
)
363363

364+
# Mock the current_result property to return the final task result
365+
async def get_current_result():
366+
return final_task_result
367+
368+
# Configure the 'current_result' property on the type of the mock instance
369+
type(mock_result_aggregator_instance).current_result = PropertyMock(
370+
return_value=get_current_result()
371+
)
372+
364373
with (
365374
patch(
366375
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
@@ -380,6 +389,9 @@ async def test_on_message_send_with_push_notification():
380389
)
381390

382391
mock_push_notifier.set_info.assert_awaited_once_with(task_id, push_config)
392+
mock_push_notifier.send_notification.assert_awaited_once_with(
393+
final_task_result
394+
)
383395
# Other assertions for full flow if needed (e.g., agent execution)
384396
mock_agent_executor.execute.assert_awaited_once()
385397

@@ -1139,12 +1151,14 @@ async def consume_stream():
11391151
texts = [p.root.text for e in events for p in e.status.message.parts]
11401152
assert texts == ['Event 0', 'Event 1', 'Event 2']
11411153

1154+
11421155
TERMINAL_TASK_STATES = {
11431156
TaskState.completed,
11441157
TaskState.canceled,
11451158
TaskState.failed,
11461159
TaskState.rejected,
1147-
}
1160+
}
1161+
11481162

11491163
@pytest.mark.asyncio
11501164
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)

0 commit comments

Comments
 (0)