13
13
# limitations under the License.
14
14
15
15
16
+ import asyncio
16
17
import types
17
18
from collections import defaultdict
18
19
from inspect import Parameter , Signature
19
- from typing import Any , Callable , DefaultDict , Iterable , Mapping , Optional , Sequence
20
+ from typing import (
21
+ Any ,
22
+ Callable ,
23
+ DefaultDict ,
24
+ Iterable ,
25
+ Mapping ,
26
+ Optional ,
27
+ Sequence ,
28
+ Union ,
29
+ )
20
30
21
31
from aiohttp import ClientSession
22
32
from pytest import Session
@@ -44,6 +54,7 @@ def __init__(
44
54
params : Sequence [Parameter ],
45
55
required_authn_params : Mapping [str , list [str ]],
46
56
auth_service_token_getters : Mapping [str , Callable [[], str ]],
57
+ bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
47
58
):
48
59
"""
49
60
Initializes a callable that will trigger the tool invocation through the
@@ -60,6 +71,9 @@ def __init__(
60
71
of services that provide values for them.
61
72
auth_service_token_getters: A dict of authService -> token (or callables that
62
73
produce a token)
74
+ bound_params: A mapping of parameter names to bind to specific values or
75
+ callables that are called to produce values as needed.
76
+
63
77
"""
64
78
65
79
# used to invoke the toolbox API
@@ -81,6 +95,8 @@ def __init__(
81
95
self .__required_authn_params = required_authn_params
82
96
# map of authService -> token_getter
83
97
self .__auth_service_token_getters = auth_service_token_getters
98
+ # map of parameter name to value (or callable that produces that value)
99
+ self .__bound_parameters = bound_params
84
100
85
101
def __copy (
86
102
self ,
@@ -91,6 +107,7 @@ def __copy(
91
107
params : Optional [list [Parameter ]] = None ,
92
108
required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
93
109
auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
110
+ bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
94
111
) -> "ToolboxTool" :
95
112
"""
96
113
Creates a copy of the ToolboxTool, overriding specific fields.
@@ -106,6 +123,8 @@ def __copy(
106
123
a auth_service_token_getter set for them yet.
107
124
auth_service_token_getters: A dict of authService -> token (or callables
108
125
that produce a token)
126
+ bound_params: A mapping of parameter names to bind to specific values or
127
+ callables that are called to produce values as needed.
109
128
110
129
"""
111
130
check = lambda val , default : val if val is not None else default
@@ -121,6 +140,7 @@ def __copy(
121
140
auth_service_token_getters = check (
122
141
auth_service_token_getters , self .__auth_service_token_getters
123
142
),
143
+ bound_params = check (bound_params , self .__bound_parameters ),
124
144
)
125
145
126
146
async def __call__ (self , * args : Any , ** kwargs : Any ) -> str :
@@ -153,6 +173,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
153
173
all_args .apply_defaults () # Include default values if not provided
154
174
payload = all_args .arguments
155
175
176
+ # apply bounded parameters
177
+ for param , value in self .__bound_parameters .items ():
178
+ if asyncio .iscoroutinefunction (value ):
179
+ value = await value ()
180
+ elif callable (value ):
181
+ value = value ()
182
+ payload [param ] = value
183
+
156
184
# create headers for auth services
157
185
headers = {}
158
186
for auth_service , token_getter in self .__auth_service_token_getters .items ():
@@ -211,23 +239,54 @@ def add_auth_token_getters(
211
239
required_authn_params = new_req_authn_params ,
212
240
)
213
241
242
+ def bind_parameters (
243
+ self , bound_params : Mapping [str , Union [Callable [[], Any ], Any ]]
244
+ ) -> "ToolboxTool" :
245
+ """
246
+ Binds parameters to values or callables that produce values.
247
+
248
+ Args:
249
+ bound_params: A mapping of parameter names to values or callables that
250
+ produce values.
251
+
252
+ Returns:
253
+ A new ToolboxTool instance with the specified parameters bound.
254
+ """
255
+ param_names = set (p .name for p in self .__params )
256
+ for name in bound_params .keys ():
257
+ if name not in param_names :
258
+ raise Exception (f"unable to bind parameters: no parameter named { name } " )
259
+
260
+ new_params = []
261
+ for p in self .__params :
262
+ if p .name not in bound_params :
263
+ new_params .append (p )
264
+
265
+ all_bound_params = dict (self .__bound_parameters )
266
+ all_bound_params .update (bound_params )
267
+
268
+ return self .__copy (
269
+ params = new_params ,
270
+ bound_params = types .MappingProxyType (all_bound_params ),
271
+ )
272
+
214
273
215
274
def identify_required_authn_params (
216
275
req_authn_params : Mapping [str , list [str ]], auth_service_names : Iterable [str ]
217
276
) -> dict [str , list [str ]]:
218
277
"""
219
- Identifies authentication parameters that are still required; or not covered by
220
- the provided `auth_service_names`.
278
+ Identifies authentication parameters that are still required; because they
279
+ not covered by the provided `auth_service_names`.
221
280
222
- Args:
223
- req_authn_params: A mapping of parameter names to sets of required
224
- authentication services.
225
- auth_service_names: An iterable of authentication service names for which
226
- token getters are available.
281
+ Args:
282
+ req_authn_params: A mapping of parameter names to sets of required
283
+ authentication services.
284
+ auth_service_names: An iterable of authentication service names for which
285
+ token getters are available.
227
286
228
287
Returns:
229
- A new dictionary representing the subset of required authentication
230
- parameters that are not covered by the provided `auth_service_names `.
288
+ A new dictionary representing the subset of required authentication parameters
289
+ that are not covered by the provided `auth_services `.
231
290
"""
232
291
required_params = {} # params that are still required with provided auth_services
233
292
for param , services in req_authn_params .items ():
0 commit comments