Skip to content

Commit a0c3abb

Browse files
Update cohere and MCP, add support for MCP ResourceLink returned from tools (#2094)
Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent e295e5e commit a0c3abb

File tree

11 files changed

+1786
-44
lines changed

11 files changed

+1786
-44
lines changed

docs/mcp/server.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,10 @@ async def sampling_callback(
117117
SamplingMessage(
118118
role='user',
119119
content=TextContent(
120-
type='text', text='write a poem about socks', annotations=None
120+
type='text',
121+
text='write a poem about socks',
122+
annotations=None,
123+
meta=None,
121124
),
122125
)
123126
]

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ async def direct_call_tool(
151151
except McpError as e:
152152
raise exceptions.ModelRetry(e.error.message)
153153

154-
content = [self._map_tool_result_part(part) for part in result.content]
154+
content = [await self._map_tool_result_part(part) for part in result.content]
155155

156156
if result.isError:
157157
text = '\n'.join(str(part) for part in content)
@@ -262,8 +262,8 @@ async def _sampling_callback(
262262
model=self.sampling_model.model_name,
263263
)
264264

265-
def _map_tool_result_part(
266-
self, part: mcp_types.Content
265+
async def _map_tool_result_part(
266+
self, part: mcp_types.ContentBlock
267267
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
268268
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
269269

@@ -285,18 +285,29 @@ def _map_tool_result_part(
285285
) # pragma: no cover
286286
elif isinstance(part, mcp_types.EmbeddedResource):
287287
resource = part.resource
288-
if isinstance(resource, mcp_types.TextResourceContents):
289-
return resource.text
290-
elif isinstance(resource, mcp_types.BlobResourceContents):
291-
return messages.BinaryContent(
292-
data=base64.b64decode(resource.blob),
293-
media_type=resource.mimeType or 'application/octet-stream',
294-
)
295-
else:
296-
assert_never(resource)
288+
return self._get_content(resource)
289+
elif isinstance(part, mcp_types.ResourceLink):
290+
resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
291+
return (
292+
self._get_content(resource_result.contents[0])
293+
if len(resource_result.contents) == 1
294+
else [self._get_content(resource) for resource in resource_result.contents]
295+
)
297296
else:
298297
assert_never(part)
299298

299+
def _get_content(
300+
self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents
301+
) -> str | messages.BinaryContent:
302+
if isinstance(resource, mcp_types.TextResourceContents):
303+
return resource.text
304+
elif isinstance(resource, mcp_types.BlobResourceContents):
305+
return messages.BinaryContent(
306+
data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream'
307+
)
308+
else:
309+
assert_never(resource)
310+
300311

301312
@dataclass
302313
class MCPServerStdio(MCPServer):

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@
3838
AssistantChatMessageV2,
3939
AsyncClientV2,
4040
ChatMessageV2,
41-
ChatResponse,
4241
SystemChatMessageV2,
43-
TextAssistantMessageContentItem,
42+
TextAssistantMessageV2ContentItem,
4443
ToolCallV2,
4544
ToolCallV2Function,
4645
ToolChatMessageV2,
4746
ToolV2,
4847
ToolV2Function,
4948
UserChatMessageV2,
49+
V2ChatResponse,
5050
)
5151
from cohere.core.api_error import ApiError
5252
from cohere.v2.client import OMIT
@@ -164,7 +164,7 @@ async def _chat(
164164
messages: list[ModelMessage],
165165
model_settings: CohereModelSettings,
166166
model_request_parameters: ModelRequestParameters,
167-
) -> ChatResponse:
167+
) -> V2ChatResponse:
168168
tools = self._get_tools(model_request_parameters)
169169
cohere_messages = self._map_messages(messages)
170170
try:
@@ -185,7 +185,7 @@ async def _chat(
185185
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
186186
raise # pragma: no cover
187187

188-
def _process_response(self, response: ChatResponse) -> ModelResponse:
188+
def _process_response(self, response: V2ChatResponse) -> ModelResponse:
189189
"""Process a non-streamed response, and prepare a message to return."""
190190
parts: list[ModelResponsePart] = []
191191
if response.message.content is not None and len(response.message.content) > 0:
@@ -227,7 +227,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
227227
assert_never(item)
228228
message_param = AssistantChatMessageV2(role='assistant')
229229
if texts:
230-
message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
230+
message_param.content = [TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))]
231231
if tool_calls:
232232
message_param.tool_calls = tool_calls
233233
cohere_messages.append(message_param)
@@ -294,7 +294,7 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
294294
assert_never(part)
295295

296296

297-
def _map_usage(response: ChatResponse) -> usage.Usage:
297+
def _map_usage(response: V2ChatResponse) -> usage.Usage:
298298
u = response.usage
299299
if u is None:
300300
return usage.Usage()

pydantic_ai_slim/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ dependencies = [
6363
logfire = ["logfire>=3.11.0"]
6464
# Models
6565
openai = ["openai>=1.92.0"]
66-
cohere = ["cohere>=5.13.11; platform_system != 'Emscripten'"]
66+
cohere = ["cohere>=5.16.0; platform_system != 'Emscripten'"]
6767
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
6868
google = ["google-genai>=1.24.0"]
6969
anthropic = ["anthropic>=0.52.0"]
@@ -77,7 +77,7 @@ tavily = ["tavily-python>=0.5.0"]
7777
# CLI
7878
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
7979
# MCP
80-
mcp = ["mcp>=1.9.4; python_version >= '3.10'"]
80+
mcp = ["mcp>=1.10.0; python_version >= '3.10'"]
8181
# Evals
8282
evals = ["pydantic-evals=={{ version }}"]
8383
# A2A

tests/assets/product_name.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Pydantic AI

tests/cassettes/test_mcp/test_tool_returning_audio_resource_link.yaml

Lines changed: 321 additions & 0 deletions
Large diffs are not rendered by default.

tests/cassettes/test_mcp/test_tool_returning_image_resource_link.yaml

Lines changed: 447 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)