Skip to content

Commit 95857f5

Browse files
committed
fix(ag-ui): mixed tool handling
Leverage message snapshots to ensure agents have access to the entire context of a conversation, include the details of non AG-UI tool requests in the message history. This is needed when the agent processes a non-AG-UI tool request before an AG-UI tool request, which requires multiple UI to agent interactions to complete.
1 parent a2a5674 commit 95857f5

File tree

4 files changed

+326
-35
lines changed

4 files changed

+326
-35
lines changed

docs/ag-ui.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,6 @@ convenience method, it accepts the same arguments as the
194194
AG-UI tools are seamlessly provided to the PydanticAI agent, enabling rich
195195
use experiences with frontend user interfaces.
196196

197-
!!! warning "Requests requiring PydanticAI and AG-UI tools"
198-
Agent requests which require a PydanticAI tool use followed by an
199-
AG-UI tool will currently fail to process correctly, as the response from
200-
AG-UI will not include details about the PydanticAI tool request / response.
201-
This will be fixed in a future release
202-
203197
#### Events
204198

205199
The adapter provides the ability for PydanticAI tools to send

pydantic_ai_ag_ui/pydantic_ai_ag_ui/adapter.py

Lines changed: 133 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
BaseEvent,
1515
DeveloperMessage,
1616
EventType,
17+
FunctionCall,
1718
Message,
19+
MessagesSnapshotEvent,
1820
RunAgentInput,
1921
RunErrorEvent,
2022
RunFinishedEvent,
@@ -24,6 +26,7 @@
2426
TextMessageEndEvent,
2527
TextMessageStartEvent,
2628
Tool as ToolAGUI,
29+
ToolCall,
2730
ToolCallArgsEvent,
2831
ToolCallEndEvent,
2932
ToolCallStartEvent,
@@ -34,6 +37,7 @@
3437

