15
15
16
16
import types
17
17
from inspect import Signature
18
- from typing import (
19
- Any ,
20
- Callable ,
21
- Mapping ,
22
- Optional ,
23
- Sequence ,
24
- Union ,
25
- )
18
+ from typing import Any , Callable , Coroutine , Mapping , Optional , Sequence , Union
26
19
27
20
from aiohttp import ClientSession
28
21
@@ -58,6 +51,7 @@ def __init__(
58
51
required_authn_params : Mapping [str , list [str ]],
59
52
auth_service_token_getters : Mapping [str , Callable [[], str ]],
60
53
bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
54
+ client_headers : Mapping [str , Union [Callable , Coroutine , str ]],
61
55
):
62
56
"""
63
57
Initializes a callable that will trigger the tool invocation through the
@@ -75,6 +69,7 @@ def __init__(
75
69
produce a token)
76
70
bound_params: A mapping of parameter names to bind to specific values or
77
71
callables that are called to produce values as needed.
72
+ client_headers: Client specific headers bound to the tool.
78
73
"""
79
74
# used to invoke the toolbox API
80
75
self .__session : ClientSession = session
@@ -96,12 +91,27 @@ def __init__(
96
91
self .__annotations__ = {p .name : p .annotation for p in inspect_type_params }
97
92
self .__qualname__ = f"{ self .__class__ .__qualname__ } .{ self .__name__ } "
98
93
94
+ # Validate conflicting Headers/Auth Tokens
95
+ request_header_names = client_headers .keys ()
96
+ auth_token_names = [
97
+ auth_token_name + "_token"
98
+ for auth_token_name in auth_service_token_getters .keys ()
99
+ ]
100
+ duplicates = request_header_names & auth_token_names
101
+ if duplicates :
102
+ raise ValueError (
103
+ f"Client header(s) `{ ', ' .join (duplicates )} ` already registered in client. "
104
+ f"Cannot register client the same headers in the client as well as tool."
105
+ )
106
+
99
107
# map of parameter name to auth service required by it
100
108
self .__required_authn_params = required_authn_params
101
109
# map of authService -> token_getter
102
110
self .__auth_service_token_getters = auth_service_token_getters
103
111
# map of parameter name to value (or callable that produces that value)
104
112
self .__bound_parameters = bound_params
113
+ # map of client headers to their value/callable/coroutine
114
+ self .__client_headers = client_headers
105
115
106
116
def __copy (
107
117
self ,
@@ -113,6 +123,7 @@ def __copy(
113
123
required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
114
124
auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
115
125
bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
126
+ client_headers : Optional [Mapping [str , Union [Callable , Coroutine , str ]]] = None ,
116
127
) -> "ToolboxTool" :
117
128
"""
118
129
Creates a copy of the ToolboxTool, overriding specific fields.
@@ -129,7 +140,7 @@ def __copy(
129
140
that produce a token)
130
141
bound_params: A mapping of parameter names to bind to specific values or
131
142
callables that are called to produce values as needed.
132
-
143
+ client_headers: Client specific headers bound to the tool.
133
144
"""
134
145
check = lambda val , default : val if val is not None else default
135
146
return ToolboxTool (
@@ -145,6 +156,7 @@ def __copy(
145
156
auth_service_token_getters , self .__auth_service_token_getters
146
157
),
147
158
bound_params = check (bound_params , self .__bound_parameters ),
159
+ client_headers = check (client_headers , self .__client_headers ),
148
160
)
149
161
150
162
async def __call__ (self , * args : Any , ** kwargs : Any ) -> str :
@@ -169,7 +181,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
169
181
for s in self .__required_authn_params .values ():
170
182
req_auth_services .update (s )
171
183
raise Exception (
172
- f"One or more of the following authn services are required to invoke this tool: { ',' .join (req_auth_services )} "
184
+ f"One or more of the following authn services are required to invoke this tool"
185
+ f": { ',' .join (req_auth_services )} "
173
186
)
174
187
175
188
# validate inputs to this call using the signature
@@ -188,6 +201,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
188
201
headers = {}
189
202
for auth_service , token_getter in self .__auth_service_token_getters .items ():
190
203
headers [f"{ auth_service } _token" ] = await resolve_value (token_getter )
204
+ for client_header_name , client_header_val in self .__client_headers .items ():
205
+ headers [client_header_name ] = await resolve_value (client_header_val )
191
206
192
207
async with self .__session .post (
193
208
self .__url ,
@@ -215,6 +230,10 @@ def add_auth_token_getters(
215
230
Returns:
216
231
A new ToolboxTool instance with the specified authentication token
217
232
getters registered.
233
+
234
+ Raises
235
+ ValueError: If the auth source has already been registered either
236
+ to the tool or to the corresponding client.
218
237
"""
219
238
220
239
# throw an error if the authentication source is already registered
@@ -226,6 +245,18 @@ def add_auth_token_getters(
226
245
f"Authentication source(s) `{ ', ' .join (duplicates )} ` already registered in tool `{ self .__name__ } `."
227
246
)
228
247
248
+ # Validate duplicates with client headers
249
+ request_header_names = self .__client_headers .keys ()
250
+ auth_token_names = [
251
+ auth_token_name + "_token" for auth_token_name in incoming_services
252
+ ]
253
+ duplicates = request_header_names & auth_token_names
254
+ if duplicates :
255
+ raise ValueError (
256
+ f"Client header(s) `{ ', ' .join (duplicates )} ` already registered in client. "
257
+ f"Cannot register client the same headers in the client as well as tool."
258
+ )
259
+
229
260
# create a read-only updated value for new_getters
230
261
new_getters = types .MappingProxyType (
231
262
dict (self .__auth_service_token_getters , ** auth_token_getters )
0 commit comments