Skip to content

Commit 05aa972

Browse files
committed
WIP
1 parent 8745a7a commit 05aa972

File tree

3 files changed

+127
-124
lines changed

3 files changed

+127
-124
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 117 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -480,109 +480,21 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
480480
async for event in self._events_iterator:
481481
yield event
482482

483-
async def _handle_tool_calls( # noqa: C901
483+
async def _handle_tool_calls(
484484
self,
485485
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
486486
tool_calls: list[_messages.ToolCallPart],
487487
) -> AsyncIterator[_messages.HandleResponseEvent]:
488488
run_context = build_run_context(ctx)
489489

490-
final_result: result.FinalResult[NodeRunEndT] | None = None
491490
parts: list[_messages.ModelRequestPart] = []
491+
final_result_holder: list[result.FinalResult[NodeRunEndT]] = []
492492

493-
toolset = ctx.deps.toolset
494-
495-
unknown_calls: list[_messages.ToolCallPart] = []
496-
tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list)
497-
# TODO: Make Toolset.tool_defs a dict
498-
tool_defs_by_name: dict[str, ToolDefinition] = {tool_def.name: tool_def for tool_def in toolset.tool_defs}
499-
for call in tool_calls:
500-
try:
501-
tool_def = tool_defs_by_name[call.tool_name]
502-
tool_calls_by_kind[tool_def.kind].append(call)
503-
except KeyError:
504-
unknown_calls.append(call)
505-
506-
# first, look for the output tool call
507-
for call in tool_calls_by_kind['output']:
508-
if final_result:
509-
part = _messages.ToolReturnPart(
510-
tool_name=call.tool_name,
511-
content='Output tool not used - a final result was already processed.',
512-
tool_call_id=call.tool_call_id,
513-
)
514-
parts.append(part)
515-
else:
516-
try:
517-
result_data = await _call_tool(toolset, call, run_context)
518-
except _output.ToolRetryError as e:
519-
parts.append(e.tool_retry)
520-
else:
521-
part = _messages.ToolReturnPart(
522-
tool_name=call.tool_name,
523-
content='Final result processed.',
524-
tool_call_id=call.tool_call_id,
525-
)
526-
parts.append(part)
527-
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
528-
529-
# Then build the other request parts based on end strategy
530-
if final_result and ctx.deps.end_strategy == 'early':
531-
for call in tool_calls_by_kind['function']:
532-
parts.append(
533-
_messages.ToolReturnPart(
534-
tool_name=call.tool_name,
535-
content='Tool not executed - a final result was already processed.',
536-
tool_call_id=call.tool_call_id,
537-
)
538-
)
539-
else:
540-
async for event in process_function_tools(
541-
toolset,
542-
tool_calls_by_kind['function'],
543-
ctx,
544-
parts,
545-
):
546-
yield event
547-
548-
if unknown_calls:
549-
ctx.state.increment_retries(ctx.deps.max_result_retries)
550-
async for event in process_function_tools(
551-
toolset,
552-
unknown_calls,
553-
ctx,
554-
parts,
555-
):
556-
yield event
557-
558-
deferred_calls: list[_messages.ToolCallPart] = []
559-
for call in tool_calls_by_kind['deferred']:
560-
if final_result:
561-
parts.append(
562-
_messages.ToolReturnPart(
563-
tool_name=call.tool_name,
564-
content='Tool not executed - a final result was already processed.',
565-
tool_call_id=call.tool_call_id,
566-
)
567-
)
568-
else:
569-
yield _messages.FunctionToolCallEvent(call)
570-
deferred_calls.append(call)
571-
572-
if deferred_calls:
573-
if not ctx.deps.output_schema.deferred_tool_calls:
574-
raise exceptions.UserError(
575-
'There are pending tool calls but DeferredToolCalls is not among output types.'
576-
)
577-
578-
deferred_tool_names = [call.tool_name for call in deferred_calls]
579-
deferred_tool_defs = {
580-
tool_def.name: tool_def for tool_def in toolset.tool_defs if tool_def.name in deferred_tool_names
581-
}
582-
output_data = cast(NodeRunEndT, DeferredToolCalls(deferred_calls, deferred_tool_defs))
583-
final_result = result.FinalResult(output_data)
493+
async for event in process_function_tools(ctx.deps.toolset, tool_calls, None, ctx, parts, final_result_holder):
494+
yield event
584495

585-
if final_result:
496+
if final_result_holder:
497+
final_result = final_result_holder[0]
586498
self._next_node = self._handle_final_result(ctx, final_result, parts)
587499
else:
588500
instructions = await ctx.deps.get_instructions(run_context)
@@ -652,24 +564,85 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str:
652564
async def process_function_tools(
653565
toolset: AbstractToolset[DepsT],
654566
tool_calls: list[_messages.ToolCallPart],
567+
final_result: result.FinalResult[NodeRunEndT] | None,
655568
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
656-
output_parts: list[_messages.ModelRequestPart],
569+
parts: list[_messages.ModelRequestPart],
570+
final_result_holder: list[result.FinalResult[NodeRunEndT]] = [],
657571
) -> AsyncIterator[_messages.HandleResponseEvent]:
658572
"""Process function (i.e., non-result) tool calls in parallel.
659573
660574
Also add stub return parts for any other tools that need it.
661575
662-
Because async iterators can't have return values, we use `output_parts` as an output argument.
576+
Because async iterators can't have return values, we use `parts` as an output argument.
663577
"""
664578
run_context = build_run_context(ctx)
665579