3538
from pydantic_ai import Agent, ModelRequestNode, models
3639
from pydantic_ai._output import OutputType
40+
from pydantic_ai._parts_manager import ModelResponsePartsManager
3741
from pydantic_ai.agent import RunOutputDataT
3842
from pydantic_ai.mcp import ToolResult
3943
from pydantic_ai.messages import (
@@ -43,6 +47,7 @@
4347
ModelRequest,
4448
ModelRequestPart,
4549
ModelResponse,
50+
ModelResponsePart,
4651
PartDeltaEvent,
4752
PartStartEvent,
4853
SystemPromptPart,
@@ -108,13 +113,6 @@ class Adapter(Generic[AgentDepsT, OutputDataT]):
108113
an adapter for running agents with Server-Sent Event (SSE) streaming
109114
responses using the AG-UI protocol.
110115
111-
# Warning
112-
113-
Agent requests which require a PydanticAI tool use followed by an AG-UI
114-
tool will currently fail to process correctly, as the response from AG-UI
115-
will not include details about the PydanticAI tool request / response.
116-
This will be fixed in a future release
117-
118116
Examples:
119117
This is an example of base usage with FastAPI.
120118
.. code-block:: python
@@ -239,14 +237,9 @@ async def run(
239237
if isinstance(deps, StateHandler):
240238
deps.set_state(run_input.state)
241239

242-
prompt: str | None = None
243-
if isinstance(run_input.messages[-1], UserMessage):
244-
prompt = run_input.messages[-1].content
245-
run_input.messages.pop()
246-
247240
run: AgentRun[AgentDepsT, Any]
248241
async with self.agent.iter(
249-
user_prompt=prompt,
242+
user_prompt=None,
250243
output_type=output_type,
251244
message_history=_convert_history(run_input.messages),
252245
model=model,
@@ -257,10 +250,21 @@ async def run(
257250
infer_name=infer_name,
258251
additional_tools=run_tools,
259252
) as run:
260-
async for event in self._agent_stream(tool_names, run):
253+
parts_manager: ModelResponsePartsManager = ModelResponsePartsManager()
254+
async for event in self._agent_stream(tool_names, run, parts_manager):
261255
if event is None:
262256
# Tool call signals early return, so we stop processing.
263257
self.logger.debug('tool call early return')
258+
259+
# TODO(steve): Remove this workaround, it's only needed as AG-UI doesn't
260+
# currently have a way to add server side tool calls to the message history
261+
# via events. To workaround this we create a full snapshot of the messages
262+
# and send that.
263+
snapshot: MessagesSnapshotEvent | None = self._message_snapshot(
264+
run, run_input.messages, parts_manager
265+
)
266+
if snapshot is not None:
267+
yield encoder.encode(snapshot)
264268
break
265269

266270
yield encoder.encode(event)
@@ -285,6 +289,102 @@ async def run(
285289

286290
self.logger.info('done thread_id=%s run_id=%s', run_input.thread_id, run_input.run_id)
287291

292+
def _message_snapshot(
293+
self, run: AgentRun[AgentDepsT, Any], messages: list[Message], parts_manager: ModelResponsePartsManager
294+
) -> MessagesSnapshotEvent | None:
295+
"""Create a message snapshot to replicate the current state of the run.
296+
297+
This method collects all messages from the run's state and the parts
298+
manager, converting them into AG-UI messages.
299+
300+
Args:
301+
run: The agent run instance.
302+
messages: The initial messages from the run input.
303+
parts_manager: The parts manager containing the response parts.
304+
305+
Returns:
306+
A full snapshot of the messages so far in the run if local tool
307+
calls were made, otherwise `None`.
308+
"""
309+
new_messages: list[ModelMessage] = run.ctx.state.message_history[len(messages) :]
310+
if not any(
311+
isinstance(request_part, ToolReturnPart)
312+
for msg in new_messages
313+
if isinstance(msg, ModelRequest)
314+
for request_part in msg.parts
315+
):
316+
# No tool calls were made, so we don't need a snapshot.
317+
return None
318+
319+
# Tool calls were made, so we need to create a snapshot.
320+
for msg in new_messages:
321+
match msg:
322+
case ModelRequest():
323+
for request_part in msg.parts:
324+
if isinstance(request_part, ToolReturnPart):
325+
messages.append(
326+
ToolMessage(
327+
id='result-' + request_part.tool_call_id,
328+
role=Role.TOOL,
329+
content=request_part.content,
330+
tool_call_id=request_part.tool_call_id,
331+
)
332+
)
333+
case ModelResponse():
334+
self._convert_response_parts(msg.parts, messages)
335+
336+
self._convert_response_parts(parts_manager.get_parts(), messages)
337+
338+
return MessagesSnapshotEvent(
339+
type=EventType.MESSAGES_SNAPSHOT,
340+
messages=messages,
341+
)
342+
343+
def _convert_response_parts(self, parts: list[ModelResponsePart], messages: list[Message]) -> None:
344+
"""Convert model response parts to AG-UI messages.
345+
346+
Args:
347+
parts: The list of model response parts to convert.
348+
messages: The list of messages to append the converted parts to.
349+
"""
350+
response_part: ModelResponsePart
351+
for response_part in parts:
352+
match response_part:
353+
case TextPart(): # pragma: no cover
354+
# This is not expected, but we handle it gracefully.
355+
messages.append(
356+
AssistantMessage(
357+
id=uuid.uuid4().hex,
358+
role=Role.ASSISTANT,
359+
content=response_part.content,
360+
)
361+
)
362+
case ToolCallPart():
363+
args: str = (
364+
json.dumps(response_part.args)
365+
if isinstance(response_part.args, dict)
366+
else response_part.args or '{}'
367+
)
368+
messages.append(
369+
AssistantMessage(
370+
id=uuid.uuid4().hex,
371+
role=Role.ASSISTANT,
372+
tool_calls=[
373+
ToolCall(
374+
id=response_part.tool_call_id,
375+
type='function',
376+
function=FunctionCall(
377+
name=response_part.tool_name,
378+
arguments=args,
379+
),
380+
)
381+
],
382+
),
383+
)
384+
case ThinkingPart(): # pragma: no cover
385+
# No AG-UI equivalent for thinking parts, so we skip them.
386+
pass
387+
288388
async def _tool_events(self, parts: list[ModelRequestPart]) -> AsyncGenerator[BaseEvent | None, None]:
289389
"""Check for tool call results that are AG-UI events.
290390
@@ -371,12 +471,14 @@ async def _agent_stream(
371471
self,
372472
tool_names: dict[str, str],
373473
run: AgentRun[AgentDepsT, Any],
474+
parts_manager: ModelResponsePartsManager,
374475
) -> AsyncGenerator[BaseEvent | None, None]:
375476
"""Run the agent streaming responses using AG-UI protocol events.
376477
377478
Args:
378479
tool_names: A mapping of tool names to their AG-UI names.
379480
run: The agent run to process.
481+
parts_manager: The parts manager to handle tool call parts.
380482
381483
Yields:
382484
AG-UI Server-Sent Events (SSE).
@@ -399,7 +501,7 @@ async def _agent_stream(
399501
async with node.stream(run.ctx) as request_stream:
400502
agent_event: AgentStreamEvent
401503
async for agent_event in request_stream:
402-
async for msg in self._handle_agent_event(tool_names, stream_ctx, agent_event):
504+
async for msg in self._handle_agent_event(tool_names, stream_ctx, agent_event, parts_manager):
403505
yield msg
404506

405507
for part_end in stream_ctx.part_ends:
@@ -410,6 +512,7 @@ async def _handle_agent_event(
410512
tool_names: dict[str, str],
411513
stream_ctx: _RequestStreamContext,
412514
agent_event: AgentStreamEvent,
515+
parts_manager: ModelResponsePartsManager,
413516
) -> AsyncGenerator[BaseEvent | None, None]:
414517
"""Handle an agent event and yield AG-UI protocol events.
415518
@@ -418,6 +521,7 @@ async def _handle_agent_event(
418521
tool_names: A mapping of tool names to their AG-UI names.
419522
stream_ctx: The request stream context to manage state.
420523
agent_event: The agent event to process.
524+
parts_manager: The parts manager to handle tool call parts.
421525
422526
Yields:
423527
AG-UI Server-Sent Events (SSE) based on the agent event.
@@ -454,9 +558,16 @@ async def _handle_agent_event(
454558
case ToolCallPart(): # pragma: no branch
455559
tool_name: str | None = tool_names.get(agent_event.part.tool_name)
456560
if not tool_name:
561+
# Local tool calls are not sent as events to the UI.
457562
stream_ctx.local_tool_calls.add(agent_event.part.tool_call_id)
458563
return
459564

565+
parts_manager.handle_tool_call_part(
566+
vendor_part_id=None,
567+
tool_name=agent_event.part.tool_name,
568+
args=agent_event.part.args,
569+
tool_call_id=agent_event.part.tool_call_id,
570+
)
460571
stream_ctx.last_tool_call_id = agent_event.part.tool_call_id
461572
yield ToolCallStartEvent(
462573
type=EventType.TOOL_CALL_START,
@@ -483,9 +594,15 @@ async def _handle_agent_event(
483594
)
484595
case ToolCallPartDelta(): # pragma: no branch
485596
if agent_event.delta.tool_call_id in stream_ctx.local_tool_calls:
486-
# Local tool calls are not sent to the UI.
597+
# Local tool calls are not sent as events to the UI.
487598
return
488599

600+
parts_manager.handle_tool_call_delta(
601+
vendor_part_id=None,
602+
tool_name=None,
603+
args=agent_event.delta.args_delta,
604+
tool_call_id=agent_event.delta.tool_call_id,
605+
)
489606
yield ToolCallArgsEvent(
490607
type=EventType.TOOL_CALL_ARGS,
491608
tool_call_id=agent_event.delta.tool_call_id

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,33 @@ class _WrappedToolOutput:
4545
value: Any | None
4646

4747

48+
@dataclass
49+
class TestToolCallPart:
50+
"""Represents a tool call in the test model."""
51+
52+
call_tools: list[str] | Literal['all'] = 'all'
53+
deltas: bool = False
54+
55+
56+
@dataclass
57+
class TestTextPart:
58+
"""Represents a text part in the test model."""
59+
60+
text: str
61+
62+
63+
TestPart = TestTextPart | TestToolCallPart
64+
"""A part of the test model response."""
65+
66+
67+
@dataclass
68+
class TestNode:
69+
"""A node in the test model."""
70+
71+
parts: list[TestPart]
72+
id: str = field(default_factory=_utils.generate_tool_call_id)
73+
74+
4875
@dataclass
4976
class TestModel(Model):
5077
"""A model specifically for testing purposes.
@@ -65,6 +92,8 @@ class TestModel(Model):
6592
"""List of tools to call. If `'all'`, all tools will be called."""
6693
tool_call_deltas: set[str] = field(default_factory=set)
6794
"""A set of tool call names which should result in tool call part deltas."""
95+
custom_response_nodes: list[TestNode] | None = None
96+
"""A list of nodes which defines a custom model response."""
6897
custom_output_text: str | None = None
6998
"""If set, this text is returned as the final output."""
7099
custom_output_args: Any | None = None
@@ -155,23 +184,71 @@ def _get_output(self, model_request_parameters: ModelRequestParameters) -> _Wrap
155184
else:
156185
return _WrappedTextOutput(None)
157186

187+
def _node_response(
188+
self,
189+
messages: list[ModelMessage],
190+
model_request_parameters: ModelRequestParameters,
191+
) -> ModelResponse | None:
192+
"""Returns a ModelResponse based on configured nodes.
193+
194+
Args:
195+
messages: The messages sent to the model.
196+
model_request_parameters: The parameters for the model request.
197+
198+
Returns:
199+
The response from the model, or `None` if no nodes configured or
200+
all nodes have already been processed.
201+
"""
202+
if not self.custom_response_nodes:
203+
# No nodes configured, follow the default behaviour.
204+
return None
205+
206+
# Pick up where we left off by counting the number of ModelResponse messages in the stream.
207+
# This allows us to stream the response in chunks, simulating a real model response.
208+
node: TestNode
209+
count: int = sum(isinstance(m, ModelResponse) for m in messages)
210+
if count < len(self.custom_response_nodes):
211+
node: TestNode = self.custom_response_nodes[count]
212+
assert node.parts, 'Node parts should not be empty.'
213+
214+
parts: list[ModelResponsePart] = []
215+
part: TestPart
216+
for part in node.parts:
217+
if isinstance(part, TestTextPart):
218+
assert model_request_parameters.allow_text_output, (
219+
'Plain response not allowed, but `part` is a `TestText`.'
220+
)
221+
parts.append(TextPart(part.text))
222+
elif isinstance(part, TestToolCallPart):
223+
tool_calls = self._get_tool_calls(model_request_parameters)
224+
if part.call_tools == 'all':
225+
parts.extend(ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls)
226+
else:
227+
parts.extend(
228+
ToolCallPart(name, self.gen_tool_args(args))
229+
for name, args in tool_calls
230+
if name in part.call_tools
231+
)
232+
return ModelResponse(vendor_id=node.id, parts=parts, model_name=self._model_name)
233+
158234
def _request(
159235
self,
160236
messages: list[ModelMessage],
161237
model_settings: ModelSettings | None,
162238
model_request_parameters: ModelRequestParameters,
163239
) -> ModelResponse:
164-
tool_calls = self._get_tool_calls(model_request_parameters)
165-
output_wrapper = self._get_output(model_request_parameters)
166-
output_tools = model_request_parameters.output_tools
240+
if (response := self._node_response(messages, model_request_parameters)) is not None:
241+
return response
167242

168-
# if there are tools, the first thing we want to do is call all of them
243+
tool_calls = self._get_tool_calls(model_request_parameters)
169244
if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
170245
return ModelResponse(
171246
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls],
172247
model_name=self._model_name,
173248
)
174249

250+
output_wrapper = self._get_output(model_request_parameters)
251+
output_tools = model_request_parameters.output_tools
175252
if messages: # pragma: no branch
176253
last_message = messages[-1]
177254
assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'

0 commit comments

Comments
 (0)