Skip to content

Commit e3dda9d

Browse files
committed
Add more work on it
1 parent 97ab44b commit e3dda9d

16 files changed

+1111
-15
lines changed

pydantic_ai_slim/pydantic_ai/builtin_tools.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class AbstractBuiltinTool(ABC):
1414
"""A builtin tool that can be used by an agent.
1515
1616
This class is abstract and cannot be instantiated directly.
17+
18+
The builtin tools are are passed to the model as part of the `ModelRequestParameters`.
1719
"""
1820

1921

@@ -73,3 +75,10 @@ class WebSearchTool(AbstractBuiltinTool):
7375
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering)
7476
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings)
7577
"""
78+
79+
max_uses: int | None = None
80+
"""If provided, the tool will stop searching the web after the given number of uses.
81+
82+
Supported by:
83+
* Anthropic
84+
"""

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
from json import JSONDecodeError, loads as json_loads
99
from typing import Any, Literal, Union, cast, overload
1010

11+
from anthropic.types import ServerToolUseBlock, ToolUnionParam, WebSearchTool20250305Param, WebSearchToolResultBlock
12+
from anthropic.types.web_search_tool_20250305_param import UserLocation
1113
from typing_extensions import assert_never
1214

15+
from pydantic_ai.builtin_tools import WebSearchTool
16+
1317
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
1418
from .._utils import guard_tool_call_id as _guard_tool_call_id
1519
from ..messages import (
@@ -207,6 +211,7 @@ async def _messages_create(
207211
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
208212
# standalone function to make it easier to override
209213
tools = self._get_tools(model_request_parameters)
214+
tools += self._get_builtin_tools(model_request_parameters)
210215
tool_choice: ToolChoiceParam | None
211216

212217
if not tools:
@@ -252,8 +257,11 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
252257
for item in response.content:
253258
if isinstance(item, TextBlock):
254259
items.append(TextPart(content=item.text))
260+
if isinstance(item, WebSearchToolResultBlock):
261+
# TODO(Marcelo): We should send something back to the user, because we need to send it back on the next request.
262+
...
255263
else:
256-
assert isinstance(item, ToolUseBlock), 'unexpected item type'
264+
assert isinstance(item, (ToolUseBlock, ServerToolUseBlock)), f'unexpected item type {type(item)}'
257265
items.append(
258266
ToolCallPart(
259267
tool_name=item.name,
@@ -282,6 +290,22 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
282290
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
283291
return tools
284292

293+
def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolUnionParam]:
294+
tools: list[ToolUnionParam] = []
295+
for tool in model_request_parameters.builtin_tools:
296+
if isinstance(tool, WebSearchTool):
297+
user_location = UserLocation(type='approximate', **tool.user_location) if tool.user_location else None
298+
tools.append(
299+
WebSearchTool20250305Param(
300+
name='web_search',
301+
type='web_search_20250305',
302+
allowed_domains=tool.allowed_domains,
303+
blocked_domains=tool.blocked_domains,
304+
user_location=user_location,
305+
)
306+
)
307+
return tools
308+
285309
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
286310
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
287311
system_prompt_parts: list[str] = []

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def _customize_tool_def(t: ToolDefinition):
173173

174174
return ModelRequestParameters(
175175
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
176+
builtin_tools=model_request_parameters.builtin_tools,
176177
allow_text_output=model_request_parameters.allow_text_output,
177178
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
178179
)

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
234234
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
235235
choice = response.choices[0]
236236
items: list[ModelResponsePart] = []
237+
# TODO(Marcelo): The `choice.message.executed_tools` has the server tools executed by Groq.
237238
if choice.message.content is not None:
238239
items.append(TextPart(content=choice.message.content))
239240
if choice.message.tool_calls is not None:

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,6 @@ async def _completions_create(
278278
return await self.client.chat.completions.create(
279279
model=self._model_name,
280280
messages=openai_messages,
281-
n=1,
282281
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
283282
tools=tools or NOT_GIVEN,
284283
tool_choice=tool_choice or NOT_GIVEN,
@@ -338,13 +337,15 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[c
338337
def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None:
339338
for tool in model_request_parameters.builtin_tools:
340339
if isinstance(tool, WebSearchTool):
341-
return WebSearchOptions(
342-
search_context_size=tool.search_context_size,
343-
user_location=WebSearchOptionsUserLocation(
344-
type='approximate',
345-
approximate=WebSearchOptionsUserLocationApproximate(**tool.user_location),
346-
),
347-
)
340+
if tool.user_location:
341+
return WebSearchOptions(
342+
search_context_size=tool.search_context_size,
343+
user_location=WebSearchOptionsUserLocation(
344+
type='approximate',
345+
approximate=WebSearchOptionsUserLocationApproximate(**tool.user_location),
346+
),
347+
)
348+
return WebSearchOptions(search_context_size=tool.search_context_size)
348349

349350
async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
350351
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
@@ -677,13 +678,14 @@ def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -
677678
tools: list[responses.ToolParam] = []
678679
for tool in model_request_parameters.builtin_tools:
679680
if isinstance(tool, WebSearchTool):
680-
tools.append(
681-
responses.WebSearchToolParam(
682-
type='web_search_preview',
683-
search_context_size=tool.search_context_size,
684-
user_location={'type': 'approximate', **tool.user_location},
685-
)
681+
web_search_tool = responses.WebSearchToolParam(
682+
type='web_search_preview', search_context_size=tool.search_context_size
686683
)
684+
if tool.user_location:
685+
web_search_tool['user_location'] = responses.web_search_tool_param.UserLocation(
686+
type='approximate', **tool.user_location
687+
)
688+
tools.append(web_search_tool)
687689
return tools
688690

689691
@staticmethod
@@ -1112,6 +1114,7 @@ def _customize_tool_def(t: ToolDefinition):
11121114

11131115
return ModelRequestParameters(
11141116
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
1117+
builtin_tools=model_request_parameters.builtin_tools,
11151118
allow_text_output=model_request_parameters.allow_text_output,
11161119
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
11171120
)

0 commit comments

Comments
 (0)