Skip to content

Fix AG-UI parallel tool calls #2301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 60 additions & 77 deletions pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ async def run(
if isinstance(deps, StateHandler):
deps.state = run_input.state

history = _History.from_ag_ui(run_input.messages)
messages = _messages_from_ag_ui(run_input.messages)

async with self.agent.iter(
user_prompt=None,
output_type=[output_type or self.agent.output_type, DeferredToolCalls],
message_history=history.messages,
message_history=messages,
model=model,
deps=deps,
model_settings=model_settings,
Expand All @@ -305,7 +305,7 @@ async def run(
infer_name=infer_name,
toolsets=toolsets,
) as run:
async for event in self._agent_stream(run, history):
async for event in self._agent_stream(run):
yield encoder.encode(event)
except _RunError as e:
yield encoder.encode(
Expand All @@ -327,20 +327,18 @@ async def run(
async def _agent_stream(
self,
run: AgentRun[AgentDepsT, Any],
history: _History,
) -> AsyncGenerator[BaseEvent, None]:
"""Run the agent streaming responses using AG-UI protocol events.

Args:
run: The agent run to process.
history: The history of messages and tool calls to use for the run.

Yields:
AG-UI Server-Sent Events (SSE).
"""
async for node in run:
stream_ctx = _RequestStreamContext()
if isinstance(node, ModelRequestNode):
stream_ctx = _RequestStreamContext()
async with node.stream(run.ctx) as request_stream:
async for agent_event in request_stream:
async for msg in self._handle_model_request_event(stream_ctx, agent_event):
Expand All @@ -352,8 +350,8 @@ async def _agent_stream(
elif isinstance(node, CallToolsNode):
async with node.stream(run.ctx) as handle_stream:
async for event in handle_stream:
if isinstance(event, FunctionToolResultEvent) and isinstance(event.result, ToolReturnPart):
async for msg in self._handle_tool_result_event(event.result, history.prompt_message_id):
if isinstance(event, FunctionToolResultEvent):
async for msg in self._handle_tool_result_event(stream_ctx, event):
yield msg

async def _handle_model_request_event(
Expand Down Expand Up @@ -391,9 +389,11 @@ async def _handle_model_request_event(
delta=part.content,
)
elif isinstance(part, ToolCallPart): # pragma: no branch
message_id = stream_ctx.message_id or stream_ctx.new_message_id()
yield ToolCallStartEvent(
tool_call_id=part.tool_call_id,
tool_call_name=part.tool_name,
parent_message_id=message_id,
)
stream_ctx.part_end = ToolCallEndEvent(
tool_call_id=part.tool_call_id,
Expand All @@ -403,11 +403,9 @@ async def _handle_model_request_event(
yield ThinkingTextMessageStartEvent(
type=EventType.THINKING_TEXT_MESSAGE_START,
)
# Always send the content even if it's empty, as it may be
# used to indicate the start of thinking.
yield ThinkingTextMessageContentEvent(
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
delta=part.content or '',
delta=part.content,
)
stream_ctx.part_end = ThinkingTextMessageEndEvent(
type=EventType.THINKING_TEXT_MESSAGE_END,
Expand Down Expand Up @@ -435,20 +433,25 @@ async def _handle_model_request_event(

async def _handle_tool_result_event(
self,
result: ToolReturnPart,
prompt_message_id: str,
stream_ctx: _RequestStreamContext,
event: FunctionToolResultEvent,
) -> AsyncGenerator[BaseEvent, None]:
"""Convert a tool call result to AG-UI events.

Args:
result: The tool call result to process.
prompt_message_id: The message ID of the prompt that initiated the tool call.
stream_ctx: The request stream context to manage state.
event: The tool call result event to process.

Yields:
AG-UI Server-Sent Events (SSE).
"""
result = event.result
if not isinstance(result, ToolReturnPart):
return

message_id = stream_ctx.new_message_id()
yield ToolCallResultEvent(
message_id=prompt_message_id,
message_id=message_id,
type=EventType.TOOL_CALL_RESULT,
role='tool',
tool_call_id=result.tool_call_id,
Expand All @@ -468,75 +471,55 @@ async def _handle_tool_result_event(
yield item


@dataclass
class _History:
"""A simple history representation for AG-UI protocol."""

prompt_message_id: str # The ID of the last user message.
messages: list[ModelMessage]

@classmethod
def from_ag_ui(cls, messages: list[Message]) -> _History:
"""Convert a AG-UI history to a Pydantic AI one.

Args:
messages: List of AG-UI messages to convert.

Returns:
List of Pydantic AI model messages.
"""
prompt_message_id = ''
result: list[ModelMessage] = []
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
for msg in messages:
if isinstance(msg, UserMessage):
prompt_message_id = msg.id
result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
elif isinstance(msg, AssistantMessage):
if msg.tool_calls:
for tool_call in msg.tool_calls:
tool_calls[tool_call.id] = tool_call.function.name

result.append(
ModelResponse(
parts=[
ToolCallPart(
tool_name=tool_call.function.name,
tool_call_id=tool_call.id,
args=tool_call.function.arguments,
)
for tool_call in msg.tool_calls
]
)
)

if msg.content:
result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
elif isinstance(msg, SystemMessage):
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
elif isinstance(msg, ToolMessage):
tool_name = tool_calls.get(msg.tool_call_id)
if tool_name is None: # pragma: no cover
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
"""Convert a AG-UI history to a Pydantic AI one."""
result: list[ModelMessage] = []
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
for msg in messages:
if isinstance(msg, UserMessage):
result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
elif isinstance(msg, AssistantMessage):
if msg.tool_calls:
for tool_call in msg.tool_calls:
tool_calls[tool_call.id] = tool_call.function.name

result.append(
ModelRequest(
ModelResponse(
parts=[
ToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=msg.tool_call_id,
ToolCallPart(
tool_name=tool_call.function.name,
tool_call_id=tool_call.id,
args=tool_call.function.arguments,
)
for tool_call in msg.tool_calls
]
)
)
elif isinstance(msg, DeveloperMessage): # pragma: no branch
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))

return cls(
prompt_message_id=prompt_message_id,
messages=result,
)
if msg.content:
result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
elif isinstance(msg, SystemMessage):
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
elif isinstance(msg, ToolMessage):
tool_name = tool_calls.get(msg.tool_call_id)
if tool_name is None: # pragma: no cover
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)

result.append(
ModelRequest(
parts=[
ToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=msg.tool_call_id,
)
]
)
)
elif isinstance(msg, DeveloperMessage): # pragma: no branch
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))

return result


@runtime_checkable
Expand Down
39 changes: 39 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def IsFloat(*args: Any, **kwargs: Any) -> float: ...
def IsInt(*args: Any, **kwargs: Any) -> int: ...
def IsNow(*args: Any, **kwargs: Any) -> datetime: ...
def IsStr(*args: Any, **kwargs: Any) -> str: ...
def IsSameStr(*args: Any, **kwargs: Any) -> str: ...
else:
from dirty_equals import IsDatetime, IsFloat, IsInstance, IsInt, IsNow as _IsNow, IsStr

Expand All @@ -59,6 +60,44 @@ def IsNow(*args: Any, **kwargs: Any):
kwargs['delta'] = 10
return _IsNow(*args, **kwargs)

class IsSameStr(IsStr):
"""
Checks if the value is a string, and that subsequent uses have the same value as the first one.

Example:
```python {test="skip"}
assert events == [
{
'type': 'RUN_STARTED',
'threadId': (thread_id := IsSameStr()),
'runId': (run_id := IsSameStr()),
},
{'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'},
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'success '},
{
'type': 'TEXT_MESSAGE_CONTENT',
'messageId': message_id,
'delta': '(no tool calls)',
},
{'type': 'TEXT_MESSAGE_END', 'messageId': message_id},
{
'type': 'RUN_FINISHED',
'threadId': thread_id,
'runId': run_id,
},
]
```
"""

_first_other: str | None = None

def equals(self, other: Any) -> bool:
if self._first_other is None:
self._first_other = other
return super().equals(other)
else:
return other == self._first_other


class TestEnv:
__test__ = False
Expand Down
Loading