14
14
BaseEvent ,
15
15
DeveloperMessage ,
16
16
EventType ,
17
+ FunctionCall ,
17
18
Message ,
19
+ MessagesSnapshotEvent ,
18
20
RunAgentInput ,
19
21
RunErrorEvent ,
20
22
RunFinishedEvent ,
24
26
TextMessageEndEvent ,
25
27
TextMessageStartEvent ,
26
28
Tool as ToolAGUI ,
29
+ ToolCall ,
27
30
ToolCallArgsEvent ,
28
31
ToolCallEndEvent ,
29
32
ToolCallStartEvent ,
34
37
35
38
from pydantic_ai import Agent , ModelRequestNode , models
36
39
from pydantic_ai ._output import OutputType
40
+ from pydantic_ai ._parts_manager import ModelResponsePartsManager
37
41
from pydantic_ai .agent import RunOutputDataT
38
42
from pydantic_ai .mcp import ToolResult
39
43
from pydantic_ai .messages import (
43
47
ModelRequest ,
44
48
ModelRequestPart ,
45
49
ModelResponse ,
50
+ ModelResponsePart ,
46
51
PartDeltaEvent ,
47
52
PartStartEvent ,
48
53
SystemPromptPart ,
@@ -108,13 +113,6 @@ class Adapter(Generic[AgentDepsT, OutputDataT]):
108
113
an adapter for running agents with Server-Sent Event (SSE) streaming
109
114
responses using the AG-UI protocol.
110
115
111
- # Warning
112
-
113
- Agent requests which require a PydanticAI tool use followed by an AG-UI
114
- tool will currently fail to process correctly, as the response from AG-UI
115
- will not include details about the PydanticAI tool request / response.
116
- This will be fixed in a future release
117
-
118
116
Examples:
119
117
This is an example of base usage with FastAPI.
120
118
.. code-block:: python
@@ -239,14 +237,9 @@ async def run(
239
237
if isinstance (deps , StateHandler ):
240
238
deps .set_state (run_input .state )
241
239
242
- prompt : str | None = None
243
- if isinstance (run_input .messages [- 1 ], UserMessage ):
244
- prompt = run_input .messages [- 1 ].content
245
- run_input .messages .pop ()
246
-
247
240
run : AgentRun [AgentDepsT , Any ]
248
241
async with self .agent .iter (
249
- user_prompt = prompt ,
242
+ user_prompt = None ,
250
243
output_type = output_type ,
251
244
message_history = _convert_history (run_input .messages ),
252
245
model = model ,
@@ -257,10 +250,21 @@ async def run(
257
250
infer_name = infer_name ,
258
251
additional_tools = run_tools ,
259
252
) as run :
260
- async for event in self ._agent_stream (tool_names , run ):
253
+ parts_manager : ModelResponsePartsManager = ModelResponsePartsManager ()
254
+ async for event in self ._agent_stream (tool_names , run , parts_manager ):
261
255
if event is None :
262
256
# Tool call signals early return, so we stop processing.
263
257
self .logger .debug ('tool call early return' )
258
+
259
+ # TODO(steve): Remove this workaround, it's only needed as AG-UI doesn't
260
+ # currently have a way to add server side tool calls to the message history
261
+ # via events. To workaround this we create a full snapshot of the messages
262
+ # and send that.
263
+ snapshot : MessagesSnapshotEvent | None = self ._message_snapshot (
264
+ run , run_input .messages , parts_manager
265
+ )
266
+ if snapshot is not None :
267
+ yield encoder .encode (snapshot )
264
268
break
265
269
266
270
yield encoder .encode (event )
@@ -285,6 +289,102 @@ async def run(
285
289
286
290
self .logger .info ('done thread_id=%s run_id=%s' , run_input .thread_id , run_input .run_id )
287
291
292
+ def _message_snapshot (
293
+ self , run : AgentRun [AgentDepsT , Any ], messages : list [Message ], parts_manager : ModelResponsePartsManager
294
+ ) -> MessagesSnapshotEvent | None :
295
+ """Create a message snapshot to replicate the current state of the run.
296
+
297
+ This method collects all messages from the run's state and the parts
298
+ manager, converting them into AG-UI messages.
299
+
300
+ Args:
301
+ run: The agent run instance.
302
+ messages: The initial messages from the run input.
303
+ parts_manager: The parts manager containing the response parts.
304
+
305
+ Returns:
306
+ A full snapshot of the messages so far in the run if local tool
307
+ calls were made, otherwise `None`.
308
+ """
309
+ new_messages : list [ModelMessage ] = run .ctx .state .message_history [len (messages ) :]
310
+ if not any (
311
+ isinstance (request_part , ToolReturnPart )
312
+ for msg in new_messages
313
+ if isinstance (msg , ModelRequest )
314
+ for request_part in msg .parts
315
+ ):
316
+ # No tool calls were made, so we don't need a snapshot.
317
+ return None
318
+
319
+ # Tool calls were made, so we need to create a snapshot.
320
+ for msg in new_messages :
321
+ match msg :
322
+ case ModelRequest ():
323
+ for request_part in msg .parts :
324
+ if isinstance (request_part , ToolReturnPart ):
325
+ messages .append (
326
+ ToolMessage (
327
+ id = 'result-' + request_part .tool_call_id ,
328
+ role = Role .TOOL ,
329
+ content = request_part .content ,
330
+ tool_call_id = request_part .tool_call_id ,
331
+ )
332
+ )
333
+ case ModelResponse ():
334
+ self ._convert_response_parts (msg .parts , messages )
335
+
336
+ self ._convert_response_parts (parts_manager .get_parts (), messages )
337
+
338
+ return MessagesSnapshotEvent (
339
+ type = EventType .MESSAGES_SNAPSHOT ,
340
+ messages = messages ,
341
+ )
342
+
343
+ def _convert_response_parts (self , parts : list [ModelResponsePart ], messages : list [Message ]) -> None :
344
+ """Convert model response parts to AG-UI messages.
345
+
346
+ Args:
347
+ parts: The list of model response parts to convert.
348
+ messages: The list of messages to append the converted parts to.
349
+ """
350
+ response_part : ModelResponsePart
351
+ for response_part in parts :
352
+ match response_part :
353
+ case TextPart (): # pragma: no cover
354
+ # This is not expected, but we handle it gracefully.
355
+ messages .append (
356
+ AssistantMessage (
357
+ id = uuid .uuid4 ().hex ,
358
+ role = Role .ASSISTANT ,
359
+ content = response_part .content ,
360
+ )
361
+ )
362
+ case ToolCallPart ():
363
+ args : str = (
364
+ json .dumps (response_part .args )
365
+ if isinstance (response_part .args , dict )
366
+ else response_part .args or '{}'
367
+ )
368
+ messages .append (
369
+ AssistantMessage (
370
+ id = uuid .uuid4 ().hex ,
371
+ role = Role .ASSISTANT ,
372
+ tool_calls = [
373
+ ToolCall (
374
+ id = response_part .tool_call_id ,
375
+ type = 'function' ,
376
+ function = FunctionCall (
377
+ name = response_part .tool_name ,
378
+ arguments = args ,
379
+ ),
380
+ )
381
+ ],
382
+ ),
383
+ )
384
+ case ThinkingPart (): # pragma: no cover
385
+ # No AG-UI equivalent for thinking parts, so we skip them.
386
+ pass
387
+
288
388
async def _tool_events (self , parts : list [ModelRequestPart ]) -> AsyncGenerator [BaseEvent | None , None ]:
289
389
"""Check for tool call results that are AG-UI events.
290
390
@@ -371,12 +471,14 @@ async def _agent_stream(
371
471
self ,
372
472
tool_names : dict [str , str ],
373
473
run : AgentRun [AgentDepsT , Any ],
474
+ parts_manager : ModelResponsePartsManager ,
374
475
) -> AsyncGenerator [BaseEvent | None , None ]:
375
476
"""Run the agent streaming responses using AG-UI protocol events.
376
477
377
478
Args:
378
479
tool_names: A mapping of tool names to their AG-UI names.
379
480
run: The agent run to process.
481
+ parts_manager: The parts manager to handle tool call parts.
380
482
381
483
Yields:
382
484
AG-UI Server-Sent Events (SSE).
@@ -399,7 +501,7 @@ async def _agent_stream(
399
501
async with node .stream (run .ctx ) as request_stream :
400
502
agent_event : AgentStreamEvent
401
503
async for agent_event in request_stream :
402
- async for msg in self ._handle_agent_event (tool_names , stream_ctx , agent_event ):
504
+ async for msg in self ._handle_agent_event (tool_names , stream_ctx , agent_event , parts_manager ):
403
505
yield msg
404
506
405
507
for part_end in stream_ctx .part_ends :
@@ -410,6 +512,7 @@ async def _handle_agent_event(
410
512
tool_names : dict [str , str ],
411
513
stream_ctx : _RequestStreamContext ,
412
514
agent_event : AgentStreamEvent ,
515
+ parts_manager : ModelResponsePartsManager ,
413
516
) -> AsyncGenerator [BaseEvent | None , None ]:
414
517
"""Handle an agent event and yield AG-UI protocol events.
415
518
@@ -418,6 +521,7 @@ async def _handle_agent_event(
418
521
tool_names: A mapping of tool names to their AG-UI names.
419
522
stream_ctx: The request stream context to manage state.
420
523
agent_event: The agent event to process.
524
+ parts_manager: The parts manager to handle tool call parts.
421
525
422
526
Yields:
423
527
AG-UI Server-Sent Events (SSE) based on the agent event.
@@ -454,9 +558,16 @@ async def _handle_agent_event(
454
558
case ToolCallPart (): # pragma: no branch
455
559
tool_name : str | None = tool_names .get (agent_event .part .tool_name )
456
560
if not tool_name :
561
+ # Local tool calls are not sent as events to the UI.
457
562
stream_ctx .local_tool_calls .add (agent_event .part .tool_call_id )
458
563
return
459
564
565
+ parts_manager .handle_tool_call_part (
566
+ vendor_part_id = None ,
567
+ tool_name = agent_event .part .tool_name ,
568
+ args = agent_event .part .args ,
569
+ tool_call_id = agent_event .part .tool_call_id ,
570
+ )
460
571
stream_ctx .last_tool_call_id = agent_event .part .tool_call_id
461
572
yield ToolCallStartEvent (
462
573
type = EventType .TOOL_CALL_START ,
@@ -483,9 +594,15 @@ async def _handle_agent_event(
483
594
)
484
595
case ToolCallPartDelta (): # pragma: no branch
485
596
if agent_event .delta .tool_call_id in stream_ctx .local_tool_calls :
486
- # Local tool calls are not sent to the UI.
597
+ # Local tool calls are not sent as events to the UI.
487
598
return
488
599
600
+ parts_manager .handle_tool_call_delta (
601
+ vendor_part_id = None ,
602
+ tool_name = None ,
603
+ args = agent_event .delta .args_delta ,
604
+ tool_call_id = agent_event .delta .tool_call_id ,
605
+ )
489
606
yield ToolCallArgsEvent (
490
607
type = EventType .TOOL_CALL_ARGS ,
491
608
tool_call_id = agent_event .delta .tool_call_id
0 commit comments