55
55
TaskState .rejected ,
56
56
}
57
57
58
+
58
59
@trace_class (kind = SpanKind .SERVER )
59
60
class DefaultRequestHandler (RequestHandler ):
60
61
"""Default request handler for all incoming requests.
@@ -168,23 +169,25 @@ async def _run_event_stream(
168
169
await self .agent_executor .execute (request , queue )
169
170
await queue .close ()
170
171
171
- async def on_message_send (
172
+ async def _setup_message_execution (
172
173
self ,
173
174
params : MessageSendParams ,
174
175
context : ServerCallContext | None = None ,
175
- ) -> Message | Task :
176
- """Default handler for 'message/send' interface ( non-streaming) .
176
+ ) -> tuple [ TaskManager , str , EventQueue , ResultAggregator , asyncio . Task ] :
177
+ """Common setup logic for both streaming and non-streaming message handling .
177
178
178
- Starts the agent execution for the message and waits for the final
179
- result (Task or Message).
179
+ Returns:
180
+ A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
180
181
"""
182
+ # Create task manager and validate existing task
181
183
task_manager = TaskManager (
182
184
task_id = params .message .taskId ,
183
185
context_id = params .message .contextId ,
184
186
task_store = self .task_store ,
185
187
initial_message = params .message ,
186
188
)
187
189
task : Task | None = await task_manager .get_task ()
190
+
188
191
if task :
189
192
if task .status .state in TERMINAL_TASK_STATES :
190
193
raise ServerError (
@@ -206,6 +209,8 @@ async def on_message_send(
206
209
await self ._push_notifier .set_info (
207
210
task .id , params .configuration .pushNotificationConfig
208
211
)
212
+
213
+ # Build request context
209
214
request_context = await self ._request_context_builder .build (
210
215
params = params ,
211
216
task_id = task .id if task else None ,
@@ -222,13 +227,49 @@ async def on_message_send(
222
227
result_aggregator = ResultAggregator (task_manager )
223
228
# TODO: to manage the non-blocking flows.
224
229
producer_task = asyncio .create_task (
225
- self ._run_event_stream (
226
- request_context ,
227
- queue ,
228
- )
230
+ self ._run_event_stream (request_context , queue )
229
231
)
230
232
await self ._register_producer (task_id , producer_task )
231
233
234
+ return task_manager , task_id , queue , result_aggregator , producer_task
235
+
236
+ def _validate_task_id_match (self , task_id : str , event_task_id : str ) -> None :
237
+ """Validates that agent-generated task ID matches the expected task ID."""
238
+ if task_id != event_task_id :
239
+ logger .error (
240
+ f'Agent generated task_id={ event_task_id } does not match the RequestContext task_id={ task_id } .'
241
+ )
242
+ raise ServerError (
243
+ InternalError (message = 'Task ID mismatch in agent response' )
244
+ )
245
+
246
+ async def _send_push_notification_if_needed (
247
+ self , task_id : str , result_aggregator : ResultAggregator
248
+ ) -> None :
249
+ """Sends push notification if configured and task is available."""
250
+ if self ._push_notifier and task_id :
251
+ latest_task = await result_aggregator .current_result
252
+ if isinstance (latest_task , Task ):
253
+ await self ._push_notifier .send_notification (latest_task )
254
+
255
+ async def on_message_send (
256
+ self ,
257
+ params : MessageSendParams ,
258
+ context : ServerCallContext | None = None ,
259
+ ) -> Message | Task :
260
+ """Default handler for 'message/send' interface (non-streaming).
261
+
262
+ Starts the agent execution for the message and waits for the final
263
+ result (Task or Message).
264
+ """
265
+ (
266
+ task_manager ,
267
+ task_id ,
268
+ queue ,
269
+ result_aggregator ,
270
+ producer_task ,
271
+ ) = await self ._setup_message_execution (params , context )
272
+
232
273
consumer = EventConsumer (queue )
233
274
producer_task .add_done_callback (consumer .agent_task_callback )
234
275
@@ -241,13 +282,13 @@ async def on_message_send(
241
282
if not result :
242
283
raise ServerError (error = InternalError ())
243
284
244
- if isinstance (result , Task ) and task_id != result . id :
245
- logger . error (
246
- f'Agent generated task_id= { result . id } does not match the RequestContext task_id= { task_id } .'
247
- )
248
- raise ServerError (
249
- InternalError ( message = 'Task ID mismatch in agent response' )
250
- )
285
+ if isinstance (result , Task ):
286
+ self . _validate_task_id_match ( task_id , result . id )
287
+
288
+ await self . _send_push_notification_if_needed (
289
+ task_id , result_aggregator
290
+ )
291
+
251
292
except Exception as e :
252
293
logger .error (f'Agent execution failed. Error: { e } ' )
253
294
raise
@@ -272,85 +313,34 @@ async def on_message_send_stream(
272
313
Starts the agent execution and yields events as they are produced
273
314
by the agent.
274
315
"""
275
- task_manager = TaskManager (
276
- task_id = params .message .taskId ,
277
- context_id = params .message .contextId ,
278
- task_store = self .task_store ,
279
- initial_message = params .message ,
280
- )
281
- task : Task | None = await task_manager .get_task ()
282
-
283
- if task :
284
- if task .status .state in TERMINAL_TASK_STATES :
285
- raise ServerError (
286
- error = InvalidParamsError (
287
- message = f'Task { task .id } is in terminal state: { task .status .state } '
288
- )
289
- )
290
-
291
- task = task_manager .update_with_message (params .message , task )
292
- if self .should_add_push_info (params ):
293
- assert isinstance (self ._push_notifier , PushNotifier )
294
- assert isinstance (
295
- params .configuration , MessageSendConfiguration
296
- )
297
- assert isinstance (
298
- params .configuration .pushNotificationConfig ,
299
- PushNotificationConfig ,
300
- )
301
- await self ._push_notifier .set_info (
302
- task .id , params .configuration .pushNotificationConfig
303
- )
304
- else :
305
- queue = EventQueue ()
306
- result_aggregator = ResultAggregator (task_manager )
307
- request_context = await self ._request_context_builder .build (
308
- params = params ,
309
- task_id = task .id if task else None ,
310
- context_id = params .message .contextId ,
311
- task = task ,
312
- context = context ,
313
- )
314
-
315
- task_id = cast ('str' , request_context .task_id )
316
- queue = await self ._queue_manager .create_or_tap (task_id )
317
- producer_task = asyncio .create_task (
318
- self ._run_event_stream (
319
- request_context ,
320
- queue ,
321
- )
322
- )
323
- await self ._register_producer (task_id , producer_task )
316
+ (
317
+ task_manager ,
318
+ task_id ,
319
+ queue ,
320
+ result_aggregator ,
321
+ producer_task ,
322
+ ) = await self ._setup_message_execution (params , context )
324
323
325
324
try :
326
325
consumer = EventConsumer (queue )
327
326
producer_task .add_done_callback (consumer .agent_task_callback )
328
327
async for event in result_aggregator .consume_and_emit (consumer ):
329
328
if isinstance (event , Task ):
330
- if task_id != event .id :
331
- logger .error (
332
- f'Agent generated task_id={ event .id } does not match the RequestContext task_id={ task_id } .'
333
- )
334
- raise ServerError (
335
- InternalError (
336
- message = 'Task ID mismatch in agent response'
337
- )
338
- )
339
-
340
- if (
341
- self ._push_notifier
342
- and params .configuration
343
- and params .configuration .pushNotificationConfig
344
- ):
345
- await self ._push_notifier .set_info (
346
- task_id ,
347
- params .configuration .pushNotificationConfig ,
348
- )
349
-
350
- if self ._push_notifier and task_id :
351
- latest_task = await result_aggregator .current_result
352
- if isinstance (latest_task , Task ):
353
- await self ._push_notifier .send_notification (latest_task )
329
+ self ._validate_task_id_match (task_id , event .id )
330
+
331
+ if (
332
+ self ._push_notifier
333
+ and params .configuration
334
+ and params .configuration .pushNotificationConfig
335
+ ):
336
+ await self ._push_notifier .set_info (
337
+ task_id ,
338
+ params .configuration .pushNotificationConfig ,
339
+ )
340
+
341
+ await self ._send_push_notification_if_needed (
342
+ task_id , result_aggregator
343
+ )
354
344
yield event
355
345
finally :
356
346
await self ._cleanup_producer (producer_task , task_id )
0 commit comments