Skip to content

Commit 1e1ec05

Browse files
committed
Improve test coverage
1 parent ecc1394 commit 1e1ec05

File tree

8 files changed

+685
-53
lines changed

8 files changed

+685
-53
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import dataclasses
5+
import hashlib
56
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
67
from contextlib import asynccontextmanager, contextmanager
78
from contextvars import ContextVar
@@ -546,6 +547,13 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
546547
)
547548

548549

550+
def multi_modal_content_identifier(identifier: str | bytes) -> str:
551+
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
552+
if isinstance(identifier, str):
553+
identifier = identifier.encode('utf-8')
554+
return hashlib.sha1(identifier).hexdigest()[:6]
555+
556+
549557
async def process_function_tools( # noqa C901
550558
tool_calls: list[_messages.ToolCallPart],
551559
output_tool_name: str | None,
@@ -671,7 +679,11 @@ async def process_function_tools( # noqa C901
671679
processed_contents: list[Any] = []
672680
for content in contents:
673681
if isinstance(content, _messages.MultiModalContentTypes):
674-
identifier = content.identifier
682+
if isinstance(content, _messages.BinaryContent):
683+
identifier = multi_modal_content_identifier(content.data)
684+
else:
685+
identifier = multi_modal_content_identifier(content.url)
686+
675687
user_parts.append(
676688
_messages.UserPromptPart(
677689
content=[f'This is file {identifier}:', content],

pydantic_ai_slim/pydantic_ai/mcp.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ async def call_tool(
9696
9797
Returns:
9898
The result of the tool call.
99+
100+
Raises:
101+
ModelRetry: If the tool call fails.
99102
"""
100103
result = await self._client.call_tool(tool_name, arguments)
101104

@@ -131,27 +134,20 @@ async def call_tool(
131134
else:
132135
assert_never(part)
133136

134-
if result.isError:
135-
raise ModelRetry('\n'.join(text_parts) or 'Unknown error')
136-
137-
if text_parts and not binary_parts and not json_parts:
138-
return '\n'.join(text_parts)
137+
text = '\n'.join(text_parts)
139138

140-
if json_parts and not text_parts and not binary_parts:
141-
if len(json_parts) == 1:
142-
return json_parts[0]
143-
return json_parts
139+
if result.isError:
140+
raise ModelRetry(text or 'Unknown error')
144141

145-
if binary_parts and not text_parts and not json_parts:
146-
if len(binary_parts) == 1:
147-
return binary_parts[0]
148-
return binary_parts
142+
parts: list[str | BinaryContent | dict[str, Any] | list[Any]] = []
143+
if text:
144+
parts.append(text)
145+
parts.extend(json_parts)
146+
parts.extend(binary_parts)
149147

150-
return [
151-
*text_parts,
152-
*json_parts,
153-
*binary_parts,
154-
]
148+
if len(parts) == 1:
149+
return parts[0]
150+
return parts
155151

156152
async def __aenter__(self) -> Self:
157153
self._exit_stack = AsyncExitStack()

pydantic_ai_slim/pydantic_ai/messages.py

-27
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations as _annotations
22

3-
import hashlib
43
import uuid
54
from collections.abc import Sequence
65
from dataclasses import dataclass, field, replace
@@ -45,12 +44,6 @@
4544
VideoFormat: TypeAlias = Literal['mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp']
4645

4746

48-
def _multi_modal_content_identifier(identifier: str | bytes) -> str:
49-
if isinstance(identifier, str):
50-
identifier = identifier.encode('utf-8')
51-
return hashlib.sha1(identifier).hexdigest()[:6]
52-
53-
5447
@dataclass
5548
class SystemPromptPart:
5649
"""A system prompt, generally written by the application developer.
@@ -87,10 +80,6 @@ class VideoUrl:
8780
kind: Literal['video-url'] = 'video-url'
8881
"""Type identifier, this is available on all parts as a discriminator."""
8982

90-
@property
91-
def identifier(self) -> str:
92-
return _multi_modal_content_identifier(self.url)
93-
9483
@property
9584
def media_type(self) -> VideoMediaType: # pragma: no cover
9685
"""Return the media type of the video, based on the url."""
@@ -132,10 +121,6 @@ class AudioUrl:
132121
kind: Literal['audio-url'] = 'audio-url'
133122
"""Type identifier, this is available on all parts as a discriminator."""
134123

135-
@property
136-
def identifier(self) -> str:
137-
return _multi_modal_content_identifier(self.url)
138-
139124
@property
140125
def media_type(self) -> AudioMediaType:
141126
"""Return the media type of the audio file, based on the url."""
@@ -157,10 +142,6 @@ class ImageUrl:
157142
kind: Literal['image-url'] = 'image-url'
158143
"""Type identifier, this is available on all parts as a discriminator."""
159144

160-
@property
161-
def identifier(self) -> str:
162-
return _multi_modal_content_identifier(self.url)
163-
164145
@property
165146
def media_type(self) -> ImageMediaType:
166147
"""Return the media type of the image, based on the url."""
@@ -194,10 +175,6 @@ class DocumentUrl:
194175
kind: Literal['document-url'] = 'document-url'
195176
"""Type identifier, this is available on all parts as a discriminator."""
196177

197-
@property
198-
def identifier(self) -> str:
199-
return _multi_modal_content_identifier(self.url)
200-
201178
@property
202179
def media_type(self) -> str:
203180
"""Return the media type of the document, based on the url."""
@@ -228,10 +205,6 @@ class BinaryContent:
228205
kind: Literal['binary'] = 'binary'
229206
"""Type identifier, this is available on all parts as a discriminator."""
230207

231-
@property
232-
def identifier(self) -> str:
233-
return _multi_modal_content_identifier(self.data)
234-
235208
@property
236209
def is_audio(self) -> bool:
237210
"""Return `True` if the media type is an audio type."""

0 commit comments

Comments
 (0)