Skip to content

Commit c5ef5f6

Browse files
committed
Address some feedback
1 parent 5ca305e commit c5ef5f6

25 files changed

+480
-86
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
1818
from pydantic_ai._utils import is_async_callable, run_in_executor
1919
from pydantic_ai.toolsets import AbstractToolset
20-
from pydantic_ai.toolsets.run import RunToolset
20+
from pydantic_ai.toolsets._run import RunToolset
2121
from pydantic_graph import BaseNode, Graph, GraphRunContext
2222
from pydantic_graph.nodes import End, NodeRunEndT
2323

@@ -505,7 +505,7 @@ async def _handle_tool_calls(
505505
elif deferred_tool_calls := ctx.deps.toolset.get_deferred_tool_calls(tool_calls):
506506
if not ctx.deps.output_schema.allows_deferred_tool_calls:
507507
raise exceptions.UserError(
508-
'There are deferred tool calls but DeferredToolCalls is not among output types.'
508+
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
509509
)
510510
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None)
511511
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
@@ -586,7 +586,7 @@ async def process_function_tools( # noqa: C901
586586
587587
Also add stub return parts for any other tools that need it.
588588
589-
Because async iterators can't have return values, we use `parts` as an output argument.
589+
Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments.
590590
"""
591591
run_context = build_run_context(ctx)
592592

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
3131
from .toolsets import AbstractToolset
32-
from .toolsets.run import RunToolset
32+
from .toolsets._run import RunToolset
3333

3434
if TYPE_CHECKING:
3535
from .profiles import ModelProfile

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import dataclasses
4+
from collections import defaultdict
45
from collections.abc import Sequence
56
from dataclasses import field
67
from typing import TYPE_CHECKING, Generic
@@ -31,8 +32,8 @@ class RunContext(Generic[AgentDepsT]):
3132
"""The original user prompt passed to the run."""
3233
messages: list[_messages.ModelMessage] = field(default_factory=list)
3334
"""Messages exchanged in the conversation so far."""
34-
retries: dict[str, int] = field(default_factory=dict)
35-
"""Number of retries for each tool."""
35+
retries: defaultdict[str, int] = field(default_factory=lambda: defaultdict(int))
36+
"""Number of retries for each tool so far."""
3637
tool_call_id: str | None = None
3738
"""The ID of the tool call."""
3839
tool_name: str | None = None

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,18 +1783,18 @@ def is_end_node(
17831783

17841784
@asynccontextmanager
17851785
async def run_toolsets(
1786-
self, model: models.Model | models.KnownModelName | str | None = None
1786+
self, sampling_model: models.Model | models.KnownModelName | str | None = None
17871787
) -> AsyncIterator[None]:
1788-
"""Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
1788+
"""Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] among toolsets so they can be used by the agent.
17891789
17901790
Returns: a context manager to start and shutdown the servers.
17911791
"""
17921792
try:
1793-
sampling_model: models.Model | None = self._get_model(model)
1793+
model: models.Model | None = self._get_model(sampling_model)
17941794
except exceptions.UserError: # pragma: no cover
1795-
sampling_model = None
1796-
if sampling_model is not None: # pragma: no branch
1797-
self._toolset.set_mcp_sampling_model(sampling_model)
1795+
model = None
1796+
if model is not None: # pragma: no branch
1797+
self._toolset._set_mcp_sampling_model(model) # type: ignore[reportPrivateUsage]
17981798

17991799
async with self._toolset:
18001800
yield

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
from pydantic_ai._run_context import RunContext
2020
from pydantic_ai.tools import ToolDefinition
2121

22-
from .exceptions import UserError
2322
from .toolsets import AbstractToolset
23+
from .toolsets._run import RunToolset
2424
from .toolsets.prefixed import PrefixedToolset
2525
from .toolsets.processed import ProcessedToolset, ToolProcessFunc
26-
from .toolsets.run import RunToolset
2726

2827
try:
2928
from mcp import types as mcp_types
@@ -104,9 +103,8 @@ async def list_tools(self) -> list[mcp_types.Tool]:
104103
- We don't cache tools as they might change.
105104
- We also don't subscribe to the server to avoid complexity.
106105
"""
107-
if not self.is_running: # pragma: no cover
108-
raise UserError(f'MCP server is not running: {self}')
109-
result = await self._client.list_tools()
106+
async with self:
107+
result = await self._client.list_tools()
110108
return result.tools
111109

112110
async def call_tool(
@@ -134,25 +132,24 @@ async def call_tool(
134132
Raises:
135133
ModelRetry: If the tool call fails.
136134
"""
137-
if not self.is_running: # pragma: no cover
138-
raise UserError(f'MCP server is not running: {self}')
139-
try:
140-
# meta param is not provided by session yet, so build and can send_request directly.
141-
result = await self._client.send_request(
142-
mcp_types.ClientRequest(
143-
mcp_types.CallToolRequest(
144-
method='tools/call',
145-
params=mcp_types.CallToolRequestParams(
146-
name=name,
147-
arguments=tool_args,
148-
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
149-
),
150-
)
151-
),
152-
mcp_types.CallToolResult,
153-
)
154-
except McpError as e:
155-
raise exceptions.ModelRetry(e.error.message)
135+
async with self:
136+
try:
137+
# meta param is not provided by session yet, so build and can send_request directly.
138+
result = await self._client.send_request(
139+
mcp_types.ClientRequest(
140+
mcp_types.CallToolRequest(
141+
method='tools/call',
142+
params=mcp_types.CallToolRequestParams(
143+
name=name,
144+
arguments=tool_args,
145+
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
146+
),
147+
)
148+
),
149+
mcp_types.CallToolResult,
150+
)
151+
except McpError as e:
152+
raise exceptions.ModelRetry(e.error.message)
156153

157154
content = [self._map_tool_result_part(part) for part in result.content]
158155

@@ -196,7 +193,7 @@ def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_
196193
def _max_retries_for_tool(self, name: str) -> int:
197194
return self.max_retries
198195

199-
def set_mcp_sampling_model(self, model: models.Model) -> None:
196+
def _set_mcp_sampling_model(self, model: models.Model) -> None:
200197
self.sampling_model = model
201198

202199
async def __aenter__(self) -> Self:

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydantic import ValidationError
1111
from typing_extensions import TypeVar, deprecated, overload
1212

13-
from pydantic_ai.toolsets.run import RunToolset
13+
from pydantic_ai.toolsets._run import RunToolset
1414

1515
from . import _utils, exceptions, messages as _messages, models
1616
from ._output import (

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,8 @@ class ToolDefinition:
366366
"""The kind of tool:
367367
- `'function'`: a tool that can be executed by Pydantic AI and has its result returned to the model
368368
- `'output'`: a tool that passes through an output value that ends the run
369-
- `'deferred'`: a tool that cannot be executed by Pydantic AI and needs to get a result from the outside.
370-
When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s for each deferred call.
369+
- `'deferred'`: a tool that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via e.g. [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools).
370+
When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call.
371371
"""
372372

373373
__repr__ = _utils.dataclasses_no_defaults_repr

pydantic_ai_slim/pydantic_ai/toolsets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
if TYPE_CHECKING:
1414
from ..models import Model
15-
from .run import RunToolset
15+
from ._run import RunToolset
1616

1717

1818
class AbstractToolset(ABC, Generic[AgentDepsT]):
@@ -80,5 +80,5 @@ async def call_tool(
8080
) -> Any:
8181
raise NotImplementedError()
8282

83-
def set_mcp_sampling_model(self, model: Model) -> None:
83+
def _set_mcp_sampling_model(self, model: Model) -> None:
8484
pass

pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py renamed to pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
ToolDefinition,
99
ToolPrepareFunc,
1010
)
11-
from .mapped import MappedToolset
12-
from .run import RunToolset
11+
from ._mapped import MappedToolset
12+
from ._run import RunToolset
1313
from .wrapper import WrapperToolset
1414

1515

pydantic_ai_slim/pydantic_ai/toolsets/mapped.py renamed to pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
ToolDefinition,
1111
)
1212
from . import AbstractToolset
13-
from .run import RunToolset
13+
from ._run import RunToolset
1414
from .wrapper import WrapperToolset
1515

1616

1717
@dataclass(init=False)
1818
class MappedToolset(WrapperToolset[AgentDepsT]):
19-
"""A toolset that maps the names of the tools it contains."""
19+
"""A toolset that maps renamed tool names to original tool names. Used by `IndividuallyPreparedToolset` as the prepare function may rename a tool."""
2020

2121
name_map: dict[str, str]
2222
_tool_defs: list[ToolDefinition]

0 commit comments

Comments
 (0)