Skip to content

feat: Wrap toolbox-langchain's AsyncToolboxTool over toolbox-core's ToolboxTool. #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 41 additions & 37 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# limitations under the License.


import types
from typing import Any, Callable, Mapping, Optional, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Union

from aiohttp import ClientSession

from .protocol import ManifestSchema, ToolSchema
from .tool import ToolboxTool, identify_required_authn_params
from .protocol import ManifestSchema, ParameterSchema, ToolSchema
from .tool import ToolboxTool


class ToolboxClient:
Expand Down Expand Up @@ -59,24 +58,26 @@ def __parse_tool(
self,
name: str,
schema: ToolSchema,
auth_token_getters: dict[str, Callable[[], str]],
auth_token_getters: Mapping[str, Callable[[], str]],
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
strict: bool,
) -> ToolboxTool:
"""Internal helper to create a callable tool from its schema."""
# sort into reg, authn, and bound params
params = []
authn_params: dict[str, list[str]] = {}
bound_params: dict[str, Callable[[], str]] = {}
for p in schema.parameters:
if p.authSources: # authn parameter
authn_params[p.name] = p.authSources
elif p.name in all_bound_params: # bound parameter
bound_params[p.name] = all_bound_params[p.name]
else: # regular parameter
params.append(p)

authn_params = identify_required_authn_params(
authn_params, auth_token_getters.keys()
"""
Internal helper to create a callable ToolboxTool from its schema.

Args:
name: The name of the tool.
schema: The ToolSchema defining the tool.
auth_token_getters: Mapping of auth service names to token getters.
all_bound_params: Mapping of all initially bound parameter names to values/callables.
strict: The strictness setting for the created ToolboxTool instance.

Returns:
An initialized ToolboxTool instance.
"""

params: Sequence[ParameterSchema] = (
schema.parameters if schema.parameters is not None else []
)

tool = ToolboxTool(
Expand All @@ -85,10 +86,9 @@ def __parse_tool(
name=name,
description=schema.description,
params=params,
# create a read-only values for the maps to prevent mutation
required_authn_params=types.MappingProxyType(authn_params),
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
bound_params=types.MappingProxyType(bound_params),
auth_service_token_getters=auth_token_getters,
bound_params=all_bound_params,
strict=strict,
)
return tool

Expand Down Expand Up @@ -127,8 +127,9 @@ async def close(self):
async def load_tool(
self,
name: str,
auth_token_getters: dict[str, Callable[[], str]] = {},
auth_token_getters: Mapping[str, Callable[[], str]] = {},
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
strict: bool = True,
) -> ToolboxTool:
"""
Asynchronously loads a tool from the server.
Expand All @@ -143,8 +144,8 @@ async def load_tool(
callables that return the corresponding authentication token.
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.


strict: If True (default), the loaded tool instance will operate in
strict validation mode. If False, it will be non-strict.

Returns:
ToolboxTool: A callable object representing the loaded tool, ready
Expand All @@ -161,35 +162,38 @@ async def load_tool(

# parse the provided definition to a tool
if name not in manifest.tools:
# TODO: Better exception
raise Exception(f"Tool '{name}' not found!")
raise Exception(
f"Tool '{name}' not found in the manifest received from {url}"
)
tool = self.__parse_tool(
name, manifest.tools[name], auth_token_getters, bound_params
name, manifest.tools[name], auth_token_getters, bound_params, strict
)

return tool

async def load_toolset(
self,
name: Optional[str] = None,
auth_token_getters: dict[str, Callable[[], str]] = {},
auth_token_getters: Mapping[str, Callable[[], str]] = {},
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
strict: bool = True,
) -> list[ToolboxTool]:
"""
Asynchronously fetches a toolset and loads all tools defined within it.

Args:
name: Name of the toolset to load tools.
name: Optional name of the toolset to load. If None, attempts to load
the default toolset.
auth_token_getters: A mapping of authentication service names to
callables that return the corresponding authentication token.
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.


strict: If True (default), all loaded tool instances will operate in
strict validation mode. If False, they will be non-strict.

Returns:
list[ToolboxTool]: A list of callables, one for each tool defined
in the toolset.
list[ToolboxTool]: A list of callables, one for each tool defined in
the toolset.
"""
# Request the definition of the tool from the server
url = f"{self.__base_url}/api/toolset/{name or ''}"
Expand All @@ -199,7 +203,7 @@ async def load_toolset(

# parse each tools name and schema into a list of ToolboxTools
tools = [
self.__parse_tool(n, s, auth_token_getters, bound_params)
self.__parse_tool(n, s, auth_token_getters, bound_params, strict)
for n, s in manifest.tools.items()
]
return tools
Loading
Loading