diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 9b69f6d1..ecf67f0c 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -13,7 +13,7 @@ # limitations under the License. import re import types -from typing import Any, Callable, Optional +from typing import Any, Callable, Mapping, Optional, Union from aiohttp import ClientSession @@ -59,18 +59,22 @@ def __parse_tool( name: str, schema: ToolSchema, auth_token_getters: dict[str, Callable[[], str]], + all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" - # sort into authenticated and reg params + # sort into reg, authn, and bound params params = [] authn_params: dict[str, list[str]] = {} + bound_params: dict[str, Callable[[], str]] = {} auth_sources: set[str] = set() for p in schema.parameters: - if not p.authSources: - params.append(p) - else: + if p.authSources: # authn parameter authn_params[p.name] = p.authSources auth_sources.update(p.authSources) + elif p.name in all_bound_params: # bound parameter + bound_params[p.name] = all_bound_params[p.name] + else: # regular parameter + params.append(p) authn_params = identify_required_authn_params( authn_params, auth_token_getters.keys() @@ -85,6 +89,7 @@ def __parse_tool( # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), auth_service_token_getters=types.MappingProxyType(auth_token_getters), + bound_params=types.MappingProxyType(bound_params), ) return tool @@ -124,6 +129,7 @@ async def load_tool( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -136,6 +142,10 @@ async def load_tool( name: The unique name or identifier of the tool to load. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. + + Returns: ToolboxTool: A callable object representing the loaded tool, ready @@ -154,7 +164,9 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters) + tool = self.__parse_tool( + name, manifest.tools[name], auth_token_getters, bound_params + ) return tool @@ -162,6 +174,7 @@ async def load_toolset( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. @@ -170,6 +183,9 @@ async def load_toolset( name: Name of the toolset to load tools. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. + Returns: @@ -184,7 +200,7 @@ async def load_toolset( # parse each tools name and schema into a list of ToolboxTools tools = [ - self.__parse_tool(n, s, auth_token_getters) + self.__parse_tool(n, s, auth_token_getters, bound_params) for n, s in manifest.tools.items() ] return tools diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 6421cd99..49a2a70a 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,10 +13,20 @@ # limitations under the License. +import asyncio import types from collections import defaultdict from inspect import Parameter, Signature -from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence +from typing import ( + Any, + Callable, + DefaultDict, + Iterable, + Mapping, + Optional, + Sequence, + Union, +) from aiohttp import ClientSession from pytest import Session @@ -44,6 +54,7 @@ def __init__( params: Sequence[Parameter], required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], + bound_params: Mapping[str, Union[Callable[[], Any], Any]], ): """ Initializes a callable that will trigger the tool invocation through the @@ -60,6 +71,9 @@ def __init__( of services that provide values for them. auth_service_token_getters: A dict of authService -> token (or callables that produce a token) + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. + """ # used to invoke the toolbox API @@ -81,6 +95,8 @@ def __init__( self.__required_authn_params = required_authn_params # map of authService -> token_getter self.__auth_service_token_getters = auth_service_token_getters + # map of parameter name to value (or callable that produces that value) + self.__bound_parameters = bound_params def __copy( self, @@ -91,6 +107,7 @@ def __copy( params: Optional[list[Parameter]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, + bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. @@ -106,6 +123,8 @@ def __copy( a auth_service_token_getter set for them yet. auth_service_token_getters: A dict of authService -> token (or callables that produce a token) + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. """ check = lambda val, default: val if val is not None else default @@ -121,6 +140,7 @@ def __copy( auth_service_token_getters=check( auth_service_token_getters, self.__auth_service_token_getters ), + bound_params=check(bound_params, self.__bound_parameters), ) async def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -153,6 +173,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: all_args.apply_defaults() # Include default values if not provided payload = all_args.arguments + # apply bounded parameters + for param, value in self.__bound_parameters.items(): + if asyncio.iscoroutinefunction(value): + value = await value() + elif callable(value): + value = value() + payload[param] = value + # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): @@ -211,23 +239,54 @@ def add_auth_token_getters( required_authn_params=new_req_authn_params, ) + def bind_parameters( + self, bound_params: Mapping[str, Union[Callable[[], Any], Any]] + ) -> "ToolboxTool": + """ + Binds parameters to values or callables that produce values. + + Args: + bound_params: A mapping of parameter names to values or callables that + produce values. + + Returns: + A new ToolboxTool instance with the specified parameters bound. + """ + param_names = set(p.name for p in self.__params) + for name in bound_params.keys(): + if name not in param_names: + raise Exception(f"unable to bind parameters: no parameter named {name}") + + new_params = [] + for p in self.__params: + if p.name not in bound_params: + new_params.append(p) + + all_bound_params = dict(self.__bound_parameters) + all_bound_params.update(bound_params) + + return self.__copy( + params=new_params, + bound_params=types.MappingProxyType(all_bound_params), + ) + def identify_required_authn_params( req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] ) -> dict[str, list[str]]: """ - Identifies authentication parameters that are still required; or not covered by - the provided `auth_service_names`. + Identifies authentication parameters that are still required; because they + not covered by the provided `auth_service_names`. - Args: - req_authn_params: A mapping of parameter names to sets of required - authentication services. - auth_service_names: An iterable of authentication service names for which - token getters are available. + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_service_names: An iterable of authentication service names for which + token getters are available. Returns: - A new dictionary representing the subset of required authentication - parameters that are not covered by the provided `auth_service_names`. + A new dictionary representing the subset of required authentication parameters + that are not covered by the provided `auth_services`. """ required_params = {} # params that are still required with provided auth_services for param, services in req_authn_params.items(): diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 91ea2259..e25573e2 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -14,6 +14,7 @@ import inspect +import json import pytest import pytest_asyncio @@ -100,6 +101,31 @@ async def test_load_tool_success(aioresponses, test_tool_str): assert await loaded_tool("some value") == "ok" +@pytest.mark.asyncio +async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): + """Tests successfully loading a toolset with multiple tools.""" + TOOLSET_NAME = "my_toolset" + TOOL1 = "tool1" + TOOL2 = "tool2" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} + ) + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", + payload=manifest.model_dump(), + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + tools = await client.load_toolset(TOOLSET_NAME) + + assert isinstance(tools, list) + assert len(tools) == len(manifest.tools) + + # Check if tools were created correctly + assert {t.__name__ for t in tools} == manifest.tools.keys() + + class TestAuth: @pytest.fixture @@ -181,26 +207,93 @@ def token_handler(): res = await tool(5) -@pytest.mark.asyncio -async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): - """Tests successfully loading a toolset with multiple tools.""" - TOOLSET_NAME = "my_toolset" - TOOL1 = "tool1" - TOOL2 = "tool2" - manifest = ManifestSchema( - serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} - ) - aioresponses.get( - f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", - payload=manifest.model_dump(), - status=200, - ) +class TestBoundParameter: - async with ToolboxClient(TEST_BASE_URL) as client: - tools = await client.load_toolset(TOOLSET_NAME) + @pytest.fixture + def tool_name(self): + return "tool1" - assert isinstance(tools, list) - assert len(tools) == len(manifest.tools) + @pytest_asyncio.fixture + async def client(self, aioresponses, test_tool_int_bool, tool_name): + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_int_bool} + ) - # Check if tools were created correctly - assert {t.__name__ for t in tools} == manifest.tools.keys() + # mock toolset GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool INVOKE call + def reflect_parameters(url, **kwargs): + body = {"result": kwargs["json"]} + return CallbackResult(status=200, body=json.dumps(body)) + + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", + payload=manifest.model_dump(), + callback=reflect_parameters, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + yield client + + @pytest.mark.asyncio + async def test_load_tool_success(self, tool_name, client): + """Tests 'load_tool' with a bound parameter specified.""" + tool = await client.load_tool(tool_name, bound_params={"argA": lambda: 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_load_toolset_success(self, tool_name, client): + """Tests 'load_toolset' with a bound parameter specified.""" + tools = await client.load_toolset("", bound_params={"argB": lambda: "hello"}) + tool = tools[0] + + assert len(tool.__signature__.parameters) == 1 + assert "argB" not in tool.__signature__.parameters + + res = await tool(True) + assert "argB" in res + + @pytest.mark.asyncio + async def test_bind_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_parameters({"argA": lambda: 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_param_fail(self, tool_name, client): + """Tests 'bind_param' with a bound parameter that doesn't exist.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + with pytest.raises(Exception): + tool = tool.bind_parameters({"argC": lambda: 5})