Skip to content

Commit 0360e77

Browse files
committed
Make DeferredToolCalls work with streaming
1 parent 74a56ae commit 0360e77

File tree

6 files changed

+178
-53
lines changed

6 files changed

+178
-53
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pydantic_graph.nodes import End, NodeRunEndT
2222

2323
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
24-
from .output import DeferredToolCalls, OutputDataT, OutputSpec
24+
from .output import OutputDataT, OutputSpec
2525
from .settings import ModelSettings, merge_model_settings
2626
from .tools import RunContext, ToolDefinition, ToolKind
2727

@@ -310,6 +310,7 @@ async def stream(
310310
ctx.deps.output_validators,
311311
build_run_context(ctx),
312312
ctx.deps.usage_limits,
313+
ctx.deps.toolset,
313314
)
314315
yield agent_stream
315316
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
@@ -497,6 +498,13 @@ async def _handle_tool_calls(
497498
if final_result_holder:
498499
final_result = final_result_holder[0]
499500
self._next_node = self._handle_final_result(ctx, final_result, parts)
501+
elif deferred_tool_calls := ctx.deps.toolset.get_deferred_tool_calls(tool_calls):
502+
if not ctx.deps.output_schema.deferred_tool_calls:
503+
raise exceptions.UserError(
504+
'There are deferred tool calls but DeferredToolCalls is not among output types.'
505+
)
506+
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
507+
self._next_node = self._handle_final_result(ctx, final_result, parts)
500508
else:
501509
instructions = await ctx.deps.get_instructions(run_context)
502510
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
@@ -578,16 +586,11 @@ async def process_function_tools( # noqa: C901
578586
"""
579587
run_context = build_run_context(ctx)
580588

581-
unknown_calls: list[_messages.ToolCallPart] = []
582-
tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list)
583-
# TODO: Make Toolset.tool_defs a dict
584-
tool_defs_by_name: dict[str, ToolDefinition] = {tool_def.name: tool_def for tool_def in toolset.tool_defs}
589+
tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list)
585590
for call in tool_calls:
586-
try:
587-
tool_def = tool_defs_by_name[call.tool_name]
588-
tool_calls_by_kind[tool_def.kind].append(call)
589-
except KeyError:
590-
unknown_calls.append(call)
591+
tool_def = toolset.get_tool_def(call.tool_name)
592+
kind = tool_def.kind if tool_def else 'unknown'
593+
tool_calls_by_kind[kind].append(call)
591594

592595
# first, look for the output tool call
593596
for call in tool_calls_by_kind['output']:
@@ -642,9 +645,9 @@ async def process_function_tools( # noqa: C901
642645
else:
643646
calls_to_run.extend(tool_calls_by_kind['function'])
644647

645-
if unknown_calls:
648+
if tool_calls_by_kind['unknown']:
646649
ctx.state.increment_retries(ctx.deps.max_result_retries)
647-
calls_to_run.extend(unknown_calls)
650+
calls_to_run.extend(tool_calls_by_kind['unknown'])
648651

649652
for call in calls_to_run:
650653
yield _messages.FunctionToolCallEvent(call)
@@ -718,7 +721,6 @@ def process_content(content: Any) -> Any:
718721
for k in sorted(results_by_index):
719722
parts.append(results_by_index[k])
720723

721-
deferred_calls: list[_messages.ToolCallPart] = []
722724
for call in tool_calls_by_kind['deferred']:
723725
if final_result:
724726
parts.append(
@@ -730,18 +732,6 @@ def process_content(content: Any) -> Any:
730732
)
731733
else:
732734
yield _messages.FunctionToolCallEvent(call)
733-
deferred_calls.append(call)
734-
735-
if deferred_calls:
736-
if not ctx.deps.output_schema.deferred_tool_calls:
737-
raise exceptions.UserError('There are pending tool calls but DeferredToolCalls is not among output types.')
738-
739-
deferred_tool_names = [call.tool_name for call in deferred_calls]
740-
deferred_tool_defs = {
741-
tool_def.name: tool_def for tool_def in toolset.tool_defs if tool_def.name in deferred_tool_names
742-
}
743-
output_data = cast(NodeRunEndT, DeferredToolCalls(deferred_calls, deferred_tool_defs))
744-
final_result = result.FinalResult(output_data)
745735

746736
parts.extend(user_parts)
747737

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import json
55
from abc import ABC, abstractmethod
6-
from collections.abc import Awaitable, Iterable, Iterator, Sequence
6+
from collections.abc import Awaitable, Iterable, Sequence
77
from dataclasses import dataclass, field, replace
88
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
99

@@ -549,16 +549,6 @@ def find_named_tool(
549549
if part.tool_name == tool_name:
550550
return part, self.tools[tool_name]
551551

552-
def find_tool(
553-
self,
554-
parts: Iterable[_messages.ModelResponsePart],
555-
) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]:
556-
"""Find a tool that matches one of the calls."""
557-
for part in parts:
558-
if isinstance(part, _messages.ToolCallPart): # pragma: no branch
559-
if result := self.tools.get(part.tool_name):
560-
yield part, result
561-
562552

563553
@dataclass(init=False)
564554
class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchema[OutputDataT]):

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,12 +1042,13 @@ async def stream_to_final(
10421042
output_schema, _output.TextOutputSchema
10431043
):
10441044
return FinalResult(s, None, None)
1045-
elif isinstance(new_part, _messages.ToolCallPart) and isinstance(
1046-
output_schema, _output.ToolOutputSchema
1047-
): # pragma: no branch
1048-
for call, _ in output_schema.find_tool([new_part]):
1049-
return FinalResult(s, call.tool_name, call.tool_call_id)
1050-
# TODO: Handle DeferredToolCalls
1045+
elif isinstance(new_part, _messages.ToolCallPart) and (
1046+
tool_def := graph_ctx.deps.toolset.get_tool_def(new_part.tool_name)
1047+
):
1048+
if tool_def.kind == 'output':
1049+
return FinalResult(s, new_part.tool_name, new_part.tool_call_id)
1050+
elif tool_def.kind == 'deferred':
1051+
return FinalResult(s, None, None)
10511052
return None
10521053

10531054
final_result = await stream_to_final(streamed_response)
@@ -1072,16 +1073,20 @@ async def on_complete() -> None:
10721073

10731074
# TODO: Should we move on to the CallToolsNode here, instead of doing this ourselves?
10741075
parts: list[_messages.ModelRequestPart] = []
1076+
# final_result_holder: list[result.FinalResult[models.StreamedResponse]] = []
10751077
async for _event in _agent_graph.process_function_tools(
10761078
graph_ctx.deps.toolset,
10771079
tool_calls,
10781080
final_result,
10791081
graph_ctx,
10801082
parts,
1083+
# final_result_holder,
10811084
):
10821085
pass
10831086
if parts:
10841087
messages.append(_messages.ModelRequest(parts))
1088+
# if final_result_holder:
1089+
# final_result = final_result_holder[0]
10851090

10861091
yield StreamedRunResult(
10871092
messages,
@@ -1093,6 +1098,7 @@ async def on_complete() -> None:
10931098
graph_ctx.deps.output_validators,
10941099
final_result.tool_name,
10951100
on_complete,
1101+
graph_ctx.deps.toolset,
10961102
)
10971103
break
10981104
# TODO: There may be deferred tool calls, process those.

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from copy import copy
66
from dataclasses import dataclass, field
77
from datetime import datetime
8-
from typing import Generic
8+
from typing import Generic, cast
99

1010
from pydantic import ValidationError
1111
from typing_extensions import TypeVar, deprecated, overload
1212

13+
from pydantic_ai.toolset import RunToolset
14+
1315
from . import _utils, exceptions, messages as _messages, models
1416
from ._output import (
1517
OutputDataT_inv,
@@ -47,6 +49,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
4749
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
4850
_run_ctx: RunContext[AgentDepsT]
4951
_usage_limits: UsageLimits | None
52+
_toolset: RunToolset[AgentDepsT]
5053

5154
_agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
5255
_final_result_event: FinalResultEvent | None = field(default=None, init=False)
@@ -102,6 +105,12 @@ async def _validate_response(
102105
result_data = await output_tool.process(
103106
call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
104107
)
108+
elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts):
109+
if not self._output_schema.deferred_tool_calls:
110+
raise exceptions.UserError(
111+
'There are deferred tool calls but DeferredToolCalls is not among output types.'
112+
)
113+
return cast(OutputDataT, deferred_tool_calls)
105114
elif isinstance(self._output_schema, TextOutputSchema):
106115
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
107116

@@ -113,8 +122,6 @@ async def _validate_response(
113122
'Invalid response, unable to process text output'
114123
)
115124

116-
# TODO: Possibly return DeferredToolCalls here?
117-
118125
for validator in self._output_validators:
119126
result_data = await validator.validate(result_data, call, self._run_ctx)
120127
return result_data
@@ -136,13 +143,19 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.
136143
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
137144
if isinstance(e, _messages.PartStartEvent):
138145
new_part = e.part
139-
if isinstance(new_part, _messages.ToolCallPart) and isinstance(output_schema, ToolOutputSchema):
140-
for call, _ in output_schema.find_tool([new_part]): # pragma: no branch
141-
return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id)
142-
elif isinstance(new_part, _messages.TextPart) and isinstance(
146+
if isinstance(new_part, _messages.TextPart) and isinstance(
143147
output_schema, TextOutputSchema
144148
): # pragma: no branch
145149
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
150+
elif isinstance(new_part, _messages.ToolCallPart) and (
151+
tool_def := self._toolset.get_tool_def(new_part.tool_name)
152+
):
153+
if tool_def.kind == 'output':
154+
return _messages.FinalResultEvent(
155+
tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id
156+
)
157+
elif tool_def.kind == 'deferred':
158+
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
146159

147160
usage_checking_stream = _get_usage_checking_stream_response(
148161
self._raw_stream_response, self._usage_limits, self.usage
@@ -177,6 +190,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
177190
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
178191
_output_tool_name: str | None
179192
_on_complete: Callable[[], Awaitable[None]]
193+
_toolset: RunToolset[AgentDepsT]
180194

181195
_initial_run_ctx_usage: Usage = field(init=False)
182196
is_complete: bool = field(default=False, init=False)
@@ -382,7 +396,6 @@ async def get_output(self) -> OutputDataT:
382396
pass
383397
message = self._stream_response.get()
384398
await self._marked_completed(message)
385-
# TODO: Possibly return DeferredToolCalls here?
386399
return await self.validate_structured_output(message)
387400

388401
@deprecated('`get_data` is deprecated, use `get_output` instead.')
@@ -423,6 +436,12 @@ async def validate_structured_output(
423436
result_data = await output_tool.process(
424437
call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
425438
)
439+
elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts):
440+
if not self._output_schema.deferred_tool_calls:
441+
raise exceptions.UserError(
442+
'There are deferred tool calls but DeferredToolCalls is not among output types.'
443+
)
444+
return cast(OutputDataT, deferred_tool_calls)
426445
elif isinstance(self._output_schema, TextOutputSchema):
427446
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
428447

pydantic_ai_slim/pydantic_ai/toolset.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
from abc import ABC, abstractmethod
5-
from collections.abc import Awaitable, Iterator, Sequence
5+
from collections.abc import Awaitable, Iterable, Iterator, Sequence
66
from contextlib import AsyncExitStack, contextmanager
77
from dataclasses import dataclass, field, replace
88
from functools import partial
@@ -14,6 +14,8 @@
1414
from pydantic_core import SchemaValidator
1515
from typing_extensions import Self
1616

17+
from pydantic_ai.output import DeferredToolCalls
18+
1719
from . import messages as _messages
1820
from ._output import BaseOutputSchema, OutputValidator, ToolRetryError
1921
from ._run_context import AgentDepsT, RunContext
@@ -70,6 +72,9 @@ def tool_defs(self) -> list[ToolDefinition]:
7072
def tool_names(self) -> list[str]:
7173
return [tool_def.name for tool_def in self.tool_defs]
7274

75+
def get_tool_def(self, name: str) -> ToolDefinition | None:
76+
return next((tool_def for tool_def in self.tool_defs if tool_def.name == name), None)
77+
7378
@abstractmethod
7479
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
7580
raise NotImplementedError()
@@ -676,6 +681,25 @@ async def call_tool(
676681
self._retries.pop(name, None)
677682
return output
678683

684+
def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None:
685+
deferred_calls_and_defs = [
686+
(part, tool_def)
687+
for part in parts
688+
if isinstance(part, _messages.ToolCallPart)
689+
and (tool_def := self.get_tool_def(part.tool_name))
690+
and tool_def.kind == 'deferred'
691+
]
692+
if not deferred_calls_and_defs:
693+
return None
694+
695+
deferred_calls: list[_messages.ToolCallPart] = []
696+
deferred_tool_defs: dict[str, ToolDefinition] = {}
697+
for part, tool_def in deferred_calls_and_defs:
698+
deferred_calls.append(part)
699+
deferred_tool_defs[part.tool_name] = tool_def
700+
701+
return DeferredToolCalls(deferred_calls, deferred_tool_defs)
702+
679703
@contextmanager
680704
def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunContext[AgentDepsT]]:
681705
try:

0 commit comments

Comments
 (0)