Skip to content

feat(toolbox-core): add support for bound parameters #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import re
import types
from typing import Any, Callable, Optional
from typing import Any, Callable, Mapping, Optional, Union

from aiohttp import ClientSession

Expand Down Expand Up @@ -59,18 +59,22 @@ def __parse_tool(
name: str,
schema: ToolSchema,
auth_token_getters: dict[str, Callable[[], str]],
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
) -> ToolboxTool:
"""Internal helper to create a callable tool from its schema."""
# sort into authenticated and reg params
# sort into reg, authn, and bound params
params = []
authn_params: dict[str, list[str]] = {}
bound_params: dict[str, Callable[[], str]] = {}
auth_sources: set[str] = set()
for p in schema.parameters:
if not p.authSources:
params.append(p)
else:
if p.authSources: # authn parameter
authn_params[p.name] = p.authSources
auth_sources.update(p.authSources)
elif p.name in all_bound_params: # bound parameter
bound_params[p.name] = all_bound_params[p.name]
else: # regular parameter
params.append(p)

authn_params = identify_required_authn_params(
authn_params, auth_token_getters.keys()
Expand All @@ -85,6 +89,7 @@ def __parse_tool(
# create a read-only values for the maps to prevent mutation
required_authn_params=types.MappingProxyType(authn_params),
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
bound_params=types.MappingProxyType(bound_params),
)
return tool

Expand Down Expand Up @@ -124,6 +129,7 @@ async def load_tool(
self,
name: str,
auth_token_getters: dict[str, Callable[[], str]] = {},
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
) -> ToolboxTool:
"""
Asynchronously loads a tool from the server.
Expand All @@ -136,6 +142,10 @@ async def load_tool(
name: The unique name or identifier of the tool to load.
auth_token_getters: A mapping of authentication service names to
callables that return the corresponding authentication token.
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.



Returns:
ToolboxTool: A callable object representing the loaded tool, ready
Expand All @@ -154,14 +164,17 @@ async def load_tool(
if name not in manifest.tools:
# TODO: Better exception
raise Exception(f"Tool '{name}' not found!")
tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters)
tool = self.__parse_tool(
name, manifest.tools[name], auth_token_getters, bound_params
)

return tool

async def load_toolset(
self,
name: str,
auth_token_getters: dict[str, Callable[[], str]] = {},
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
) -> list[ToolboxTool]:
"""
Asynchronously fetches a toolset and loads all tools defined within it.
Expand All @@ -170,6 +183,9 @@ async def load_toolset(
name: Name of the toolset to load tools.
auth_token_getters: A mapping of authentication service names to
callables that return the corresponding authentication token.
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.



Returns:
Expand All @@ -184,7 +200,7 @@ async def load_toolset(

# parse each tools name and schema into a list of ToolboxTools
tools = [
self.__parse_tool(n, s, auth_token_getters)
self.__parse_tool(n, s, auth_token_getters, bound_params)
for n, s in manifest.tools.items()
]
return tools
79 changes: 69 additions & 10 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,20 @@
# limitations under the License.


import asyncio
import types
from collections import defaultdict
from inspect import Parameter, Signature
from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence
from typing import (
Any,
Callable,
DefaultDict,
Iterable,
Mapping,
Optional,
Sequence,
Union,
)

from aiohttp import ClientSession
from pytest import Session
Expand Down Expand Up @@ -44,6 +54,7 @@ def __init__(
params: Sequence[Parameter],
required_authn_params: Mapping[str, list[str]],
auth_service_token_getters: Mapping[str, Callable[[], str]],
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
):
"""
Initializes a callable that will trigger the tool invocation through the
Expand All @@ -60,6 +71,9 @@ def __init__(
of services that provide values for them.
auth_service_token_getters: A dict of authService -> token (or callables that
produce a token)
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.

"""

# used to invoke the toolbox API
Expand All @@ -81,6 +95,8 @@ def __init__(
self.__required_authn_params = required_authn_params
# map of authService -> token_getter
self.__auth_service_token_getters = auth_service_token_getters
# map of parameter name to value (or callable that produces that value)
self.__bound_parameters = bound_params

def __copy(
self,
Expand All @@ -91,6 +107,7 @@ def __copy(
params: Optional[list[Parameter]] = None,
required_authn_params: Optional[Mapping[str, list[str]]] = None,
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
) -> "ToolboxTool":
"""
Creates a copy of the ToolboxTool, overriding specific fields.
Expand All @@ -106,6 +123,8 @@ def __copy(
a auth_service_token_getter set for them yet.
auth_service_token_getters: A dict of authService -> token (or callables
that produce a token)
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.

"""
check = lambda val, default: val if val is not None else default
Expand All @@ -121,6 +140,7 @@ def __copy(
auth_service_token_getters=check(
auth_service_token_getters, self.__auth_service_token_getters
),
bound_params=check(bound_params, self.__bound_parameters),
)

async def __call__(self, *args: Any, **kwargs: Any) -> str:
Expand Down Expand Up @@ -153,6 +173,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
all_args.apply_defaults() # Include default values if not provided
payload = all_args.arguments

# apply bounded parameters
for param, value in self.__bound_parameters.items():
if asyncio.iscoroutinefunction(value):
value = await value()
elif callable(value):
value = value()
payload[param] = value

# create headers for auth services
headers = {}
for auth_service, token_getter in self.__auth_service_token_getters.items():
Expand Down Expand Up @@ -211,23 +239,54 @@ def add_auth_token_getters(
required_authn_params=new_req_authn_params,
)

def bind_parameters(
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
) -> "ToolboxTool":
"""
Binds parameters to values or callables that produce values.

Args:
bound_params: A mapping of parameter names to values or callables that
produce values.

Returns:
A new ToolboxTool instance with the specified parameters bound.
"""
param_names = set(p.name for p in self.__params)
for name in bound_params.keys():
if name not in param_names:
raise Exception(f"unable to bind parameters: no parameter named {name}")

new_params = []
for p in self.__params:
if p.name not in bound_params:
new_params.append(p)

all_bound_params = dict(self.__bound_parameters)
all_bound_params.update(bound_params)

return self.__copy(
params=new_params,
bound_params=types.MappingProxyType(all_bound_params),
)


def identify_required_authn_params(
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
) -> dict[str, list[str]]:
"""
Identifies authentication parameters that are still required; or not covered by
the provided `auth_service_names`.
Identifies authentication parameters that are still required; because they
not covered by the provided `auth_service_names`.

Args:
req_authn_params: A mapping of parameter names to sets of required
authentication services.
auth_service_names: An iterable of authentication service names for which
token getters are available.
Args:
req_authn_params: A mapping of parameter names to sets of required
authentication services.
auth_service_names: An iterable of authentication service names for which
token getters are available.

Returns:
A new dictionary representing the subset of required authentication
parameters that are not covered by the provided `auth_service_names`.
A new dictionary representing the subset of required authentication parameters
that are not covered by the provided `auth_services`.
"""
required_params = {} # params that are still required with provided auth_services
for param, services in req_authn_params.items():
Expand Down
Loading