Skip to content

feat: Add client headers to Toolbox #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
cd59f54
iter1: poc
twishabansal Apr 15, 2025
297f5a9
remove client headers from tool
twishabansal Apr 17, 2025
7a30444
merge correction
twishabansal Apr 17, 2025
c78b74e
cleanup
twishabansal Apr 17, 2025
4b92ace
client headers functionality
twishabansal Apr 17, 2025
0b739ce
Merge branch 'main' into add-headers-poc
twishabansal Apr 17, 2025
37a3984
small diff
twishabansal Apr 17, 2025
d7ab445
Merge remote-tracking branch 'origin/add-headers-poc' into add-header…
twishabansal Apr 17, 2025
bc6ca96
mypy
twishabansal Apr 17, 2025
f5fc526
raise error on duplicate headers
twishabansal Apr 21, 2025
5e91f15
docs
twishabansal Apr 21, 2025
154edc1
add client headers to tool
twishabansal Apr 21, 2025
b83eaef
lint
twishabansal Apr 21, 2025
6965ba1
lint
twishabansal Apr 21, 2025
e8a53c0
fix
twishabansal Apr 21, 2025
ea319e5
add client tests
twishabansal Apr 21, 2025
abeb8de
add client tests
twishabansal Apr 21, 2025
45c0fc6
fix tests
twishabansal Apr 21, 2025
7693011
fix
twishabansal Apr 21, 2025
a43ffad
lint
twishabansal Apr 21, 2025
f564e60
fix tests
twishabansal Apr 21, 2025
2361cdf
cleanup
twishabansal Apr 21, 2025
d46bfd0
cleanup
twishabansal Apr 21, 2025
f2f5cd2
lint
twishabansal Apr 21, 2025
0fd1ca2
fix
twishabansal Apr 21, 2025
516d151
cleanup
twishabansal Apr 21, 2025
16b6664
Merge branch 'main' into add-headers-poc
twishabansal Apr 21, 2025
80b40d7
lint
twishabansal Apr 21, 2025
70d55b6
lint
twishabansal Apr 21, 2025
551635b
lint
twishabansal Apr 21, 2025
2880b4d
lint
twishabansal Apr 21, 2025
7990b99
Update packages/toolbox-core/src/toolbox_core/client.py
twishabansal Apr 22, 2025
60572b4
lint
twishabansal Apr 22, 2025
1826ad9
fix
twishabansal Apr 22, 2025
a4eafe6
cleanup
twishabansal Apr 22, 2025
8502db5
use mock_tool_load in test
twishabansal Apr 22, 2025
d0731ec
test cleanup
twishabansal Apr 22, 2025
560dfec
test cleanup
twishabansal Apr 22, 2025
1ac9a14
lint
twishabansal Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 53 additions & 12 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import types
from typing import Any, Callable, Mapping, Optional, Union
from typing import Any, Callable, Coroutine, Mapping, Optional, Union

from aiohttp import ClientSession

from .protocol import ManifestSchema, ToolSchema
from .tool import ToolboxTool, identify_required_authn_params
from .tool import ToolboxTool
from .utils import identify_required_authn_params, resolve_value


class ToolboxClient:
Expand All @@ -37,6 +36,7 @@ def __init__(
self,
url: str,
session: Optional[ClientSession] = None,
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
):
"""
Initializes the ToolboxClient.
Expand All @@ -47,6 +47,7 @@ def __init__(
If None (default), a new session is created internally. Note that
if a session is provided, its lifecycle (including closing)
should typically be managed externally.
client_headers: Headers to include in each request sent through this client.
"""
self.__base_url = url

