Skip to content

Commit 57e568b

Browse files
committed
Add builtin_tools to Agent
1 parent d1615df commit 57e568b

File tree

5 files changed

+128
-0
lines changed

5 files changed

+128
-0
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from opentelemetry.trace import Tracer
1313
from typing_extensions import TypeGuard, TypeVar, assert_never
1414

15+
from pydantic_ai.builtin_tools import AbstractBuiltinTool
1516
from pydantic_graph import BaseNode, Graph, GraphRunContext
1617
from pydantic_graph.nodes import End, NodeRunEndT
1718

@@ -92,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
9293
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
9394

9495
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
96+
builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False)
9597
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
9698
default_retries: int
9799

@@ -242,6 +244,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
242244
output_schema = ctx.deps.output_schema
243245
return models.ModelRequestParameters(
244246
function_tools=function_tool_defs,
247+
builtin_tools=ctx.deps.builtin_tools,
245248
allow_text_output=allow_text_output(output_schema),
246249
output_tools=output_schema.tool_defs() if output_schema is not None else [],
247250
)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pydantic.json_schema import GenerateJsonSchema
1515
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
1616

17+
from pydantic_ai.builtin_tools import AbstractBuiltinTool, WebSearchTool
1718
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
1819
from pydantic_graph._utils import get_event_loop
1920

