Skip to content

Commit 64dacbb

Browse files
committed
Merge branch 'main' into toolsets
# Conflicts: # pydantic_ai_slim/pydantic_ai/tools.py
2 parents f660cc1 + 5b94841 commit 64dacbb

File tree

2 files changed

+67
-16
lines changed

2 files changed

+67
-16
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -737,22 +737,40 @@ async def _call_function_tool(
737737
{
738738
'type': 'object',
739739
'properties': {
740-
**({'tool_arguments': {'type': 'object'}} if include_content else {}),
740+
**(
741+
{
742+
'tool_arguments': {'type': 'object'},
743+
'tool_response': {'type': 'object'},
744+
}
745+
if include_content
746+
else {}
747+
),
741748
'gen_ai.tool.name': {},
742749
'gen_ai.tool.call.id': {},
743750
},
744751
}
745752
),
746753
}
747754

748-
with tracer.start_as_current_span('running tool', attributes=span_attributes):
755+
with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
749756
try:
750757
tool_result = await _call_tool(toolset, tool_call, run_context)
751758
except ToolRetryError as e:
759+
part = e.tool_retry
760+
if include_content and span.is_recording():
761+
span.set_attribute('tool_response', part.model_response())
752762
return (e.tool_retry, [])
753763

764+
part = _messages.ToolReturnPart(
765+
tool_name=tool_call.tool_name,
766+
content=tool_result,
767+
tool_call_id=tool_call.tool_call_id,
768+
)
769+
770+
if include_content and span.is_recording():
771+
span.set_attribute('tool_response', part.model_response_str())
772+
754773
extra_parts: list[_messages.ModelRequestPart] = []
755-
metadata = None
756774

757775
def process_content(content: Any) -> Any:
758776
if isinstance(content, _messages.ToolReturn):
@@ -790,27 +808,20 @@ def process_content(content: Any) -> Any:
790808
f'Please use `content` instead.'
791809
)
792810

793-
metadata = tool_result.metadata
811+
part.content = tool_result.return_value # type: ignore
812+
part.metadata = tool_result.metadata
794813
if tool_result.content:
795814
extra_parts.append(
796815
_messages.UserPromptPart(
797816
content=list(tool_result.content),
798817
part_kind='user-prompt',
799818
)
800819
)
801-
tool_result = tool_result.return_value # type: ignore
802820
elif isinstance(tool_result, list):
803821
contents = cast(list[Any], tool_result)
804-
tool_result = [process_content(content) for content in contents]
822+
part.content = [process_content(content) for content in contents]
805823
else:
806-
tool_result = process_content(tool_result)
807-
808-
part = _messages.ToolReturnPart(
809-
tool_name=tool_call.tool_name,
810-
content=tool_result,
811-
metadata=metadata,
812-
tool_call_id=tool_call.tool_call_id,
813-
)
824+
part.content = process_content(tool_result)
814825

815826
return (part, extra_parts)
816827

tests/test_logfire.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,46 @@ async def add_numbers(x: int, y: int) -> int:
555555
]
556556

557557
if include_content:
558-
assert tool_attributes['tool_arguments'] == snapshot('{"x":42,"y":42}')
558+
assert tool_attributes == snapshot(
559+
{
560+
'gen_ai.tool.name': 'add_numbers',
561+
'gen_ai.tool.call.id': IsStr(),
562+
'tool_arguments': '{"x":42,"y":42}',
563+
'tool_response': '84',
564+
'logfire.msg': 'running tool: add_numbers',
565+
'logfire.json_schema': IsJson(
566+
snapshot(
567+
{
568+
'type': 'object',
569+
'properties': {
570+
'tool_arguments': {'type': 'object'},
571+
'tool_response': {'type': 'object'},
572+
'gen_ai.tool.name': {},
573+
'gen_ai.tool.call.id': {},
574+
},
575+
}
576+
)
577+
),
578+
'logfire.span_type': 'span',
579+
}
580+
)
559581
else:
560-
assert 'tool_arguments' not in tool_attributes
582+
assert tool_attributes == snapshot(
583+
{
584+
'gen_ai.tool.name': 'add_numbers',
585+
'gen_ai.tool.call.id': IsStr(),
586+
'logfire.msg': 'running tool: add_numbers',
587+
'logfire.json_schema': IsJson(
588+
snapshot(
589+
{
590+
'type': 'object',
591+
'properties': {
592+
'gen_ai.tool.name': {},
593+
'gen_ai.tool.call.id': {},
594+
},
595+
}
596+
)
597+
),
598+
'logfire.span_type': 'span',
599+
}
600+
)

0 commit comments

Comments
 (0)