Expand All @@ -55,12 +56,15 @@ def __init__(
session = ClientSession()
self.__session = session

self.__client_headers = client_headers if client_headers is not None else {}

def __parse_tool(
self,
name: str,
schema: ToolSchema,
auth_token_getters: dict[str, Callable[[], str]],
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
) -> ToolboxTool:
"""Internal helper to create a callable tool from its schema."""
# sort into reg, authn, and bound params
Expand Down Expand Up @@ -89,6 +93,7 @@ def __parse_tool(
required_authn_params=types.MappingProxyType(authn_params),
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
bound_params=types.MappingProxyType(bound_params),
client_headers=types.MappingProxyType(client_headers),
)
return tool

Expand Down Expand Up @@ -144,18 +149,21 @@ async def load_tool(
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.



Returns:
ToolboxTool: A callable object representing the loaded tool, ready
for execution. The specific arguments and behavior of the callable
depend on the tool itself.

"""
# Resolve client headers
resolved_headers = {
name: await resolve_value(val)
for name, val in self.__client_headers.items()
}

# request the definition of the tool from the server
url = f"{self.__base_url}/api/tool/{name}"
async with self.__session.get(url) as response:
async with self.__session.get(url, headers=resolved_headers) as response:
json = await response.json()
manifest: ManifestSchema = ManifestSchema(**json)

Expand All @@ -164,7 +172,11 @@ async def load_tool(
# TODO: Better exception
raise Exception(f"Tool '{name}' not found!")
tool = self.__parse_tool(
name, manifest.tools[name], auth_token_getters, bound_params
name,
manifest.tools[name],
auth_token_getters,
bound_params,
self.__client_headers,
)

return tool
Expand All @@ -185,21 +197,50 @@ async def load_toolset(
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.



Returns:
list[ToolboxTool]: A list of callables, one for each tool defined
in the toolset.
"""
# Resolve client headers
original_headers = self.__client_headers
resolved_headers = {
header_name: await resolve_value(original_headers[header_name])
for header_name in original_headers
}
# Request the definition of the tool from the server
url = f"{self.__base_url}/api/toolset/{name or ''}"
async with self.__session.get(url) as response:
async with self.__session.get(url, headers=resolved_headers) as response:
json = await response.json()
manifest: ManifestSchema = ManifestSchema(**json)

# parse each tools name and schema into a list of ToolboxTools
tools = [
self.__parse_tool(n, s, auth_token_getters, bound_params)
self.__parse_tool(
n, s, auth_token_getters, bound_params, self.__client_headers
)
for n, s in manifest.tools.items()
]
return tools

async def add_headers(
self, headers: Mapping[str, Union[Callable, Coroutine, str]]
) -> None:
"""
Asynchronously Add headers to be included in each request sent through this client.

Args:
headers: Headers to include in each request sent through this client.

Raises:
ValueError: If any of the headers are already registered in the client.
"""
existing_headers = self.__client_headers.keys()
incoming_headers = headers.keys()
duplicates = existing_headers & incoming_headers
if duplicates:
raise ValueError(
f"Client header(s) `{', '.join(duplicates)}` already registered in the client."
)

merged_headers = {**self.__client_headers, **headers}
self.__client_headers = merged_headers
51 changes: 41 additions & 10 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@

import types
from inspect import Signature
from typing import (
Any,
Callable,
Mapping,
Optional,
Sequence,
Union,
)
from typing import Any, Callable, Coroutine, Mapping, Optional, Sequence, Union

from aiohttp import ClientSession

Expand Down Expand Up @@ -58,6 +51,7 @@ def __init__(
required_authn_params: Mapping[str, list[str]],
auth_service_token_getters: Mapping[str, Callable[[], str]],
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
):
"""
Initializes a callable that will trigger the tool invocation through the
Expand All @@ -75,6 +69,7 @@ def __init__(
produce a token)
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.
client_headers: Client specific headers bound to the tool.
"""
# used to invoke the toolbox API
self.__session: ClientSession = session
Expand All @@ -96,12 +91,27 @@ def __init__(
self.__annotations__ = {p.name: p.annotation for p in inspect_type_params}
self.__qualname__ = f"{self.__class__.__qualname__}.{self.__name__}"

# Validate conflicting Headers/Auth Tokens
request_header_names = client_headers.keys()
auth_token_names = [
auth_token_name + "_token"
for auth_token_name in auth_service_token_getters.keys()
]
duplicates = request_header_names & auth_token_names
if duplicates:
raise ValueError(
f"Client header(s) `{', '.join(duplicates)}` already registered in client. "
f"Cannot register client the same headers in the client as well as tool."
)

# map of parameter name to auth service required by it
self.__required_authn_params = required_authn_params
# map of authService -> token_getter
self.__auth_service_token_getters = auth_service_token_getters
# map of parameter name to value (or callable that produces that value)
self.__bound_parameters = bound_params
# map of client headers to their value/callable/coroutine
self.__client_headers = client_headers

def __copy(
self,
Expand All @@ -113,6 +123,7 @@ def __copy(
required_authn_params: Optional[Mapping[str, list[str]]] = None,
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
) -> "ToolboxTool":
"""
Creates a copy of the ToolboxTool, overriding specific fields.
Expand All @@ -129,7 +140,7 @@ def __copy(
that produce a token)
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.

client_headers: Client specific headers bound to the tool.
"""
check = lambda val, default: val if val is not None else default
return ToolboxTool(
Expand All @@ -145,6 +156,7 @@ def __copy(
auth_service_token_getters, self.__auth_service_token_getters
),
bound_params=check(bound_params, self.__bound_parameters),
client_headers=check(client_headers, self.__client_headers),
)

async def __call__(self, *args: Any, **kwargs: Any) -> str:
Expand All @@ -169,7 +181,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
for s in self.__required_authn_params.values():
req_auth_services.update(s)
raise Exception(
f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
f"One or more of the following authn services are required to invoke this tool"
f": {','.join(req_auth_services)}"
)

# validate inputs to this call using the signature
Expand All @@ -188,6 +201,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
headers = {}
for auth_service, token_getter in self.__auth_service_token_getters.items():
headers[f"{auth_service}_token"] = await resolve_value(token_getter)
for client_header_name, client_header_val in self.__client_headers.items():
headers[client_header_name] = await resolve_value(client_header_val)

async with self.__session.post(
self.__url,
Expand Down Expand Up @@ -215,6 +230,10 @@ def add_auth_token_getters(
Returns:
A new ToolboxTool instance with the specified authentication token
getters registered.

Raises
ValueError: If the auth source has already been registered either
to the tool or to the corresponding client.
"""

# throw an error if the authentication source is already registered
Expand All @@ -226,6 +245,18 @@ def add_auth_token_getters(
f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`."
)

# Validate duplicates with client headers
request_header_names = self.__client_headers.keys()
auth_token_names = [
auth_token_name + "_token" for auth_token_name in incoming_services
]
duplicates = request_header_names & auth_token_names
if duplicates:
raise ValueError(
f"Client header(s) `{', '.join(duplicates)}` already registered in client. "
f"Cannot register client the same headers in the client as well as tool."
)

# create a read-only updated value for new_getters
new_getters = types.MappingProxyType(
dict(self.__auth_service_token_getters, **auth_token_getters)
Expand Down
Loading