@@ -172,6 +173,7 @@ def __init__(
172173
retries: int = 1,
173174
output_retries: int | None = None,
174175
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
176+
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
175177
mcp_servers: Sequence[MCPServer] = (),
176178
defer_model_check: bool = False,
177179
end_strategy: EndStrategy = 'early',
@@ -200,6 +202,7 @@ def __init__(
200202
result_tool_description: str | None = None,
201203
result_retries: int | None = None,
202204
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
205+
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
203206
mcp_servers: Sequence[MCPServer] = (),
204207
defer_model_check: bool = False,
205208
end_strategy: EndStrategy = 'early',
@@ -223,6 +226,7 @@ def __init__(
223226
retries: int = 1,
224227
output_retries: int | None = None,
225228
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
229+
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
226230
mcp_servers: Sequence[MCPServer] = (),
227231
defer_model_check: bool = False,
228232
end_strategy: EndStrategy = 'early',
@@ -251,6 +255,8 @@ def __init__(
251255
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
252256
tools: Tools to register with the agent, you can also register tools via the decorators
253257
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
258+
builtin_tools: The builtin tools that the agent will use. This depends on the model, as some models may not
259+
support certain tools. On models that don't support certain tools, the tool will be ignored.
254260
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
255261
for each server you want the agent to connect to.
256262
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
@@ -334,6 +340,14 @@ def __init__(
334340
self._default_retries = retries
335341
self._max_result_retries = output_retries if output_retries is not None else retries
336342
self._mcp_servers = mcp_servers
343+
self._builtin_tools: list[AbstractBuiltinTool] = []
344+
345+
for tool in builtin_tools:
346+
if tool == 'web-search':
347+
self._builtin_tools.append(WebSearchTool())
348+
else:
349+
self._builtin_tools.append(tool)
350+
337351
for tool in tools:
338352
if isinstance(tool, Tool):
339353
self._register_tool(tool)
@@ -688,6 +702,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
688702
output_schema=output_schema,
689703
output_validators=output_validators,
690704
function_tools=self._function_tools,
705+
builtin_tools=self._builtin_tools,
691706
mcp_servers=self._mcp_servers,
692707
default_retries=self._default_retries,
693708
tracer=tracer,
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from __future__ import annotations as _annotations
2+
3+
from abc import ABC
4+
from dataclasses import dataclass, field
5+
from typing import Literal
6+
7+
from typing_extensions import TypedDict
8+
9+
__all__ = ('AbstractBuiltinTool', 'WebSearchTool', 'UserLocation')
10+
11+
12+
class AbstractBuiltinTool(ABC):
13+
"""A builtin tool that can be used by an agent.
14+
15+
This class is abstract and cannot be instantiated directly.
16+
"""
17+
18+
19+
class UserLocation(TypedDict, total=False):
20+
"""Allows you to localize search results based on a user's location.
21+
22+
Supported by:
23+
* Anthropic
24+
* OpenAI
25+
"""
26+
27+
city: str
28+
country: str
29+
region: str
30+
timezone: str
31+
32+
33+
@dataclass
34+
class WebSearchTool(AbstractBuiltinTool):
35+
"""A builtin tool that allows your agent to search the web for information.
36+
37+
The parameters that PydanticAI passes depend on the model, as some parameters may not be supported by certain models.
38+
"""
39+
40+
search_context_size: Literal['low', 'medium', 'high'] = 'medium'
41+
"""The `search_context_size` parameter controls how much context is retrieved from the web to help the tool formulate a response.
42+
43+
Supported by:
44+
* OpenAI
45+
"""
46+
47+
user_location: UserLocation = field(default_factory=UserLocation)
48+
"""The `user_location` parameter allows you to localize search results based on a user's location.
49+
50+
Supported by:
51+
* Anthropic
52+
* OpenAI
53+
"""
54+
55+
blocked_domains: list[str] | None = None
56+
"""If provided, these domains will never appear in results.
57+
58+
With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.
59+
60+
Supported by:
61+
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering)
62+
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings)
63+
* MistralAI
64+
"""
65+
66+
allowed_domains: list[str] | None = None
67+
"""If provided, only these domains will be included in results.
68+
69+
With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.
70+
71+
Supported by:
72+
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering)
73+
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings)
74+
"""

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import httpx
1717
from typing_extensions import Literal, TypeAliasType
1818

19+
from pydantic_ai.builtin_tools import AbstractBuiltinTool
20+
1921
from .._parts_manager import ModelResponsePartsManager
2022
from ..exceptions import UserError
2123
from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent
@@ -261,6 +263,7 @@ class ModelRequestParameters:
261263
"""Configuration for an agent's request to a model, specifically related to tools and output handling."""
262264

263265
function_tools: list[ToolDefinition] = field(default_factory=list)
266+
builtin_tools: list[AbstractBuiltinTool] = field(default_factory=list)
264267
allow_text_output: bool = True
265268
output_tools: list[ToolDefinition] = field(default_factory=list)
266269

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from typing_extensions import assert_never
1313

14+
from pydantic_ai.builtin_tools import WebSearchTool
1415
from pydantic_ai.providers import Provider, infer_provider
1516

1617
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
@@ -58,6 +59,11 @@
5859
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
5960
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
6061
from openai.types.chat.chat_completion_content_part_param import File, FileFile
62+
from openai.types.chat.completion_create_params import (
63+
WebSearchOptions,
64+
WebSearchOptionsUserLocation,
65+
WebSearchOptionsUserLocationApproximate,
66+
)
6167
from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam
6268
from openai.types.responses.response_input_param import FunctionCallOutput, Message
6369
from openai.types.shared import ReasoningEffort
@@ -254,6 +260,7 @@ async def _completions_create(
254260
model_request_parameters: ModelRequestParameters,
255261
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
256262
tools = self._get_tools(model_request_parameters)
263+
web_search_options = self._get_web_search_options(model_request_parameters)
257264

258265
# standalone function to make it easier to override
259266
if not tools:
@@ -288,6 +295,7 @@ async def _completions_create(
288295
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
289296
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
290297
user=model_settings.get('openai_user', NOT_GIVEN),
298+
web_search_options=web_search_options or NOT_GIVEN,
291299
extra_headers=extra_headers,
292300
extra_body=model_settings.get('extra_body'),
293301
)
@@ -327,6 +335,17 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[c
327335
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
328336
return tools
329337

338+
def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None:
339+
for tool in model_request_parameters.builtin_tools:
340+
if isinstance(tool, WebSearchTool):
341+
return WebSearchOptions(
342+
search_context_size=tool.search_context_size,
343+
user_location=WebSearchOptionsUserLocation(
344+
type='approximate',
345+
approximate=WebSearchOptionsUserLocationApproximate(**tool.user_location),
346+
),
347+
)
348+
330349
async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
331350
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
332351
openai_messages: list[chat.ChatCompletionMessageParam] = []
@@ -601,6 +620,7 @@ async def _responses_create(
601620
) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]:
602621
tools = self._get_tools(model_request_parameters)
603622
tools = list(model_settings.get('openai_builtin_tools', [])) + tools
623+
tools = self._get_builtin_tools(model_request_parameters) + tools
604624

605625
# standalone function to make it easier to override
606626
if not tools:
@@ -653,6 +673,19 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[r
653673
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
654674
return tools
655675

676+
def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.ToolParam]:
677+
tools: list[responses.ToolParam] = []
678+
for tool in model_request_parameters.builtin_tools:
679+
if isinstance(tool, WebSearchTool):
680+
tools.append(
681+
responses.WebSearchToolParam(
682+
type='web_search_preview',
683+
search_context_size=tool.search_context_size,
684+
user_location={'type': 'approximate', **tool.user_location},
685+
)
686+
)
687+
return tools
688+
656689
@staticmethod
657690
def _map_tool_definition(f: ToolDefinition) -> responses.FunctionToolParam:
658691
return {

0 commit comments

Comments
 (0)