666-
calls_to_run: list[_messages.ToolCallPart] = []
667-
call_index_to_event_id: dict[int, str] = {}
580+
unknown_calls: list[_messages.ToolCallPart] = []
581+
tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list)
582+
# TODO: Make Toolset.tool_defs a dict
583+
tool_defs_by_name: dict[str, ToolDefinition] = {tool_def.name: tool_def for tool_def in toolset.tool_defs}
668584
for call in tool_calls:
669-
event = _messages.FunctionToolCallEvent(call)
670-
yield event
671-
call_index_to_event_id[len(calls_to_run)] = event.call_id
672-
calls_to_run.append(call)
585+
try:
586+
tool_def = tool_defs_by_name[call.tool_name]
587+
tool_calls_by_kind[tool_def.kind].append(call)
588+
except KeyError:
589+
unknown_calls.append(call)
590+
591+
# first, look for the output tool call
592+
for call in tool_calls_by_kind['output']:
593+
if final_result:
594+
if final_result.tool_call_id == call.tool_call_id:
595+
part = _messages.ToolReturnPart(
596+
tool_name=call.tool_name,
597+
content='Final result processed.',
598+
tool_call_id=call.tool_call_id,
599+
)
600+
else:
601+
yield _messages.FunctionToolCallEvent(call)
602+
part = _messages.ToolReturnPart(
603+
tool_name=call.tool_name,
604+
content='Output tool not used - a final result was already processed.',
605+
tool_call_id=call.tool_call_id,
606+
)
607+
yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
608+
609+
parts.append(part)
610+
else:
611+
try:
612+
result_data = await _call_tool(toolset, call, run_context)
613+
except _output.ToolRetryError as e:
614+
yield _messages.FunctionToolCallEvent(call)
615+
parts.append(e.tool_retry)
616+
yield _messages.FunctionToolResultEvent(e.tool_retry, tool_call_id=call.tool_call_id)
617+
else:
618+
part = _messages.ToolReturnPart(
619+
tool_name=call.tool_name,
620+
content='Final result processed.',
621+
tool_call_id=call.tool_call_id,
622+
)
623+
parts.append(part)
624+
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
625+
626+
calls_to_run: list[_messages.ToolCallPart] = []
627+
# Then build the other request parts based on end strategy
628+
if final_result and ctx.deps.end_strategy == 'early':
629+
for call in tool_calls_by_kind['function']:
630+
parts.append(
631+
_messages.ToolReturnPart(
632+
tool_name=call.tool_name,
633+
content='Tool not executed - a final result was already processed.',
634+
tool_call_id=call.tool_call_id,
635+
)
636+
)
637+
else:
638+
calls_to_run.extend(tool_calls_by_kind['function'])
639+
640+
if unknown_calls:
641+
ctx.state.increment_retries(ctx.deps.max_result_retries)
642+
calls_to_run.extend(unknown_calls)
643+
644+
for call in calls_to_run:
645+
yield _messages.FunctionToolCallEvent(call)
673646

674647
user_parts: list[_messages.UserPromptPart] = []
675648

