@@ -561,7 +561,7 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str:
561
561
return hashlib .sha1 (identifier ).hexdigest ()[:6 ]
562
562
563
563
564
- async def process_function_tools (
564
+ async def process_function_tools ( # noqa: C901
565
565
toolset : AbstractToolset [DepsT ],
566
566
tool_calls : list [_messages .ToolCallPart ],
567
567
final_result : result .FinalResult [NodeRunEndT ] | None ,
@@ -646,70 +646,72 @@ async def process_function_tools(
646
646
647
647
user_parts : list [_messages .UserPromptPart ] = []
648
648
649
- include_content = (
650
- ctx .deps .instrumentation_settings is not None and ctx .deps .instrumentation_settings .include_content
651
- )
649
+ if calls_to_run :
650
+ include_content = (
651
+ ctx .deps .instrumentation_settings is not None and ctx .deps .instrumentation_settings .include_content
652
+ )
652
653
653
- # Run all tool tasks in parallel
654
- results_by_index : dict [int , _messages .ModelRequestPart ] = {}
655
- with ctx .deps .tracer .start_as_current_span (
656
- 'running tools' ,
657
- attributes = {
658
- 'tools' : [call .tool_name for call in calls_to_run ],
659
- 'logfire.msg' : f'running { len (calls_to_run )} tool{ "" if len (calls_to_run ) == 1 else "s" } ' ,
660
- },
661
- ):
662
- tasks = [
663
- asyncio .create_task (
664
- _call_function_tool (toolset , call , run_context , ctx .deps .tracer , include_content ), name = call .tool_name
665
- )
666
- for call in calls_to_run
667
- ]
668
-
669
- pending = tasks
670
- while pending :
671
- done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
672
- for task in done :
673
- index = tasks .index (task )
674
- tool_result = task .result ()
675
- yield _messages .FunctionToolResultEvent (tool_result , tool_call_id = tool_result .tool_call_id )
676
-
677
- if isinstance (tool_result , _messages .RetryPromptPart ):
678
- results_by_index [index ] = tool_result
679
- elif isinstance (tool_result , _messages .ToolReturnPart ):
680
-
681
- def process_content (content : Any ) -> Any :
682
- if isinstance (content , _messages .MultiModalContentTypes ):
683
- if isinstance (content , _messages .BinaryContent ):
684
- identifier = multi_modal_content_identifier (content .data )
654
+ # Run all tool tasks in parallel
655
+ results_by_index : dict [int , _messages .ModelRequestPart ] = {}
656
+ with ctx .deps .tracer .start_as_current_span (
657
+ 'running tools' ,
658
+ attributes = {
659
+ 'tools' : [call .tool_name for call in calls_to_run ],
660
+ 'logfire.msg' : f'running { len (calls_to_run )} tool{ "" if len (calls_to_run ) == 1 else "s" } ' ,
661
+ },
662
+ ):
663
+ tasks = [
664
+ asyncio .create_task (
665
+ _call_function_tool (toolset , call , run_context , ctx .deps .tracer , include_content ),
666
+ name = call .tool_name ,
667
+ )
668
+ for call in calls_to_run
669
+ ]
670
+
671
+ pending = tasks
672
+ while pending :
673
+ done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
674
+ for task in done :
675
+ index = tasks .index (task )
676
+ tool_result = task .result ()
677
+ yield _messages .FunctionToolResultEvent (tool_result , tool_call_id = tool_result .tool_call_id )
678
+
679
+ if isinstance (tool_result , _messages .RetryPromptPart ):
680
+ results_by_index [index ] = tool_result
681
+ elif isinstance (tool_result , _messages .ToolReturnPart ):
682
+
683
+ def process_content (content : Any ) -> Any :
684
+ if isinstance (content , _messages .MultiModalContentTypes ):
685
+ if isinstance (content , _messages .BinaryContent ):
686
+ identifier = multi_modal_content_identifier (content .data )
687
+ else :
688
+ identifier = multi_modal_content_identifier (content .url )
689
+
690
+ user_parts .append (
691
+ _messages .UserPromptPart (
692
+ content = [f'This is file { identifier } :' , content ],
693
+ timestamp = tool_result .timestamp ,
694
+ part_kind = 'user-prompt' ,
695
+ )
696
+ )
697
+ return f'See file { identifier } '
685
698
else :
686
- identifier = multi_modal_content_identifier ( content . url )
699
+ return content
687
700
688
- user_parts .append (
689
- _messages .UserPromptPart (
690
- content = [f'This is file { identifier } :' , content ],
691
- timestamp = tool_result .timestamp ,
692
- part_kind = 'user-prompt' ,
693
- )
694
- )
695
- return f'See file { identifier } '
701
+ if isinstance (tool_result .content , list ):
702
+ contents = cast (list [Any ], tool_result .content ) # type: ignore
703
+ tool_result .content = [process_content (content ) for content in contents ]
696
704
else :
697
- return content
705
+ tool_result . content = process_content ( tool_result . content )
698
706
699
- if isinstance (tool_result .content , list ):
700
- contents = cast (list [Any ], tool_result .content ) # type: ignore
701
- tool_result .content = [process_content (content ) for content in contents ]
707
+ results_by_index [index ] = tool_result
702
708
else :
703
- tool_result .content = process_content (tool_result .content )
704
-
705
- results_by_index [index ] = tool_result
706
- else :
707
- assert_never (tool_result )
709
+ assert_never (tool_result )
708
710
709
- # We append the results at the end, rather than as they are received, to retain a consistent ordering
710
- # This is mostly just to simplify testing
711
- for k in sorted (results_by_index ):
712
- parts .append (results_by_index [k ])
711
+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
712
+ # This is mostly just to simplify testing
713
+ for k in sorted (results_by_index ):
714
+ parts .append (results_by_index [k ])
713
715
714
716
deferred_calls : list [_messages .ToolCallPart ] = []
715
717
for call in tool_calls_by_kind ['deferred' ]:
@@ -739,6 +741,7 @@ def process_content(content: Any) -> Any:
739
741
parts .extend (user_parts )
740
742
741
743
if final_result :
744
+ # TODO: Use some better "box" object
742
745
final_result_holder .append (final_result )
743
746
744
747
0 commit comments