15
15
16
16
import asyncio
17
17
import types
18
- from inspect import Parameter , Signature
18
+ from inspect import Signature
19
19
from typing import (
20
20
Any ,
21
21
Callable ,
28
28
29
29
from aiohttp import ClientSession
30
30
31
+ from toolbox_core .protocol import ParameterSchema
32
+
31
33
32
34
class ToolboxTool :
33
35
"""
@@ -47,8 +49,8 @@ def __init__(
47
49
session : ClientSession ,
48
50
base_url : str ,
49
51
name : str ,
50
- desc : str ,
51
- params : Sequence [Parameter ],
52
+ description : str ,
53
+ params : Sequence [ParameterSchema ],
52
54
required_authn_params : Mapping [str , list [str ]],
53
55
auth_service_token_getters : Mapping [str , Callable [[], str ]],
54
56
bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
@@ -61,31 +63,30 @@ def __init__(
61
63
session: The `aiohttp.ClientSession` used for making API requests.
62
64
base_url: The base URL of the Toolbox server API.
63
65
name: The name of the remote tool.
64
- desc: The description of the remote tool (used as its docstring).
65
- params: A list of `inspect.Parameter` objects defining the tool's
66
- arguments and their types/defaults.
66
+ description: The description of the remote tool.
67
+ params: The args of the tool.
67
68
required_authn_params: A dict of required authenticated parameters to a list
68
69
of services that provide values for them.
69
70
auth_service_token_getters: A dict of authService -> token (or callables that
70
71
produce a token)
71
72
bound_params: A mapping of parameter names to bind to specific values or
72
73
callables that are called to produce values as needed.
73
-
74
74
"""
75
-
76
75
# used to invoke the toolbox API
77
76
self .__session : ClientSession = session
78
77
self .__base_url : str = base_url
79
78
self .__url = f"{ base_url } /api/tool/{ name } /invoke"
80
-
81
- self .__desc = desc
79
+ self .__description = description
82
80
self .__params = params
81
+ inspect_type_params = [param .to_param () for param in self .__params ]
83
82
84
83
# the following properties are set to help anyone that might inspect it determine usage
85
84
self .__name__ = name
86
- self .__doc__ = desc
87
- self .__signature__ = Signature (parameters = params , return_annotation = str )
88
- self .__annotations__ = {p .name : p .annotation for p in params }
85
+ self .__doc__ = create_docstring (self .__description , self .__params )
86
+ self .__signature__ = Signature (
87
+ parameters = inspect_type_params , return_annotation = str
88
+ )
89
+ self .__annotations__ = {p .name : p .annotation for p in inspect_type_params }
89
90
# TODO: self.__qualname__ ??
90
91
91
92
# map of parameter name to auth service required by it
@@ -100,8 +101,8 @@ def __copy(
100
101
session : Optional [ClientSession ] = None ,
101
102
base_url : Optional [str ] = None ,
102
103
name : Optional [str ] = None ,
103
- desc : Optional [str ] = None ,
104
- params : Optional [list [ Parameter ]] = None ,
104
+ description : Optional [str ] = None ,
105
+ params : Optional [Sequence [ ParameterSchema ]] = None ,
105
106
required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
106
107
auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
107
108
bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
@@ -113,9 +114,8 @@ def __copy(
113
114
session: The `aiohttp.ClientSession` used for making API requests.
114
115
base_url: The base URL of the Toolbox server API.
115
116
name: The name of the remote tool.
116
- desc: The description of the remote tool (used as its docstring).
117
- params: A list of `inspect.Parameter` objects defining the tool's
118
- arguments and their types/defaults.
117
+ description: The description of the remote tool.
118
+ params: The args of the tool.
119
119
required_authn_params: A dict of required authenticated parameters that need
120
120
a auth_service_token_getter set for them yet.
121
121
auth_service_token_getters: A dict of authService -> token (or callables
@@ -129,7 +129,7 @@ def __copy(
129
129
session = check (session , self .__session ),
130
130
base_url = check (base_url , self .__base_url ),
131
131
name = check (name , self .__name__ ),
132
- desc = check (desc , self .__desc ),
132
+ description = check (description , self .__description ),
133
133
params = check (params , self .__params ),
134
134
required_authn_params = check (
135
135
required_authn_params , self .__required_authn_params
@@ -258,7 +258,6 @@ def bind_parameters(
258
258
for p in self .__params :
259
259
if p .name not in bound_params :
260
260
new_params .append (p )
261
-
262
261
all_bound_params = dict (self .__bound_parameters )
263
262
all_bound_params .update (bound_params )
264
263
@@ -268,6 +267,19 @@ def bind_parameters(
268
267
)
269
268
270
269
270
+ def create_docstring (description : str , params : Sequence [ParameterSchema ]) -> str :
271
+ """Convert tool description and params into its function docstring"""
272
+ docstring = description
273
+ if not params :
274
+ return docstring
275
+ docstring += "\n \n Args:"
276
+ for p in params :
277
+ docstring += (
278
+ f"\n { p .name } ({ p .to_param ().annotation .__name__ } ): { p .description } "
279
+ )
280
+ return docstring
281
+
282
+
271
283
def identify_required_authn_params (
272
284
req_authn_params : Mapping [str , list [str ]], auth_service_names : Iterable [str ]
273
285
) -> dict [str , list [str ]]:
0 commit comments