@@ -698,12 +671,12 @@ async def process_function_tools(
698671
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
699672
for task in done:
700673
index = tasks.index(task)
701-
result = task.result()
702-
yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
674+
tool_result = task.result()
675+
yield _messages.FunctionToolResultEvent(tool_result, tool_call_id=tool_result.tool_call_id)
703676

704-
if isinstance(result, _messages.RetryPromptPart):
705-
results_by_index[index] = result
706-
elif isinstance(result, _messages.ToolReturnPart):
677+
if isinstance(tool_result, _messages.RetryPromptPart):
678+
results_by_index[index] = tool_result
679+
elif isinstance(tool_result, _messages.ToolReturnPart):
707680

708681
def process_content(content: Any) -> Any:
709682
if isinstance(content, _messages.MultiModalContentTypes):
@@ -715,30 +688,58 @@ def process_content(content: Any) -> Any:
715688
user_parts.append(
716689
_messages.UserPromptPart(
717690
content=[f'This is file {identifier}:', content],
718-
timestamp=result.timestamp,
691+
timestamp=tool_result.timestamp,
719692
part_kind='user-prompt',
720693
)
721694
)
722695
return f'See file {identifier}'
723696
else:
724697
return content
725698

726-
if isinstance(result.content, list):
727-
contents = cast(list[Any], result.content) # type: ignore
728-
result.content = [process_content(content) for content in contents]
699+
if isinstance(tool_result.content, list):
700+
contents = cast(list[Any], tool_result.content) # type: ignore
701+
tool_result.content = [process_content(content) for content in contents]
729702
else:
730-
result.content = process_content(result.content)
703+
tool_result.content = process_content(tool_result.content)
731704

732-
results_by_index[index] = result
705+
results_by_index[index] = tool_result
733706
else:
734-
assert_never(result)
707+
assert_never(tool_result)
735708

736709
# We append the results at the end, rather than as they are received, to retain a consistent ordering
737710
# This is mostly just to simplify testing
738711
for k in sorted(results_by_index):
739-
output_parts.append(results_by_index[k])
712+
parts.append(results_by_index[k])
713+
714+
deferred_calls: list[_messages.ToolCallPart] = []
715+
for call in tool_calls_by_kind['deferred']:
716+
if final_result:
717+
parts.append(
718+
_messages.ToolReturnPart(
719+
tool_name=call.tool_name,
720+
content='Tool not executed - a final result was already processed.',
721+
tool_call_id=call.tool_call_id,
722+
)
723+
)
724+
else:
725+
yield _messages.FunctionToolCallEvent(call)
726+
deferred_calls.append(call)
727+
728+
if deferred_calls:
729+
if not ctx.deps.output_schema.deferred_tool_calls:
730+
raise exceptions.UserError('There are pending tool calls but DeferredToolCalls is not among output types.')
731+
732+
deferred_tool_names = [call.tool_name for call in deferred_calls]
733+
deferred_tool_defs = {
734+
tool_def.name: tool_def for tool_def in toolset.tool_defs if tool_def.name in deferred_tool_names
735+
}
736+
output_data = cast(NodeRunEndT, DeferredToolCalls(deferred_calls, deferred_tool_defs))
737+
final_result = result.FinalResult(output_data)
738+
739+
parts.extend(user_parts)
740740

741-
output_parts.extend(user_parts)
741+
if final_result:
742+
final_result_holder.append(final_result)
742743

743744

744745
async def _call_function_tool(

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,10 +1046,11 @@ async def stream_to_final(
10461046
): # pragma: no branch
10471047
for call, _ in output_schema.find_tool([new_part]):
10481048
return FinalResult(s, call.tool_name, call.tool_call_id)
1049+
# TODO: Handle DeferredToolCalls
10491050
return None
10501051

1051-
final_result_details = await stream_to_final(streamed_response)
1052-
if final_result_details is not None:
1052+
final_result = await stream_to_final(streamed_response)
1053+
if final_result is not None:
10531054
if yielded:
10541055
raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
10551056
yielded = True
@@ -1068,19 +1069,16 @@ async def on_complete() -> None:
10681069
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
10691070
]
10701071

1072+
# TODO: Should we move on to the CallToolsNode here, instead of doing this ourselves?
10711073
parts: list[_messages.ModelRequestPart] = []
1072-
# TODO: Make this work again. We may have pulled too much out of process_function_tools :)
10731074
async for _event in _agent_graph.process_function_tools(
10741075
graph_ctx.deps.toolset,
10751076
tool_calls,
1077+
final_result,
10761078
graph_ctx,
10771079
parts,
10781080
):
10791081
pass
1080-
# TODO: Should we do something here related to the retry count?
1081-
# Maybe we should move the incrementing of the retry count to where we actually make a request?
1082-
# if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1083-
# ctx.state.increment_retries(ctx.deps.max_result_retries)
10841082
if parts:
10851083
messages.append(_messages.ModelRequest(parts))
10861084

@@ -1092,10 +1090,11 @@ async def on_complete() -> None:
10921090
graph_ctx.deps.output_schema,
10931091
_agent_graph.build_run_context(graph_ctx),
10941092
graph_ctx.deps.output_validators,
1095-
final_result_details.tool_name,
1093+
final_result.tool_name,
10961094
on_complete,
10971095
)
10981096
break
1097+
# TODO: There may be deferred tool calls, process those.
10991098
next_node = await agent_run.next(node)
11001099
if not isinstance(next_node, _agent_graph.AgentNode):
11011100
raise exceptions.AgentRunError( # pragma: no cover

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ async def _validate_response(
113113
'Invalid response, unable to process text output'
114114
)
115115

116+
# TODO: Possibly return DeferredToolCalls here?
117+
116118
for validator in self._output_validators:
117119
result_data = await validator.validate(result_data, call, self._run_ctx)
118120
return result_data
@@ -380,6 +382,7 @@ async def get_output(self) -> OutputDataT:
380382
pass
381383
message = self._stream_response.get()
382384
await self._marked_completed(message)
385+
# TODO: Possibly return DeferredToolCalls here?
383386
return await self.validate_structured_output(message)
384387

385388
@deprecated('`get_data` is deprecated, use `get_output` instead.')

0 commit comments

Comments
 (0)