@@ -480,109 +480,21 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
480
480
async for event in self ._events_iterator :
481
481
yield event
482
482
483
- async def _handle_tool_calls ( # noqa: C901
483
+ async def _handle_tool_calls (
484
484
self ,
485
485
ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
486
486
tool_calls : list [_messages .ToolCallPart ],
487
487
) -> AsyncIterator [_messages .HandleResponseEvent ]:
488
488
run_context = build_run_context (ctx )
489
489
490
- final_result : result .FinalResult [NodeRunEndT ] | None = None
491
490
parts : list [_messages .ModelRequestPart ] = []
491
+ final_result_holder : list [result .FinalResult [NodeRunEndT ]] = []
492
492
493
- toolset = ctx .deps .toolset
494
-
495
- unknown_calls : list [_messages .ToolCallPart ] = []
496
- tool_calls_by_kind : dict [ToolKind , list [_messages .ToolCallPart ]] = defaultdict (list )
497
- # TODO: Make Toolset.tool_defs a dict
498
- tool_defs_by_name : dict [str , ToolDefinition ] = {tool_def .name : tool_def for tool_def in toolset .tool_defs }
499
- for call in tool_calls :
500
- try :
501
- tool_def = tool_defs_by_name [call .tool_name ]
502
- tool_calls_by_kind [tool_def .kind ].append (call )
503
- except KeyError :
504
- unknown_calls .append (call )
505
-
506
- # first, look for the output tool call
507
- for call in tool_calls_by_kind ['output' ]:
508
- if final_result :
509
- part = _messages .ToolReturnPart (
510
- tool_name = call .tool_name ,
511
- content = 'Output tool not used - a final result was already processed.' ,
512
- tool_call_id = call .tool_call_id ,
513
- )
514
- parts .append (part )
515
- else :
516
- try :
517
- result_data = await _call_tool (toolset , call , run_context )
518
- except _output .ToolRetryError as e :
519
- parts .append (e .tool_retry )
520
- else :
521
- part = _messages .ToolReturnPart (
522
- tool_name = call .tool_name ,
523
- content = 'Final result processed.' ,
524
- tool_call_id = call .tool_call_id ,
525
- )
526
- parts .append (part )
527
- final_result = result .FinalResult (result_data , call .tool_name , call .tool_call_id )
528
-
529
- # Then build the other request parts based on end strategy
530
- if final_result and ctx .deps .end_strategy == 'early' :
531
- for call in tool_calls_by_kind ['function' ]:
532
- parts .append (
533
- _messages .ToolReturnPart (
534
- tool_name = call .tool_name ,
535
- content = 'Tool not executed - a final result was already processed.' ,
536
- tool_call_id = call .tool_call_id ,
537
- )
538
- )
539
- else :
540
- async for event in process_function_tools (
541
- toolset ,
542
- tool_calls_by_kind ['function' ],
543
- ctx ,
544
- parts ,
545
- ):
546
- yield event
547
-
548
- if unknown_calls :
549
- ctx .state .increment_retries (ctx .deps .max_result_retries )
550
- async for event in process_function_tools (
551
- toolset ,
552
- unknown_calls ,
553
- ctx ,
554
- parts ,
555
- ):
556
- yield event
557
-
558
- deferred_calls : list [_messages .ToolCallPart ] = []
559
- for call in tool_calls_by_kind ['deferred' ]:
560
- if final_result :
561
- parts .append (
562
- _messages .ToolReturnPart (
563
- tool_name = call .tool_name ,
564
- content = 'Tool not executed - a final result was already processed.' ,
565
- tool_call_id = call .tool_call_id ,
566
- )
567
- )
568
- else :
569
- yield _messages .FunctionToolCallEvent (call )
570
- deferred_calls .append (call )
571
-
572
- if deferred_calls :
573
- if not ctx .deps .output_schema .deferred_tool_calls :
574
- raise exceptions .UserError (
575
- 'There are pending tool calls but DeferredToolCalls is not among output types.'
576
- )
577
-
578
- deferred_tool_names = [call .tool_name for call in deferred_calls ]
579
- deferred_tool_defs = {
580
- tool_def .name : tool_def for tool_def in toolset .tool_defs if tool_def .name in deferred_tool_names
581
- }
582
- output_data = cast (NodeRunEndT , DeferredToolCalls (deferred_calls , deferred_tool_defs ))
583
- final_result = result .FinalResult (output_data )
493
+ async for event in process_function_tools (ctx .deps .toolset , tool_calls , None , ctx , parts , final_result_holder ):
494
+ yield event
584
495
585
- if final_result :
496
+ if final_result_holder :
497
+ final_result = final_result_holder [0 ]
586
498
self ._next_node = self ._handle_final_result (ctx , final_result , parts )
587
499
else :
588
500
instructions = await ctx .deps .get_instructions (run_context )
@@ -652,24 +564,85 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str:
652
564
async def process_function_tools (
653
565
toolset : AbstractToolset [DepsT ],
654
566
tool_calls : list [_messages .ToolCallPart ],
567
+ final_result : result .FinalResult [NodeRunEndT ] | None ,
655
568
ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
656
- output_parts : list [_messages .ModelRequestPart ],
569
+ parts : list [_messages .ModelRequestPart ],
570
+ final_result_holder : list [result .FinalResult [NodeRunEndT ]] = [],
657
571
) -> AsyncIterator [_messages .HandleResponseEvent ]:
658
572
"""Process function (i.e., non-result) tool calls in parallel.
659
573
660
574
Also add stub return parts for any other tools that need it.
661
575
662
- Because async iterators can't have return values, we use `output_parts ` as an output argument.
576
+ Because async iterators can't have return values, we use `parts ` as an output argument.
663
577
"""
664
578
run_context = build_run_context (ctx )
665
579
666
- calls_to_run : list [_messages .ToolCallPart ] = []
667
- call_index_to_event_id : dict [int , str ] = {}
580
+ unknown_calls : list [_messages .ToolCallPart ] = []
581
+ tool_calls_by_kind : dict [ToolKind , list [_messages .ToolCallPart ]] = defaultdict (list )
582
+ # TODO: Make Toolset.tool_defs a dict
583
+ tool_defs_by_name : dict [str , ToolDefinition ] = {tool_def .name : tool_def for tool_def in toolset .tool_defs }
668
584
for call in tool_calls :
669
- event = _messages .FunctionToolCallEvent (call )
670
- yield event
671
- call_index_to_event_id [len (calls_to_run )] = event .call_id
672
- calls_to_run .append (call )
585
+ try :
586
+ tool_def = tool_defs_by_name [call .tool_name ]
587
+ tool_calls_by_kind [tool_def .kind ].append (call )
588
+ except KeyError :
589
+ unknown_calls .append (call )
590
+
591
+ # first, look for the output tool call
592
+ for call in tool_calls_by_kind ['output' ]:
593
+ if final_result :
594
+ if final_result .tool_call_id == call .tool_call_id :
595
+ part = _messages .ToolReturnPart (
596
+ tool_name = call .tool_name ,
597
+ content = 'Final result processed.' ,
598
+ tool_call_id = call .tool_call_id ,
599
+ )
600
+ else :
601
+ yield _messages .FunctionToolCallEvent (call )
602
+ part = _messages .ToolReturnPart (
603
+ tool_name = call .tool_name ,
604
+ content = 'Output tool not used - a final result was already processed.' ,
605
+ tool_call_id = call .tool_call_id ,
606
+ )
607
+ yield _messages .FunctionToolResultEvent (part , tool_call_id = call .tool_call_id )
608
+
609
+ parts .append (part )
610
+ else :
611
+ try :
612
+ result_data = await _call_tool (toolset , call , run_context )
613
+ except _output .ToolRetryError as e :
614
+ yield _messages .FunctionToolCallEvent (call )
615
+ parts .append (e .tool_retry )
616
+ yield _messages .FunctionToolResultEvent (e .tool_retry , tool_call_id = call .tool_call_id )
617
+ else :
618
+ part = _messages .ToolReturnPart (
619
+ tool_name = call .tool_name ,
620
+ content = 'Final result processed.' ,
621
+ tool_call_id = call .tool_call_id ,
622
+ )
623
+ parts .append (part )
624
+ final_result = result .FinalResult (result_data , call .tool_name , call .tool_call_id )
625
+
626
+ calls_to_run : list [_messages .ToolCallPart ] = []
627
+ # Then build the other request parts based on end strategy
628
+ if final_result and ctx .deps .end_strategy == 'early' :
629
+ for call in tool_calls_by_kind ['function' ]:
630
+ parts .append (
631
+ _messages .ToolReturnPart (
632
+ tool_name = call .tool_name ,
633
+ content = 'Tool not executed - a final result was already processed.' ,
634
+ tool_call_id = call .tool_call_id ,
635
+ )
636
+ )
637
+ else :
638
+ calls_to_run .extend (tool_calls_by_kind ['function' ])
639
+
640
+ if unknown_calls :
641
+ ctx .state .increment_retries (ctx .deps .max_result_retries )
642
+ calls_to_run .extend (unknown_calls )
643
+
644
+ for call in calls_to_run :
645
+ yield _messages .FunctionToolCallEvent (call )
673
646
674
647
user_parts : list [_messages .UserPromptPart ] = []
675
648
@@ -698,12 +671,12 @@ async def process_function_tools(
698
671
done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
699
672
for task in done :
700
673
index = tasks .index (task )
701
- result = task .result ()
702
- yield _messages .FunctionToolResultEvent (result , tool_call_id = call_index_to_event_id [ index ] )
674
+ tool_result = task .result ()
675
+ yield _messages .FunctionToolResultEvent (tool_result , tool_call_id = tool_result . tool_call_id )
703
676
704
- if isinstance (result , _messages .RetryPromptPart ):
705
- results_by_index [index ] = result
706
- elif isinstance (result , _messages .ToolReturnPart ):
677
+ if isinstance (tool_result , _messages .RetryPromptPart ):
678
+ results_by_index [index ] = tool_result
679
+ elif isinstance (tool_result , _messages .ToolReturnPart ):
707
680
708
681
def process_content (content : Any ) -> Any :
709
682
if isinstance (content , _messages .MultiModalContentTypes ):
@@ -715,30 +688,58 @@ def process_content(content: Any) -> Any:
715
688
user_parts .append (
716
689
_messages .UserPromptPart (
717
690
content = [f'This is file { identifier } :' , content ],
718
- timestamp = result .timestamp ,
691
+ timestamp = tool_result .timestamp ,
719
692
part_kind = 'user-prompt' ,
720
693
)
721
694
)
722
695
return f'See file { identifier } '
723
696
else :
724
697
return content
725
698
726
- if isinstance (result .content , list ):
727
- contents = cast (list [Any ], result .content ) # type: ignore
728
- result .content = [process_content (content ) for content in contents ]
699
+ if isinstance (tool_result .content , list ):
700
+ contents = cast (list [Any ], tool_result .content ) # type: ignore
701
+ tool_result .content = [process_content (content ) for content in contents ]
729
702
else :
730
- result .content = process_content (result .content )
703
+ tool_result .content = process_content (tool_result .content )
731
704
732
- results_by_index [index ] = result
705
+ results_by_index [index ] = tool_result
733
706
else :
734
- assert_never (result )
707
+ assert_never (tool_result )
735
708
736
709
# We append the results at the end, rather than as they are received, to retain a consistent ordering
737
710
# This is mostly just to simplify testing
738
711
for k in sorted (results_by_index ):
739
- output_parts .append (results_by_index [k ])
712
+ parts .append (results_by_index [k ])
713
+
714
+ deferred_calls : list [_messages .ToolCallPart ] = []
715
+ for call in tool_calls_by_kind ['deferred' ]:
716
+ if final_result :
717
+ parts .append (
718
+ _messages .ToolReturnPart (
719
+ tool_name = call .tool_name ,
720
+ content = 'Tool not executed - a final result was already processed.' ,
721
+ tool_call_id = call .tool_call_id ,
722
+ )
723
+ )
724
+ else :
725
+ yield _messages .FunctionToolCallEvent (call )
726
+ deferred_calls .append (call )
727
+
728
+ if deferred_calls :
729
+ if not ctx .deps .output_schema .deferred_tool_calls :
730
+ raise exceptions .UserError ('There are pending tool calls but DeferredToolCalls is not among output types.' )
731
+
732
+ deferred_tool_names = [call .tool_name for call in deferred_calls ]
733
+ deferred_tool_defs = {
734
+ tool_def .name : tool_def for tool_def in toolset .tool_defs if tool_def .name in deferred_tool_names
735
+ }
736
+ output_data = cast (NodeRunEndT , DeferredToolCalls (deferred_calls , deferred_tool_defs ))
737
+ final_result = result .FinalResult (output_data )
738
+
739
+ parts .extend (user_parts )
740
740
741
- output_parts .extend (user_parts )
741
+ if final_result :
742
+ final_result_holder .append (final_result )
742
743
743
744
744
745
async def _call_function_tool (
0 commit comments