Skip to content

Commit a2263b1

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: refactor toolset to extract tool_filter logic to base class
PiperOrigin-RevId: 761828251
1 parent e0851a1 commit a2263b1

File tree

6 files changed

+34
-39
lines changed

6 files changed

+34
-39
lines changed

src/google/adk/tools/apihub_tool/apihub_toolset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(
131131
be either a tool predicate or a list of tool names of the tools to
132132
expose.
133133
"""
134+
super().__init__(tool_filter=tool_filter)
134135
self.name = name
135136
self.description = description
136137
self._apihub_resource_name = apihub_resource_name
@@ -143,7 +144,6 @@ def __init__(
143144
self._openapi_toolset = None
144145
self._auth_scheme = auth_scheme
145146
self._auth_credential = auth_credential
146-
self.tool_filter = tool_filter
147147

148148
if not self._lazy_load_spec:
149149
self._prepare_toolset()

src/google/adk/tools/application_integration_tool/application_integration_toolset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(
128128
Exception: If there is an error during the initialization of the
129129
integration or connection client.
130130
"""
131+
super().__init__(tool_filter=tool_filter)
131132
self.project = project
132133
self.location = location
133134
self._integration = integration
@@ -140,7 +141,6 @@ def __init__(
140141
self._service_account_json = service_account_json
141142
self._auth_scheme = auth_scheme
142143
self._auth_credential = auth_credential
143-
self.tool_filter = tool_filter
144144

145145
integration_client = IntegrationClient(
146146
project,
@@ -263,7 +263,11 @@ async def get_tools(
263263
readonly_context: Optional[ReadonlyContext] = None,
264264
) -> List[RestApiTool]:
265265
return (
266-
self._tools
266+
[
267+
tool
268+
for tool in self._tools
269+
if self._is_tool_selected(tool, readonly_context)
270+
]
267271
if self._openapi_toolset is None
268272
else await self._openapi_toolset.get_tools(readonly_context)
269273
)

src/google/adk/tools/base_toolset.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from abc import ABC
22
from abc import abstractmethod
3+
from typing import List
34
from typing import Optional
45
from typing import Protocol
56
from typing import runtime_checkable
7+
from typing import Union
68

79
from ..agents.readonly_context import ReadonlyContext
810
from .base_tool import BaseTool
@@ -34,9 +36,15 @@ class BaseToolset(ABC):
3436
A toolset is a collection of tools that can be used by an agent.
3537
"""
3638

39+
def __init__(
40+
self, *, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None
41+
):
42+
self.tool_filter = tool_filter
43+
3744
@abstractmethod
3845
async def get_tools(
39-
self, readonly_context: Optional[ReadonlyContext] = None
46+
self,
47+
readonly_context: Optional[ReadonlyContext] = None,
4048
) -> list[BaseTool]:
4149
"""Return all tools in the toolset based on the provided context.
4250
@@ -57,3 +65,17 @@ async def close(self) -> None:
5765
should ensure that any open connections, files, or other managed
5866
resources are properly released to prevent leaks.
5967
"""
68+
69+
def _is_tool_selected(
70+
self, tool: BaseTool, readonly_context: ReadonlyContext
71+
) -> bool:
72+
if not self.tool_filter:
73+
return True
74+
75+
if isinstance(self.tool_filter, ToolPredicate):
76+
return self.tool_filter(tool, readonly_context)
77+
78+
if isinstance(self.tool_filter, list):
79+
return tool.name in self.tool_filter
80+
81+
return False

src/google/adk/tools/google_api_tool/google_api_toolset.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,6 @@ def __init__(
5656
self._openapi_toolset = self._load_toolset_with_oidc_auth()
5757
self.tool_filter = tool_filter
5858

59-
def _is_tool_selected(
60-
self, tool: GoogleApiTool, readonly_context: ReadonlyContext
61-
) -> bool:
62-
if not self.tool_filter:
63-
return True
64-
65-
if isinstance(self.tool_filter, ToolPredicate):
66-
return self.tool_filter(tool, readonly_context)
67-
68-
if isinstance(self.tool_filter, list):
69-
return tool.name in self.tool_filter
70-
71-
return False
72-
7359
@override
7460
async def get_tools(
7561
self, readonly_context: Optional[ReadonlyContext] = None

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090

9191
if not connection_params:
9292
raise ValueError("Missing connection params in MCPToolset.")
93+
super().__init__(tool_filter=tool_filter)
9394
self._connection_params = connection_params
9495
self._errlog = errlog
9596
self._exit_stack = AsyncExitStack()
@@ -102,7 +103,6 @@ def __init__(
102103
errlog=self._errlog,
103104
)
104105
self._session = None
105-
self.tool_filter = tool_filter
106106
self._initialized = False
107107

108108
async def _initialize(self) -> ClientSession:
@@ -116,18 +116,6 @@ async def _initialize(self) -> ClientSession:
116116
self._initialized = True
117117
return self._session
118118

119-
def _is_selected(
120-
self, tool: BaseTool, readonly_context: Optional[ReadonlyContext]
121-
) -> bool:
122-
"""Checks if a tool should be selected based on the tool filter."""
123-
if self.tool_filter is None:
124-
return True
125-
if isinstance(self.tool_filter, ToolPredicate):
126-
return self.tool_filter(tool, readonly_context)
127-
if isinstance(self.tool_filter, list):
128-
return tool.name in self.tool_filter
129-
return False
130-
131119
@override
132120
async def close(self):
133121
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
@@ -221,6 +209,6 @@ async def get_tools(
221209
mcp_session_manager=self._session_manager,
222210
)
223211

224-
if self._is_selected(mcp_tool, readonly_context):
212+
if self._is_tool_selected(mcp_tool, readonly_context):
225213
tools.append(mcp_tool)
226214
return tools

src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ def __init__(
103103
tool_filter: The filter used to filter the tools in the toolset. It can be
104104
either a tool predicate or a list of tool names of the tools to expose.
105105
"""
106+
super().__init__(tool_filter=tool_filter)
106107
if not spec_dict:
107108
spec_dict = self._load_spec(spec_str, spec_str_type)
108109
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
109110
if auth_scheme or auth_credential:
110111
self._configure_auth_all(auth_scheme, auth_credential)
111-
self.tool_filter = tool_filter
112112

113113
def _configure_auth_all(
114114
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
@@ -129,12 +129,7 @@ async def get_tools(
129129
return [
130130
tool
131131
for tool in self._tools
132-
if self.tool_filter is None
133-
or (
134-
self.tool_filter(tool, readonly_context)
135-
if isinstance(self.tool_filter, ToolPredicate)
136-
else tool.name in self.tool_filter
137-
)
132+
if self._is_tool_selected(tool, readonly_context)
138133
]
139134

140135
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:

0 commit comments

Comments
 (0)