Skip to content

Commit e73f6b6

Browse files
feat: Add client headers to Toolbox (#178)
* iter1: poc # Conflicts: # packages/toolbox-core/src/toolbox_core/client.py # packages/toolbox-core/src/toolbox_core/tool.py * remove client headers from tool * merge correction * cleanup * client headers functionality * small diff * mypy * raise error on duplicate headers * docs * add client headers to tool * lint * lint * fix * add client tests * add client tests * fix tests * fix * lint * fix tests * cleanup * cleanup * lint * fix * cleanup * lint * lint * lint * lint * Update packages/toolbox-core/src/toolbox_core/client.py Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> * lint * fix * cleanup * use mock_tool_load in test * test cleanup * test cleanup * lint --------- Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
1 parent 6693407 commit e73f6b6

File tree

4 files changed

+592
-64
lines changed

4 files changed

+592
-64
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
1614
import types
17-
from typing import Any, Callable, Mapping, Optional, Union
15+
from typing import Any, Callable, Coroutine, Mapping, Optional, Union
1816

1917
from aiohttp import ClientSession
2018

2119
from .protocol import ManifestSchema, ToolSchema
22-
from .tool import ToolboxTool, identify_required_authn_params
20+
from .tool import ToolboxTool
21+
from .utils import identify_required_authn_params, resolve_value
2322

2423

2524
class ToolboxClient:
@@ -37,6 +36,7 @@ def __init__(
3736
self,
3837
url: str,
3938
session: Optional[ClientSession] = None,
39+
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
4040
):
4141
"""
4242
Initializes the ToolboxClient.
@@ -47,6 +47,7 @@ def __init__(
4747
If None (default), a new session is created internally. Note that
4848
if a session is provided, its lifecycle (including closing)
4949
should typically be managed externally.
50+
client_headers: Headers to include in each request sent through this client.
5051
"""
5152
self.__base_url = url
5253

@@ -55,12 +56,15 @@ def __init__(
5556
session = ClientSession()
5657
self.__session = session
5758

59+
self.__client_headers = client_headers if client_headers is not None else {}
60+
5861
def __parse_tool(
5962
self,
6063
name: str,
6164
schema: ToolSchema,
6265
auth_token_getters: dict[str, Callable[[], str]],
6366
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
67+
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
6468
) -> ToolboxTool:
6569
"""Internal helper to create a callable tool from its schema."""
6670
# sort into reg, authn, and bound params
@@ -89,6 +93,7 @@ def __parse_tool(
8993
required_authn_params=types.MappingProxyType(authn_params),
9094
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
9195
bound_params=types.MappingProxyType(bound_params),
96+
client_headers=types.MappingProxyType(client_headers),
9297
)
9398
return tool
9499

@@ -144,18 +149,21 @@ async def load_tool(
144149
bound_params: A mapping of parameter names to bind to specific values or
145150
callables that are called to produce values as needed.
146151
147-
148-
149152
Returns:
150153
ToolboxTool: A callable object representing the loaded tool, ready
151154
for execution. The specific arguments and behavior of the callable
152155
depend on the tool itself.
153156
154157
"""
158+
# Resolve client headers
159+
resolved_headers = {
160+
name: await resolve_value(val)
161+
for name, val in self.__client_headers.items()
162+
}
155163

156164
# request the definition of the tool from the server
157165
url = f"{self.__base_url}/api/tool/{name}"
158-
async with self.__session.get(url) as response:
166+
async with self.__session.get(url, headers=resolved_headers) as response:
159167
json = await response.json()
160168
manifest: ManifestSchema = ManifestSchema(**json)
161169

@@ -164,7 +172,11 @@ async def load_tool(
164172
# TODO: Better exception
165173
raise Exception(f"Tool '{name}' not found!")
166174
tool = self.__parse_tool(
167-
name, manifest.tools[name], auth_token_getters, bound_params
175+
name,
176+
manifest.tools[name],
177+
auth_token_getters,
178+
bound_params,
179+
self.__client_headers,
168180
)
169181

170182
return tool
@@ -185,21 +197,50 @@ async def load_toolset(
185197
bound_params: A mapping of parameter names to bind to specific values or
186198
callables that are called to produce values as needed.
187199
188-
189-
190200
Returns:
191201
list[ToolboxTool]: A list of callables, one for each tool defined
192202
in the toolset.
193203
"""
204+
# Resolve client headers
205+
original_headers = self.__client_headers
206+
resolved_headers = {
207+
header_name: await resolve_value(original_headers[header_name])
208+
for header_name in original_headers
209+
}
194210
# Request the definition of the tool from the server
195211
url = f"{self.__base_url}/api/toolset/{name or ''}"
196-
async with self.__session.get(url) as response:
212+
async with self.__session.get(url, headers=resolved_headers) as response:
197213
json = await response.json()
198214
manifest: ManifestSchema = ManifestSchema(**json)
199215

200216
# parse each tools name and schema into a list of ToolboxTools
201217
tools = [
202-
self.__parse_tool(n, s, auth_token_getters, bound_params)
218+
self.__parse_tool(
219+
n, s, auth_token_getters, bound_params, self.__client_headers
220+
)
203221
for n, s in manifest.tools.items()
204222
]
205223
return tools
224+
225+
async def add_headers(
226+
self, headers: Mapping[str, Union[Callable, Coroutine, str]]
227+
) -> None:
228+
"""
229+
Asynchronously Add headers to be included in each request sent through this client.
230+
231+
Args:
232+
headers: Headers to include in each request sent through this client.
233+
234+
Raises:
235+
ValueError: If any of the headers are already registered in the client.
236+
"""
237+
existing_headers = self.__client_headers.keys()
238+
incoming_headers = headers.keys()
239+
duplicates = existing_headers & incoming_headers
240+
if duplicates:
241+
raise ValueError(
242+
f"Client header(s) `{', '.join(duplicates)}` already registered in the client."
243+
)
244+
245+
merged_headers = {**self.__client_headers, **headers}
246+
self.__client_headers = merged_headers

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,7 @@
1515

1616
import types
1717
from inspect import Signature
18-
from typing import (
19-
Any,
20-
Callable,
21-
Mapping,
22-
Optional,
23-
Sequence,
24-
Union,
25-
)
18+
from typing import Any, Callable, Coroutine, Mapping, Optional, Sequence, Union
2619

2720
from aiohttp import ClientSession
2821

@@ -58,6 +51,7 @@ def __init__(
5851
required_authn_params: Mapping[str, list[str]],
5952
auth_service_token_getters: Mapping[str, Callable[[], str]],
6053
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
54+
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
6155
):
6256
"""
6357
Initializes a callable that will trigger the tool invocation through the
@@ -75,6 +69,7 @@ def __init__(
7569
produce a token)
7670
bound_params: A mapping of parameter names to bind to specific values or
7771
callables that are called to produce values as needed.
72+
client_headers: Client specific headers bound to the tool.
7873
"""
7974
# used to invoke the toolbox API
8075
self.__session: ClientSession = session
@@ -96,12 +91,27 @@ def __init__(
9691
self.__annotations__ = {p.name: p.annotation for p in inspect_type_params}
9792
self.__qualname__ = f"{self.__class__.__qualname__}.{self.__name__}"
9893

94+
# Validate conflicting Headers/Auth Tokens
95+
request_header_names = client_headers.keys()
96+
auth_token_names = [
97+
auth_token_name + "_token"
98+
for auth_token_name in auth_service_token_getters.keys()
99+
]
100+
duplicates = request_header_names & auth_token_names
101+
if duplicates:
102+
raise ValueError(
103+
f"Client header(s) `{', '.join(duplicates)}` already registered in client. "
104+
f"Cannot register client the same headers in the client as well as tool."
105+
)
106+
99107
# map of parameter name to auth service required by it
100108
self.__required_authn_params = required_authn_params
101109
# map of authService -> token_getter
102110
self.__auth_service_token_getters = auth_service_token_getters
103111
# map of parameter name to value (or callable that produces that value)
104112
self.__bound_parameters = bound_params
113+
# map of client headers to their value/callable/coroutine
114+
self.__client_headers = client_headers
105115

106116
def __copy(
107117
self,
@@ -113,6 +123,7 @@ def __copy(
113123
required_authn_params: Optional[Mapping[str, list[str]]] = None,
114124
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
115125
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
126+
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
116127
) -> "ToolboxTool":
117128
"""
118129
Creates a copy of the ToolboxTool, overriding specific fields.
@@ -129,7 +140,7 @@ def __copy(
129140
that produce a token)
130141
bound_params: A mapping of parameter names to bind to specific values or
131142
callables that are called to produce values as needed.
132-
143+
client_headers: Client specific headers bound to the tool.
133144
"""
134145
check = lambda val, default: val if val is not None else default
135146
return ToolboxTool(
@@ -145,6 +156,7 @@ def __copy(
145156
auth_service_token_getters, self.__auth_service_token_getters
146157
),
147158
bound_params=check(bound_params, self.__bound_parameters),
159+
client_headers=check(client_headers, self.__client_headers),
148160
)
149161

150162
async def __call__(self, *args: Any, **kwargs: Any) -> str:
@@ -169,7 +181,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
169181
for s in self.__required_authn_params.values():
170182
req_auth_services.update(s)
171183
raise Exception(
172-
f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
184+
f"One or more of the following authn services are required to invoke this tool"
185+
f": {','.join(req_auth_services)}"
173186
)
174187

175188
# validate inputs to this call using the signature
@@ -188,6 +201,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
188201
headers = {}
189202
for auth_service, token_getter in self.__auth_service_token_getters.items():
190203
headers[f"{auth_service}_token"] = await resolve_value(token_getter)
204+
for client_header_name, client_header_val in self.__client_headers.items():
205+
headers[client_header_name] = await resolve_value(client_header_val)
191206

192207
async with self.__session.post(
193208
self.__url,
@@ -215,6 +230,10 @@ def add_auth_token_getters(
215230
Returns:
216231
A new ToolboxTool instance with the specified authentication token
217232
getters registered.
233+
234+
Raises
235+
ValueError: If the auth source has already been registered either
236+
to the tool or to the corresponding client.
218237
"""
219238

220239
# throw an error if the authentication source is already registered
@@ -226,6 +245,18 @@ def add_auth_token_getters(
226245
f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`."
227246
)
228247

248+
# Validate duplicates with client headers
249+
request_header_names = self.__client_headers.keys()
250+
auth_token_names = [
251+
auth_token_name + "_token" for auth_token_name in incoming_services
252+
]
253+
duplicates = request_header_names & auth_token_names
254+
if duplicates:
255+
raise ValueError(
256+
f"Client header(s) `{', '.join(duplicates)}` already registered in client. "
257+
f"Cannot register client the same headers in the client as well as tool."
258+
)
259+
229260
# create a read-only updated value for new_getters
230261
new_getters = types.MappingProxyType(
231262
dict(self.__auth_service_token_getters, **auth_token_getters)

0 commit comments

Comments
 (0)