Skip to content

Commit b2a2208

Browse files
authored
feat(toolbox-core): add support for bound parameters (#120)
1 parent 10087a1 commit b2a2208

File tree

3 files changed

+205
-37
lines changed

3 files changed

+205
-37
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import re
1515
import types
16-
from typing import Any, Callable, Optional
16+
from typing import Any, Callable, Mapping, Optional, Union
1717

1818
from aiohttp import ClientSession
1919

@@ -59,18 +59,22 @@ def __parse_tool(
5959
name: str,
6060
schema: ToolSchema,
6161
auth_token_getters: dict[str, Callable[[], str]],
62+
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
6263
) -> ToolboxTool:
6364
"""Internal helper to create a callable tool from its schema."""
64-
# sort into authenticated and reg params
65+
# sort into reg, authn, and bound params
6566
params = []
6667
authn_params: dict[str, list[str]] = {}
68+
bound_params: dict[str, Callable[[], str]] = {}
6769
auth_sources: set[str] = set()
6870
for p in schema.parameters:
69-
if not p.authSources:
70-
params.append(p)
71-
else:
71+
if p.authSources: # authn parameter
7272
authn_params[p.name] = p.authSources
7373
auth_sources.update(p.authSources)
74+
elif p.name in all_bound_params: # bound parameter
75+
bound_params[p.name] = all_bound_params[p.name]
76+
else: # regular parameter
77+
params.append(p)
7478

7579
authn_params = identify_required_authn_params(
7680
authn_params, auth_token_getters.keys()
@@ -85,6 +89,7 @@ def __parse_tool(
8589
# create a read-only values for the maps to prevent mutation
8690
required_authn_params=types.MappingProxyType(authn_params),
8791
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
92+
bound_params=types.MappingProxyType(bound_params),
8893
)
8994
return tool
9095

@@ -124,6 +129,7 @@ async def load_tool(
124129
self,
125130
name: str,
126131
auth_token_getters: dict[str, Callable[[], str]] = {},
132+
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
127133
) -> ToolboxTool:
128134
"""
129135
Asynchronously loads a tool from the server.
@@ -136,6 +142,10 @@ async def load_tool(
136142
name: The unique name or identifier of the tool to load.
137143
auth_token_getters: A mapping of authentication service names to
138144
callables that return the corresponding authentication token.
145+
bound_params: A mapping of parameter names to bind to specific values or
146+
callables that are called to produce values as needed.
147+
148+
139149
140150
Returns:
141151
ToolboxTool: A callable object representing the loaded tool, ready
@@ -154,14 +164,17 @@ async def load_tool(
154164
if name not in manifest.tools:
155165
# TODO: Better exception
156166
raise Exception(f"Tool '{name}' not found!")
157-
tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters)
167+
tool = self.__parse_tool(
168+
name, manifest.tools[name], auth_token_getters, bound_params
169+
)
158170

159171
return tool
160172

161173
async def load_toolset(
162174
self,
163175
name: str,
164176
auth_token_getters: dict[str, Callable[[], str]] = {},
177+
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
165178
) -> list[ToolboxTool]:
166179
"""
167180
Asynchronously fetches a toolset and loads all tools defined within it.
@@ -170,6 +183,9 @@ async def load_toolset(
170183
name: Name of the toolset to load tools.
171184
auth_token_getters: A mapping of authentication service names to
172185
callables that return the corresponding authentication token.
186+
bound_params: A mapping of parameter names to bind to specific values or
187+
callables that are called to produce values as needed.
188+
173189
174190
175191
Returns:
@@ -184,7 +200,7 @@ async def load_toolset(
184200

185201
# parse each tools name and schema into a list of ToolboxTools
186202
tools = [
187-
self.__parse_tool(n, s, auth_token_getters)
203+
self.__parse_tool(n, s, auth_token_getters, bound_params)
188204
for n, s in manifest.tools.items()
189205
]
190206
return tools

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,20 @@
1313
# limitations under the License.
1414

1515

16+
import asyncio
1617
import types
1718
from collections import defaultdict
1819
from inspect import Parameter, Signature
19-
from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence
20+
from typing import (
21+
Any,
22+
Callable,
23+
DefaultDict,
24+
Iterable,
25+
Mapping,
26+
Optional,
27+
Sequence,
28+
Union,
29+
)
2030

2131
from aiohttp import ClientSession
2232
from pytest import Session
@@ -44,6 +54,7 @@ def __init__(
4454
params: Sequence[Parameter],
4555
required_authn_params: Mapping[str, list[str]],
4656
auth_service_token_getters: Mapping[str, Callable[[], str]],
57+
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
4758
):
4859
"""
4960
Initializes a callable that will trigger the tool invocation through the
@@ -60,6 +71,9 @@ def __init__(
6071
of services that provide values for them.
6172
auth_service_token_getters: A dict of authService -> token (or callables that
6273
produce a token)
74+
bound_params: A mapping of parameter names to bind to specific values or
75+
callables that are called to produce values as needed.
76+
6377
"""
6478

6579
# used to invoke the toolbox API
@@ -81,6 +95,8 @@ def __init__(
8195
self.__required_authn_params = required_authn_params
8296
# map of authService -> token_getter
8397
self.__auth_service_token_getters = auth_service_token_getters
98+
# map of parameter name to value (or callable that produces that value)
99+
self.__bound_parameters = bound_params
84100

85101
def __copy(
86102
self,
@@ -91,6 +107,7 @@ def __copy(
91107
params: Optional[list[Parameter]] = None,
92108
required_authn_params: Optional[Mapping[str, list[str]]] = None,
93109
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
110+
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
94111
) -> "ToolboxTool":
95112
"""
96113
Creates a copy of the ToolboxTool, overriding specific fields.
@@ -106,6 +123,8 @@ def __copy(
106123
a auth_service_token_getter set for them yet.
107124
auth_service_token_getters: A dict of authService -> token (or callables
108125
that produce a token)
126+
bound_params: A mapping of parameter names to bind to specific values or
127+
callables that are called to produce values as needed.
109128
110129
"""
111130
check = lambda val, default: val if val is not None else default
@@ -121,6 +140,7 @@ def __copy(
121140
auth_service_token_getters=check(
122141
auth_service_token_getters, self.__auth_service_token_getters
123142
),
143+
bound_params=check(bound_params, self.__bound_parameters),
124144
)
125145

126146
async def __call__(self, *args: Any, **kwargs: Any) -> str:
@@ -153,6 +173,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
153173
all_args.apply_defaults() # Include default values if not provided
154174
payload = all_args.arguments
155175

176+
# apply bounded parameters
177+
for param, value in self.__bound_parameters.items():
178+
if asyncio.iscoroutinefunction(value):
179+
value = await value()
180+
elif callable(value):
181+
value = value()
182+
payload[param] = value
183+
156184
# create headers for auth services
157185
headers = {}
158186
for auth_service, token_getter in self.__auth_service_token_getters.items():
@@ -211,23 +239,54 @@ def add_auth_token_getters(
211239
required_authn_params=new_req_authn_params,
212240
)
213241

242+
def bind_parameters(
243+
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
244+
) -> "ToolboxTool":
245+
"""
246+
Binds parameters to values or callables that produce values.
247+
248+
Args:
249+
bound_params: A mapping of parameter names to values or callables that
250+
produce values.
251+
252+
Returns:
253+
A new ToolboxTool instance with the specified parameters bound.
254+
"""
255+
param_names = set(p.name for p in self.__params)
256+
for name in bound_params.keys():
257+
if name not in param_names:
258+
raise Exception(f"unable to bind parameters: no parameter named {name}")
259+
260+
new_params = []
261+
for p in self.__params:
262+
if p.name not in bound_params:
263+
new_params.append(p)
264+
265+
all_bound_params = dict(self.__bound_parameters)
266+
all_bound_params.update(bound_params)
267+
268+
return self.__copy(
269+
params=new_params,
270+
bound_params=types.MappingProxyType(all_bound_params),
271+
)
272+
214273

215274
def identify_required_authn_params(
216275
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
217276
) -> dict[str, list[str]]:
218277
"""
219-
Identifies authentication parameters that are still required; or not covered by
220-
the provided `auth_service_names`.
278+
Identifies authentication parameters that are still required; because they
279+
not covered by the provided `auth_service_names`.
221280
222-
Args:
223-
req_authn_params: A mapping of parameter names to sets of required
224-
authentication services.
225-
auth_service_names: An iterable of authentication service names for which
226-
token getters are available.
281+
Args:
282+
req_authn_params: A mapping of parameter names to sets of required
283+
authentication services.
284+
auth_service_names: An iterable of authentication service names for which
285+
token getters are available.
227286
228287
Returns:
229-
A new dictionary representing the subset of required authentication
230-
parameters that are not covered by the provided `auth_service_names`.
288+
A new dictionary representing the subset of required authentication parameters
289+
that are not covered by the provided `auth_services`.
231290
"""
232291
required_params = {} # params that are still required with provided auth_services
233292
for param, services in req_authn_params.items():

0 commit comments

Comments
 (0)