From 5f1b2b0bc6def548856fa8763a27c9e21c65dda2 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:04:30 +0000 Subject: [PATCH 01/50] feat: add authenticated parameters support --- .../toolbox-core/src/toolbox_core/client.py | 39 ++++- .../toolbox-core/src/toolbox_core/tool.py | 160 ++++++++++++++++-- packages/toolbox-core/tests/test_client.py | 105 +++++++++++- 3 files changed, 279 insertions(+), 25 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index dc59e440..0fcca0d2 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Optional +import types +from typing import Any, Callable, Optional from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool +from .tool import ToolboxTool, filter_required_authn_params class ToolboxClient: @@ -53,14 +53,34 @@ def __init__( session = ClientSession() self.__session = session - def __parse_tool(self, name: str, schema: ToolSchema) -> ToolboxTool: + def __parse_tool( + self, + name: str, + schema: ToolSchema, + auth_token_getters: dict[str, Callable[[], str]], + ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" + # sort into authenticated and reg params + params = [] + authn_params: dict[str, list[str]] = {} + auth_sources: set[str] = set() + for p in schema.parameters: + if not p.authSources: + params.append(p) + else: + authn_params[p.name] = p.authSources + auth_sources.update(p.authSources) + + authn_params = filter_required_authn_params(authn_params, auth_sources) + tool = ToolboxTool( session=self.__session, base_url=self.__base_url, name=name, desc=schema.description, - params=[p.to_param() for p in schema.parameters], + params=[p.to_param() for p in params], + required_authn_params=types.MappingProxyType(authn_params), + auth_service_token_getters=auth_token_getters, ) return tool @@ -99,6 +119,7 @@ async def close(self): async def load_tool( self, name: str, + auth_service_tokens: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -127,13 +148,14 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name]) + tool = self.__parse_tool(name, manifest.tools[name], auth_service_tokens) return tool async def load_toolset( self, name: str, + auth_token_getters: dict[str, Callable[[], str]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. @@ -152,5 +174,8 @@ async def load_toolset( manifest: ManifestSchema = ManifestSchema(**json) # parse each tools name and schema into a list of ToolboxTools - tools = [self.__parse_tool(n, s) for n, s in manifest.tools.items()] + tools = [ + self.__parse_tool(n, s, auth_token_getters) + for n, s in manifest.tools.items() + ] return tools diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 48e4626c..494c7c21 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,10 +13,13 @@ # limitations under the License. +import types +from collections import defaultdict from inspect import Parameter, Signature -from typing import Any +from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence from aiohttp import ClientSession +from pytest import Session class ToolboxTool: @@ -32,20 +35,19 @@ class ToolboxTool: and `inspect` work as expected. """ - __url: str - __session: ClientSession - __signature__: Signature - def __init__( self, session: ClientSession, base_url: str, name: str, desc: str, - params: list[Parameter], + params: Sequence[Parameter], + required_authn_params: Mapping[str, list[str]], + auth_service_token_getters: Mapping[str, Callable[[], str]], ): """ - Initializes a callable that will trigger the tool invocation through the Toolbox server. + Initializes a callable that will trigger the tool invocation through the + Toolbox server. Args: session: The `aiohttp.ClientSession` used for making API requests. @@ -54,19 +56,69 @@ def __init__( desc: The description of the remote tool (used as its docstring). params: A list of `inspect.Parameter` objects defining the tool's arguments and their types/defaults. + required_authn_params: A dict of required authenticated parameters that + need a auth_service_token_getter set for them yet. + auth_service_tokens: A dict of authService -> token (or callables that + produce a token) """ # used to invoke the toolbox API - self.__session = session + self.__session: ClientSession = session + self.__base_url: str = base_url self.__url = f"{base_url}/api/tool/{name}/invoke" - # the following properties are set to help anyone that might inspect it determine + self.__desc = desc + self.__params = params + + # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name self.__doc__ = desc self.__signature__ = Signature(parameters=params, return_annotation=str) self.__annotations__ = {p.name: p.annotation for p in params} # TODO: self.__qualname__ ?? + # map of parameter name to auth service required by it + self.__required_authn_params = required_authn_params + # map of authService -> token_getter + self.__auth_service_token_getters = auth_service_token_getters + + def __copy( + self, + session: Optional[ClientSession] = None, + base_url: Optional[str] = None, + name: Optional[str] = None, + desc: Optional[str] = None, + params: Optional[list[Parameter]] = None, + required_authn_params: Optional[Mapping[str, list[str]]] = None, + auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, + ): + """ + Creates a copy of the ToolboxTool, overriding specific fields. + + Args: + session: The `aiohttp.ClientSession` used for making API requests. + base_url: The base URL of the Toolbox server API. + name: The name of the remote tool. + desc: The description of the remote tool (used as its docstring). + params: A list of `inspect.Parameter` objects defining the tool's + arguments and their types/defaults. + required_authn_params: A dict of required authenticated parameters that need + a auth_service_token_getter set for them yet. + auth_service_token_getters: A dict of authService -> token (or callables + that produce a token) + + """ + return ToolboxTool( + session=session or self.__session, + base_url=base_url or self.__base_url, + name=name or self.__name__, + desc=desc or self.__desc, + params=params or self.__params, + required_authn_params=required_authn_params or self.__required_authn_params, + auth_service_token_getters=auth_service_token_getters + or self.__auth_service_token_getters, + ) + async def __call__(self, *args: Any, **kwargs: Any) -> str: """ Asynchronously calls the remote tool with the provided arguments. @@ -81,16 +133,96 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: Returns: The string result returned by the remote tool execution. """ + + # check if any auth services need to be specified yet + if len(self.__required_authn_params) > 0: + req_auth_services = set(l for l in self.__required_authn_params.keys()) + raise Exception( + f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + ) + + # validate inputs to this call using the signature all_args = self.__signature__.bind(*args, **kwargs) all_args.apply_defaults() # Include default values if not provided payload = all_args.arguments + # create headers for auth services + headers = {} + for auth_service, token_getter in self.__auth_service_token_getters.items(): + headers[f"{auth_service}_token"] = token_getter() + async with self.__session.post( self.__url, json=payload, + headers=headers, ) as resp: - ret = await resp.json() - if "error" in ret: - # TODO: better error - raise Exception(ret["error"]) - return ret.get("result", ret) + body = await resp.json() + if resp.status < 200 or resp.status >= 300: + err = body.get("error", f"unexpected status from server: {resp.status}") + raise Exception(err) + return body.get("result", body) + + def add_auth_token_getters( + self, + auth_token_getters: Mapping[str, Callable[[], str]], + ) -> "ToolboxTool": + """ + Registers a auth token getter function that is used for AuthServices when tools + are invoked. + + Args: + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. + + Returns: + A new ToolboxTool instance with the specified authentication token + getters registered. + """ + + # throw an error if the authentication source is already registered + dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys() + if dupes: + raise ValueError( + f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`." + ) + + # create a read-only updated value for new_getters + new_getters = types.MappingProxyType( + dict(self.__auth_service_token_getters, **auth_token_getters) + ) + # create a read-only updated for params that are still required + new_req_authn_params = types.MappingProxyType( + filter_required_authn_params( + self.__required_authn_params, auth_token_getters.keys() + ) + ) + + return self.__copy( + auth_service_token_getters=new_getters, + required_authn_params=new_req_authn_params, + ) + + +def filter_required_authn_params( + req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str] +) -> dict[str, list[str]]: + """ + Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services. + + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_services: An iterable of authentication service names for which + token getters are available. + + Returns: + A new dictionary representing the subset of required authentication + parameters that are not covered by the provided `auth_services`. + """ + req_params = {} + for param, services in req_authn_params.items(): + # if we don't have a token_getter for any of the services required by the param, the param is still required + required = not any(s in services for s in auth_services) + if required: + req_params[param] = services + return req_params diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index b19c575b..101fdb91 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -16,7 +16,8 @@ import inspect import pytest - +import pytest_asyncio +from aioresponses import CallbackResult from toolbox_core import ToolboxClient from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema @@ -26,7 +27,7 @@ @pytest.fixture() def test_tool_str(): return ToolSchema( - description="Test Tool 1 Description", + description="Test Tool with String input", parameters=[ ParameterSchema( name="param1", type="string", description="Description of Param1" @@ -37,9 +38,8 @@ def test_tool_str(): @pytest.fixture() def test_tool_int_bool(): - """Fixture for the second test tool schema.""" return ToolSchema( - description="Test Tool 2 Description", + description="Test Tool with Int, Bool", parameters=[ ParameterSchema(name="argA", type="integer", description="Argument A"), ParameterSchema(name="argB", type="boolean", description="Argument B"), @@ -47,6 +47,22 @@ def test_tool_int_bool(): ) +@pytest.fixture() +def test_tool_auth(): + return ToolSchema( + description="Test Tool with Int,Bool+Auth", + parameters=[ + ParameterSchema(name="argA", type="integer", description="Argument A"), + ParameterSchema( + name="argB", + type="boolean", + description="Argument B", + authSources=["my-auth-service"], + ), + ], + ) + + @pytest.mark.asyncio async def test_load_tool_success(aioresponses, test_tool_str): """ @@ -83,6 +99,87 @@ async def test_load_tool_success(aioresponses, test_tool_str): assert await loaded_tool("some value") == "ok" +class TestAuth: + + @pytest.fixture + def expected_header(self): + return "some_token_for_testing" + + @pytest.fixture + def tool_name(self): + return "tool1" + + @pytest_asyncio.fixture + async def client(self, aioresponses, test_tool_auth, tool_name, expected_header): + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_auth} + ) + + # mock tool GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool INVOKE call + def require_headers(url, **kwargs): + if kwargs["headers"].get("my-auth-service_token") == expected_header: + return CallbackResult(status=200, body="{}") + else: + return CallbackResult(status=400, body="{}") + + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", + payload=manifest.model_dump(), + callback=require_headers, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + yield client + + @pytest.mark.asyncio + async def test_auth_with_load_tool_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool( + tool_name, auth_service_tokens={"my-auth-service": token_handler} + ) + res = await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_add_token_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool(tool_name) + tool = await client.add_auth_token_getters({"my-auth-service": token_handler}) + res = await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_load_tool_fail_no_token( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool(tool_name) + with pytest.raises(Exception): + res = await tool(5) + + @pytest.mark.asyncio async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): """Tests successfully loading a toolset with multiple tools.""" From 34bec3745ed915c49b9177da46bf84055e85f27f Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:22:03 +0000 Subject: [PATCH 02/50] chore: add asyncio dep --- packages/toolbox-core/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index edc45a8a..ee8a5f73 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -44,7 +44,8 @@ test = [ "isort==6.0.1", "mypy==1.15.0", "pytest==8.3.5", - "pytest-aioresponses==0.3.0" + "pytest-aioresponses==0.3.0", + "pytest-asyncio==0.25.3", ] [build-system] requires = ["setuptools"] From f0024bf0daa8fb52b8d7b5b996ddc34507873ecd Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:24:23 +0000 Subject: [PATCH 03/50] chore: run itest --- packages/toolbox-core/tests/test_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 101fdb91..b6b1ff85 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -18,6 +18,7 @@ import pytest import pytest_asyncio from aioresponses import CallbackResult + from toolbox_core import ToolboxClient from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema From a646285a60a10ed46ade4407a11713885fb5e7e6 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:26:39 +0000 Subject: [PATCH 04/50] chore: add type hint --- packages/toolbox-core/src/toolbox_core/tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 494c7c21..17c44f63 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -91,7 +91,7 @@ def __copy( params: Optional[list[Parameter]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, - ): + ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. From 0dc7034776acae31c93d92798b6f6386386c5f32 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:32:36 +0000 Subject: [PATCH 05/50] fix: call tool instead of client --- packages/toolbox-core/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index b6b1ff85..aebf3133 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -164,7 +164,7 @@ def token_handler(): return expected_header tool = await client.load_tool(tool_name) - tool = await client.add_auth_token_getters({"my-auth-service": token_handler}) + tool = tool.add_auth_token_getters({"my-auth-service": token_handler}) res = await tool(5) @pytest.mark.asyncio From 61d32aadbdd35f971b0c8de28f38cc4dbcb26153 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Mon, 31 Mar 2025 15:33:05 +0000 Subject: [PATCH 06/50] chore: correct arg name --- packages/toolbox-core/src/toolbox_core/client.py | 9 +++++++-- packages/toolbox-core/tests/test_client.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 0fcca0d2..1acd521a 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -119,7 +119,7 @@ async def close(self): async def load_tool( self, name: str, - auth_service_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -130,6 +130,8 @@ async def load_tool( Args: name: The unique name or identifier of the tool to load. + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. Returns: ToolboxTool: A callable object representing the loaded tool, ready @@ -148,7 +150,7 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name], auth_service_tokens) + tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters) return tool @@ -162,6 +164,9 @@ async def load_toolset( Args: name: Name of the toolset to load tools. + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. + Returns: list[ToolboxTool]: A list of callables, one for each tool defined diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index aebf3133..91ea2259 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -150,7 +150,7 @@ def token_handler(): return expected_header tool = await client.load_tool( - tool_name, auth_service_tokens={"my-auth-service": token_handler} + tool_name, auth_token_getters={"my-auth-service": token_handler} ) res = await tool(5) From 744ade90df706987c2c30f4bbf8ecd46f08b95c2 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Mon, 31 Mar 2025 23:02:52 +0000 Subject: [PATCH 07/50] feat: add support for bound parameters --- .../toolbox-core/src/toolbox_core/client.py | 24 +++++--- .../toolbox-core/src/toolbox_core/tool.py | 60 +++++++++++++++++-- 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 1acd521a..5afaa4db 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -58,18 +58,22 @@ def __parse_tool( name: str, schema: ToolSchema, auth_token_getters: dict[str, Callable[[], str]], + all_bound_params: dict[str, Callable[[], str]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" - # sort into authenticated and reg params + # sort into reg, authn, and bound params params = [] authn_params: dict[str, list[str]] = {} + bound_params: dict[str, Callable[[], str]] = {} auth_sources: set[str] = set() for p in schema.parameters: - if not p.authSources: - params.append(p) - else: + if p.authSources: # authn parameter authn_params[p.name] = p.authSources auth_sources.update(p.authSources) + elif p.name in all_bound_params: # bound parameter + bound_params[p.name] = all_bound_params[p.name] + else: # regular parameter + params.append(p) authn_params = filter_required_authn_params(authn_params, auth_sources) @@ -80,7 +84,8 @@ def __parse_tool( desc=schema.description, params=[p.to_param() for p in params], required_authn_params=types.MappingProxyType(authn_params), - auth_service_token_getters=auth_token_getters, + auth_service_token_getters=types.MappingProxyType(auth_token_getters), + bound_params=types.MappingProxyType(bound_params), ) return tool @@ -120,6 +125,7 @@ async def load_tool( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -150,7 +156,9 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters) + tool = self.__parse_tool( + name, manifest.tools[name], auth_token_getters, bound_params + ) return tool @@ -158,6 +166,7 @@ async def load_toolset( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Callable[[], str]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. @@ -168,6 +177,7 @@ async def load_toolset( callables that return the corresponding authentication token. + Returns: list[ToolboxTool]: A list of callables, one for each tool defined in the toolset. @@ -180,7 +190,7 @@ async def load_toolset( # parse each tools name and schema into a list of ToolboxTools tools = [ - self.__parse_tool(n, s, auth_token_getters) + self.__parse_tool(n, s, auth_token_getters, bound_params) for n, s in manifest.tools.items() ] return tools diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 17c44f63..2e88c179 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,10 +13,20 @@ # limitations under the License. +import asyncio import types from collections import defaultdict from inspect import Parameter, Signature -from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence +from typing import ( + Any, + Callable, + DefaultDict, + Iterable, + Mapping, + Optional, + Sequence, + Union, +) from aiohttp import ClientSession from pytest import Session @@ -44,6 +54,7 @@ def __init__( params: Sequence[Parameter], required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, ): """ Initializes a callable that will trigger the tool invocation through the @@ -81,6 +92,8 @@ def __init__( self.__required_authn_params = required_authn_params # map of authService -> token_getter self.__auth_service_token_getters = auth_service_token_getters + # map of parameter name to value or Callable + self.__bound_parameters = bound_params def __copy( self, @@ -91,6 +104,7 @@ def __copy( params: Optional[list[Parameter]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, + bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. @@ -117,6 +131,7 @@ def __copy( required_authn_params=required_authn_params or self.__required_authn_params, auth_service_token_getters=auth_service_token_getters or self.__auth_service_token_getters, + bound_params=bound_params or self.__bound_parameters, ) async def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -146,6 +161,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: all_args.apply_defaults() # Include default values if not provided payload = all_args.arguments + # apply bounded parameters + for param, value in self.__bound_parameters.items(): + if asyncio.iscoroutinefunction(value): + value = await value() + elif callable(value): + value = value() + payload[param] = value + # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): @@ -202,12 +225,41 @@ def add_auth_token_getters( required_authn_params=new_req_authn_params, ) + def bind_parameters( + self, bound_params: Mapping[str, Callable[[], str]] + ) -> "ToolboxTool": + """ + Binds parameters to values or callables that produce values. + + Args: + bound_params: A mapping of parameter names to values or callables that + produce values. + + Returns: + A new ToolboxTool instance with the specified parameters bound. + """ + all_params = set(p.name for p in self.__params) + for name in bound_params.keys(): + if name not in all_params: + raise Exception(f"unable to bind parameters: no parameter named {name}") + + new_params = [] + for p in self.__params: + if p.name not in bound_params: + new_params.append(p) + + return self.__copy( + params=new_params, + bound_params=bound_params, + ) + def filter_required_authn_params( req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str] ) -> dict[str, list[str]]: """ - Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services. + Utility function for reducing 'req_authn_params' to a subset of parameters that + aren't supplied by a least one service in auth_services. Args: req_authn_params: A mapping of parameter names to sets of required @@ -216,8 +268,8 @@ def filter_required_authn_params( token getters are available. Returns: - A new dictionary representing the subset of required authentication - parameters that are not covered by the provided `auth_services`. + A new dictionary representing the subset of required authentication parameters + that are not covered by the provided `auth_services`. """ req_params = {} for param, services in req_authn_params.items(): From b8757717edd35775a9112ef8457dbc3cd9d1fb61 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Mon, 31 Mar 2025 23:03:47 +0000 Subject: [PATCH 08/50] chore: add tests for bound parameters --- packages/toolbox-core/tests/test_client.py | 133 +++++++++++++++++---- 1 file changed, 113 insertions(+), 20 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 91ea2259..e25573e2 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -14,6 +14,7 @@ import inspect +import json import pytest import pytest_asyncio @@ -100,6 +101,31 @@ async def test_load_tool_success(aioresponses, test_tool_str): assert await loaded_tool("some value") == "ok" +@pytest.mark.asyncio +async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): + """Tests successfully loading a toolset with multiple tools.""" + TOOLSET_NAME = "my_toolset" + TOOL1 = "tool1" + TOOL2 = "tool2" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} + ) + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", + payload=manifest.model_dump(), + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + tools = await client.load_toolset(TOOLSET_NAME) + + assert isinstance(tools, list) + assert len(tools) == len(manifest.tools) + + # Check if tools were created correctly + assert {t.__name__ for t in tools} == manifest.tools.keys() + + class TestAuth: @pytest.fixture @@ -181,26 +207,93 @@ def token_handler(): res = await tool(5) -@pytest.mark.asyncio -async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): - """Tests successfully loading a toolset with multiple tools.""" - TOOLSET_NAME = "my_toolset" - TOOL1 = "tool1" - TOOL2 = "tool2" - manifest = ManifestSchema( - serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} - ) - aioresponses.get( - f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", - payload=manifest.model_dump(), - status=200, - ) +class TestBoundParameter: - async with ToolboxClient(TEST_BASE_URL) as client: - tools = await client.load_toolset(TOOLSET_NAME) + @pytest.fixture + def tool_name(self): + return "tool1" - assert isinstance(tools, list) - assert len(tools) == len(manifest.tools) + @pytest_asyncio.fixture + async def client(self, aioresponses, test_tool_int_bool, tool_name): + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_int_bool} + ) - # Check if tools were created correctly - assert {t.__name__ for t in tools} == manifest.tools.keys() + # mock toolset GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool INVOKE call + def reflect_parameters(url, **kwargs): + body = {"result": kwargs["json"]} + return CallbackResult(status=200, body=json.dumps(body)) + + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", + payload=manifest.model_dump(), + callback=reflect_parameters, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + yield client + + @pytest.mark.asyncio + async def test_load_tool_success(self, tool_name, client): + """Tests 'load_tool' with a bound parameter specified.""" + tool = await client.load_tool(tool_name, bound_params={"argA": lambda: 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_load_toolset_success(self, tool_name, client): + """Tests 'load_toolset' with a bound parameter specified.""" + tools = await client.load_toolset("", bound_params={"argB": lambda: "hello"}) + tool = tools[0] + + assert len(tool.__signature__.parameters) == 1 + assert "argB" not in tool.__signature__.parameters + + res = await tool(True) + assert "argB" in res + + @pytest.mark.asyncio + async def test_bind_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_parameters({"argA": lambda: 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_param_fail(self, tool_name, client): + """Tests 'bind_param' with a bound parameter that doesn't exist.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + with pytest.raises(Exception): + tool = tool.bind_parameters({"argC": lambda: 5}) From 138d8d975c11c09870a6f644eb3deb4323a6de5d Mon Sep 17 00:00:00 2001 From: Yuan <45984206+Yuan325@users.noreply.github.com> Date: Mon, 31 Mar 2025 16:30:33 -0700 Subject: [PATCH 09/50] docs: update syntax error on readme (#121) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 78688861..a4370cbd 100644 --- a/README.md +++ b/README.md @@ -291,7 +291,7 @@ bound_tools = [tool.bind_param("param", "value") for tool in tools] ### Binding Parameters While Loading ```py -bound_tool = toolbox.load_tool(bound_params={"param": "value"}) +bound_tool = toolbox.load_tool("my-tool", bound_params={"param": "value"}) bound_tools = toolbox.load_toolset(bound_params={"param": "value"}) ``` From e2f5b4d5b1e74107caa822354bf0b06d31065bc7 Mon Sep 17 00:00:00 2001 From: Twisha Bansal <58483338+twishabansal@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:54:49 +0530 Subject: [PATCH 10/50] ci: added release please config (#112) * ci: add release please config * chore: add initial version * chore: specify initial version as string * chore: Update .release-please-manifest.json * chore: add empty json * chore: small change * chore: try fixing config * chore: try fixing config again * chore: remove release-as * chore: add changelog sections * chore: better release notes * chore: better release notes * chore: change toolbox-langchain version * chore: separate PRs for packages * chore: change PR style --- .release-please-manifest.json | 1 + release-please-config.json | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 .release-please-manifest.json create mode 100644 release-please-config.json diff --git a/.release-please-manifest.json b/.release-please-manifest.json new file mode 100644 index 00000000..4a7b2a5b --- /dev/null +++ b/.release-please-manifest.json @@ -0,0 +1 @@ +{"packages/toolbox-langchain": "0.1.0"} diff --git a/release-please-config.json b/release-please-config.json new file mode 100644 index 00000000..4f05e59d --- /dev/null +++ b/release-please-config.json @@ -0,0 +1,25 @@ +{ + "bootstrap-sha": "ac43090822fbf19a8920732e2ec3aa8b9c3130c1", + "release-type": "python", + "bump-minor-pre-major": true, + "bump-patch-for-minor-pre-major": true, + "separate-pull-requests": true, + "include-component-in-tag": true, + "changelog-sections": [ + { "type": "feat", "section": "Features" }, + { "type": "fix", "section": "Bug Fixes" }, + { "type": "chore", "section": "Miscellaneous Chores", "hidden": false } + ], + "packages": { + "packages/toolbox-core": { + "extra-files": [ + "src/toolbox_core/version.py" + ] + }, + "packages/toolbox-langchain": { + "extra-files": [ + "src/toolbox_langchain/version.py" + ] + } + } +} From 280a79fef928fbd880772955fc8a2ec1c48cdd81 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 17:41:45 +0530 Subject: [PATCH 11/50] added basic e2e tests --- packages/toolbox-core/tests/conftest.py | 166 +++++++++++++++++++++++ packages/toolbox-core/tests/test_e2e.py | 171 ++++++++++++++++++++++++ 2 files changed, 337 insertions(+) create mode 100644 packages/toolbox-core/tests/conftest.py create mode 100644 packages/toolbox-core/tests/test_e2e.py diff --git a/packages/toolbox-core/tests/conftest.py b/packages/toolbox-core/tests/conftest.py new file mode 100644 index 00000000..76d1ab5c --- /dev/null +++ b/packages/toolbox-core/tests/conftest.py @@ -0,0 +1,166 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains pytest fixtures that are accessible from all +files present in the same directory.""" + +from __future__ import annotations + +import os +import platform +import subprocess +import tempfile +import time +from typing import Generator + +import google +import pytest_asyncio +from google.auth import compute_engine +from google.cloud import secretmanager, storage + + +#### Define Utility Functions +def get_env_var(key: str) -> str: + """Gets environment variables.""" + value = os.environ.get(key) + if value is None: + raise ValueError(f"Must set env var {key}") + return value + + +def access_secret_version( + project_id: str, secret_id: str, version_id: str = "latest" +) -> str: + """Accesses the payload of a given secret version from Secret Manager.""" + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}" + response = client.access_secret_version(request={"name": name}) + return response.payload.data.decode("UTF-8") + + +def create_tmpfile(content: str) -> str: + """Creates a temporary file with the given content.""" + with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmpfile: + tmpfile.write(content) + return tmpfile.name + + +def download_blob( + bucket_name: str, source_blob_name: str, destination_file_name: str +) -> None: + """Downloads a blob from a GCS bucket.""" + storage_client = storage.Client() + + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(source_blob_name) + blob.download_to_filename(destination_file_name) + + print(f"Blob {source_blob_name} downloaded to {destination_file_name}.") + + +def get_toolbox_binary_url(toolbox_version: str) -> str: + """Constructs the GCS path to the toolbox binary.""" + os_system = platform.system().lower() + arch = ( + "arm64" if os_system == "darwin" and platform.machine() == "arm64" else "amd64" + ) + return f"v{toolbox_version}/{os_system}/{arch}/toolbox" + + +def get_auth_token(client_id: str) -> str: + """Retrieves an authentication token""" + request = google.auth.transport.requests.Request() + credentials = compute_engine.IDTokenCredentials( + request=request, + target_audience=client_id, + use_metadata_identity_endpoint=True, + ) + if not credentials.valid: + credentials.refresh(request) + return credentials.token + + +#### Define Fixtures +@pytest_asyncio.fixture(scope="session") +def project_id() -> str: + return get_env_var("GOOGLE_CLOUD_PROJECT") + + +@pytest_asyncio.fixture(scope="session") +def toolbox_version() -> str: + return get_env_var("TOOLBOX_VERSION") + + +@pytest_asyncio.fixture(scope="session") +def tools_file_path(project_id: str) -> Generator[str]: + """Provides a temporary file path containing the tools manifest.""" + tools_manifest = access_secret_version( + project_id=project_id, secret_id="sdk_testing_tools" + ) + tools_file_path = create_tmpfile(tools_manifest) + yield tools_file_path + os.remove(tools_file_path) + + +@pytest_asyncio.fixture(scope="session") +def auth_token1(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client1" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def auth_token2(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client2" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None]: + """Starts the toolbox server as a subprocess.""" + print("Downloading toolbox binary from gcs bucket...") + source_blob_name = get_toolbox_binary_url(toolbox_version) + download_blob("genai-toolbox", source_blob_name, "toolbox") + print("Toolbox binary downloaded successfully.") + try: + print("Opening toolbox server process...") + # Make toolbox executable + os.chmod("toolbox", 0o700) + # Run toolbox binary + toolbox_server = subprocess.Popen( + ["./toolbox", "--tools_file", tools_file_path] + ) + + # Wait for server to start + # Retry logic with a timeout + for _ in range(5): # retries + time.sleep(2) + print("Checking if toolbox is successfully started...") + if toolbox_server.poll() is None: + print("Toolbox server started successfully.") + break + else: + raise RuntimeError("Toolbox server failed to start after 5 retries.") + except subprocess.CalledProcessError as e: + print(e.stderr.decode("utf-8")) + print(e.stdout.decode("utf-8")) + raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e + yield + + # Clean up toolbox server + toolbox_server.terminate() + toolbox_server.wait() diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py new file mode 100644 index 00000000..4d9c123c --- /dev/null +++ b/packages/toolbox-core/tests/test_e2e.py @@ -0,0 +1,171 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for the toolbox SDK interacting with the toolbox server. + +This file covers the following use cases: + +1. Loading a tool. +2. Loading a specific toolset. +3. Loading the default toolset (contains all tools). +4. Running a tool with + a. Missing params. + b. Wrong param type. +5. Running a tool with no required auth, with auth provided. +6. Running a tool with required auth: + a. No auth provided. + b. Wrong auth provided: The tool requires a different authentication + than the one provided. + c. Correct auth provided. +7. Running a tool with a parameter that requires auth: + a. No auth provided. + b. Correct auth provided. + c. Auth provided does not contain the required claim. +""" +import pytest +import pytest_asyncio + +from toolbox_core.client import ToolboxClient +from toolbox_core.tool import ToolboxTool + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestE2EClient: + @pytest_asyncio.fixture(scope="function") + async def toolbox(self): + toolbox = ToolboxClient("http://localhost:5000") + try: + yield toolbox + finally: + await toolbox.close() + + @pytest_asyncio.fixture(scope="function") + async def get_n_rows_tool(self, toolbox: ToolboxClient) -> ToolboxTool: + tool = await toolbox.load_tool("get-n-rows") + assert tool.__name__ == "get-n-rows" + return tool + + #### Basic e2e tests + @pytest.mark.parametrize( + "toolset_name, expected_length, expected_tools", + [ + ("my-toolset", 1, ["get-row-by-id"]), + ("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]), + ], + ) + async def test_load_toolset_specific( + self, + toolbox: ToolboxClient, + toolset_name: str, + expected_length: int, + expected_tools: list[str], + ): + toolset = await toolbox.load_toolset(toolset_name) + assert len(toolset) == expected_length + tool_names = {tool.__name__ for tool in toolset} + assert tool_names == set(expected_tools) + + async def test_run_tool(self, get_n_rows_tool: ToolboxTool): + response = await get_n_rows_tool(num_rows="2") + + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" not in response + + async def test_run_tool_missing_params(self, get_n_rows_tool): + with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): + await get_n_rows_tool() + + async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): + with pytest.raises( + Exception, + match='provided parameters were invalid: unable to parse value for "num_rows": .* not type "string"', + ): + await get_n_rows_tool(num_rows=2) # Pass the integer value + + ##### Auth tests + @pytest.mark.asyncio + async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): + """Tests running a tool that doesn't require auth, with auth provided.""" + tool = await toolbox.load_tool( + "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} + ) + response = await tool(id="2") + assert "row2" in response + + async def test_run_tool_no_auth(self, toolbox): + """Tests running a tool requiring auth without providing auth.""" + tool = await toolbox.load_tool( + "get-row-by-id-auth", + ) + with pytest.raises( + Exception, + match="One of more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool(id="2") + + async def test_run_tool_wrong_auth(self, toolbox, auth_token2): + """Tests running a tool with incorrect auth.""" + tool = await toolbox.load_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token_getters("my-test-auth", lambda: auth_token2) + with pytest.raises( + Exception, + match="tool invocation not authorized", + ): + await auth_tool(id="2") + + async def test_run_tool_auth(self, toolbox, auth_token1): + """Tests running a tool with correct auth.""" + tool = await toolbox.load_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token_getters("my-test-auth", lambda: auth_token1) + response = await auth_tool(id="2") + assert "row2" in response + + async def test_run_tool_param_auth_no_auth(self, toolbox): + """Tests running a tool with a param requiring auth, without auth.""" + tool = await toolbox.load_tool("get-row-by-email-auth") + with pytest.raises( + Exception, + match="One of more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool(email="") + + async def test_run_tool_param_auth(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with correct auth.""" + tool = await toolbox.load_tool( + "get-row-by-email-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + response = await tool() + assert "row4" in response + assert "row5" in response + assert "row6" in response + + async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with insufficient auth.""" + tool = await toolbox.load_tool( + "get-row-by-content-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + with pytest.raises( + Exception, + match="no field named row_data in claims", + ): + await tool() From 63808ce419e428188a69a01ade5a6128bc766eb5 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 17:45:07 +0530 Subject: [PATCH 12/50] change license year --- packages/toolbox-core/tests/conftest.py | 2 +- packages/toolbox-core/tests/test_e2e.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-core/tests/conftest.py b/packages/toolbox-core/tests/conftest.py index 76d1ab5c..ee6f5e4f 100644 --- a/packages/toolbox-core/tests/conftest.py +++ b/packages/toolbox-core/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 4d9c123c..24f61f57 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 55f5ca0b6529df0937ffa12a5e56a7d30b0d7494 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 17:48:47 +0530 Subject: [PATCH 13/50] add test deps --- packages/toolbox-core/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index ee8a5f73..e9ee66f1 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -46,6 +46,8 @@ test = [ "pytest==8.3.5", "pytest-aioresponses==0.3.0", "pytest-asyncio==0.25.3", + "google-cloud-secret-manager==2.23.2", + "google-cloud-storage==3.1.0", ] [build-system] requires = ["setuptools"] From ed16488dc037053ca5b2269bafd3cd7e97f3aa6c Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 17:52:14 +0530 Subject: [PATCH 14/50] fix tests --- packages/toolbox-core/tests/test_e2e.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 24f61f57..3f6a89d3 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -113,7 +113,7 @@ async def test_run_tool_no_auth(self, toolbox): ) with pytest.raises( Exception, - match="One of more of the following authn services are required to invoke this tool: my-test-auth", + match="tool invocation not authorized. Please make sure your specify correct auth headers", ): await tool(id="2") @@ -122,7 +122,7 @@ async def test_run_tool_wrong_auth(self, toolbox, auth_token2): tool = await toolbox.load_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token_getters("my-test-auth", lambda: auth_token2) + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token2}) with pytest.raises( Exception, match="tool invocation not authorized", @@ -134,7 +134,7 @@ async def test_run_tool_auth(self, toolbox, auth_token1): tool = await toolbox.load_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token_getters("my-test-auth", lambda: auth_token1) + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token1}) response = await auth_tool(id="2") assert "row2" in response From f0991ab2172de5c77d35b1a00ef8ba435aef8c25 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 17:59:57 +0530 Subject: [PATCH 15/50] fix tests --- packages/toolbox-core/tests/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 3f6a89d3..202376f1 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -145,7 +145,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): Exception, match="One of more of the following authn services are required to invoke this tool: my-test-auth", ): - await tool(email="") + await tool() async def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" From cc6d568b1280e8b3b4433bd96fc03effc51cf0ad Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:06:25 +0530 Subject: [PATCH 16/50] fix tests --- packages/toolbox-core/tests/test_e2e.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 202376f1..90790a34 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -143,7 +143,8 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): tool = await toolbox.load_tool("get-row-by-email-auth") with pytest.raises( Exception, - match="One of more of the following authn services are required to invoke this tool: my-test-auth", + match="provided parameters were invalid: error parsing authenticated parameter " + '"email": missing or invalid authentication header', ): await tool() From 1f63b0d32195b47453b0edc840192cfe3a4380d7 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:12:03 +0530 Subject: [PATCH 17/50] add new test case --- packages/toolbox-core/tests/test_e2e.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 90790a34..e9feac71 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -32,6 +32,7 @@ a. No auth provided. b. Correct auth provided. c. Auth provided does not contain the required claim. + d. Auth Service not registered in manifest """ import pytest import pytest_asyncio @@ -148,6 +149,15 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): ): await tool() + async def test_run_tool_param_auth_no_service(self, toolbox): + """Tests running a tool with a param requiring auth, without a correctly registered auth service.""" + tool = await toolbox.load_tool("get-row-by-email-auth-wrong-auth-source") + with pytest.raises( + Exception, + match="One of more of the following authn services are required to invoke this tool: my-test-auth3", + ): + await tool() + async def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" tool = await toolbox.load_tool( From 66e88ab32ab852dbcbe7270aab2c4f12978d92c5 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:25:10 +0530 Subject: [PATCH 18/50] fix docstring --- packages/toolbox-core/tests/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index e9feac71..c427af9c 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -150,7 +150,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): await tool() async def test_run_tool_param_auth_no_service(self, toolbox): - """Tests running a tool with a param requiring auth, without a correctly registered auth service.""" + """Tests running a tool with a param requiring auth, without a registered auth service.""" tool = await toolbox.load_tool("get-row-by-email-auth-wrong-auth-source") with pytest.raises( Exception, From 1fa1b7aeee22cff595eee6a0af26f9137264599d Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:25:29 +0530 Subject: [PATCH 19/50] added todo --- packages/toolbox-core/tests/test_e2e.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index c427af9c..24831ba1 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -149,14 +149,15 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): ): await tool() - async def test_run_tool_param_auth_no_service(self, toolbox): - """Tests running a tool with a param requiring auth, without a registered auth service.""" - tool = await toolbox.load_tool("get-row-by-email-auth-wrong-auth-source") - with pytest.raises( - Exception, - match="One of more of the following authn services are required to invoke this tool: my-test-auth3", - ): - await tool() + # TODO: Uncomment after fix + # async def test_run_tool_param_auth_no_service(self, toolbox): + # """Tests running a tool with a param requiring auth, without a registered auth service.""" + # tool = await toolbox.load_tool("get-row-by-email-auth-wrong-auth-source") + # with pytest.raises( + # Exception, + # match="One of more of the following authn services are required to invoke this tool: my-test-auth3", + # ): + # await tool() async def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" From 12af5fa085f8b24cbfce341c423e21e0427e82d7 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:26:11 +0530 Subject: [PATCH 20/50] cleanup --- packages/toolbox-core/tests/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 24831ba1..43bae602 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -95,7 +95,7 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): Exception, match='provided parameters were invalid: unable to parse value for "num_rows": .* not type "string"', ): - await get_n_rows_tool(num_rows=2) # Pass the integer value + await get_n_rows_tool(num_rows=2) ##### Auth tests @pytest.mark.asyncio From 9c3ba38873aad07b08b26a54932121235f8e6a27 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:31:17 +0530 Subject: [PATCH 21/50] add bind param test case --- packages/toolbox-core/tests/test_e2e.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 43bae602..b8c5e411 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -97,8 +97,18 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): ): await get_n_rows_tool(num_rows=2) + ##### Bind param tests + async def test_bind_params(self, toolbox, get_n_rows_tool): + new_tool = get_n_rows_tool.bind_parameters({"num_rows": "3"}) + response = await new_tool() + + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" in response + assert "row4" not in response + ##### Auth tests - @pytest.mark.asyncio async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" tool = await toolbox.load_tool( From d75828edbd3bce0776979de540c79295d660940e Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:34:49 +0530 Subject: [PATCH 22/50] make bind params dynamic --- packages/toolbox-core/tests/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index b8c5e411..00c6a7d1 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -99,7 +99,7 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): ##### Bind param tests async def test_bind_params(self, toolbox, get_n_rows_tool): - new_tool = get_n_rows_tool.bind_parameters({"num_rows": "3"}) + new_tool = get_n_rows_tool.bind_parameters({"num_rows": lambda: "3"}) response = await new_tool() assert isinstance(response, str) From fef1f7f627a68376a37da8844dddc8eab09f68ed Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:54:22 +0530 Subject: [PATCH 23/50] try fix test errors --- .../toolbox-core/src/toolbox_core/client.py | 2 +- .../toolbox-core/src/toolbox_core/tool.py | 25 ++++++++++++------- packages/toolbox-core/tests/test_e2e.py | 17 ++++++------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 5afaa4db..6e2a490f 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -75,7 +75,7 @@ def __parse_tool( else: # regular parameter params.append(p) - authn_params = filter_required_authn_params(authn_params, auth_sources) + authn_params = filter_required_authn_params(authn_params, auth_token_getters.keys()) tool = ToolboxTool( session=self.__session, diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 2e88c179..0e146a0a 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -21,6 +21,7 @@ Any, Callable, DefaultDict, + TypeVar, Iterable, Mapping, Optional, @@ -31,6 +32,8 @@ from aiohttp import ClientSession from pytest import Session +T = TypeVar('T') + class ToolboxTool: """ @@ -122,16 +125,20 @@ def __copy( that produce a token) """ + + def _resolve_value(override_value: Optional[T], default_value: T) -> T: + """Returns the override_value if it's not None, otherwise the default_value.""" + return override_value if override_value is not None else default_value + return ToolboxTool( - session=session or self.__session, - base_url=base_url or self.__base_url, - name=name or self.__name__, - desc=desc or self.__desc, - params=params or self.__params, - required_authn_params=required_authn_params or self.__required_authn_params, - auth_service_token_getters=auth_service_token_getters - or self.__auth_service_token_getters, - bound_params=bound_params or self.__bound_parameters, + session=_resolve_value(session, self.__session), + base_url=_resolve_value(base_url, self.__base_url), + name=_resolve_value(name, self.__name__), + desc=_resolve_value(desc, self.__desc), + params=_resolve_value(params, self.__params), + required_authn_params=_resolve_value(required_authn_params, self.__required_authn_params), + auth_service_token_getters=_resolve_value(auth_service_token_getters, self.__auth_service_token_getters), + bound_params=_resolve_value(bound_params, self.__bound_parameters), ) async def __call__(self, *args: Any, **kwargs: Any) -> str: diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 00c6a7d1..0d4d402e 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -159,15 +159,14 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): ): await tool() - # TODO: Uncomment after fix - # async def test_run_tool_param_auth_no_service(self, toolbox): - # """Tests running a tool with a param requiring auth, without a registered auth service.""" - # tool = await toolbox.load_tool("get-row-by-email-auth-wrong-auth-source") - # with pytest.raises( - # Exception, - # match="One of more of the following authn services are required to invoke this tool: my-test-auth3", - # ): - # await tool() + async def test_run_tool_param_auth_no_service(self, toolbox): + """Tests running a tool with a param requiring auth, without a registered auth service.""" + tool = await toolbox.load_tool("get-row-by-email-auth-wrong-auth-source") + with pytest.raises( + Exception, + match="One of more of the following authn services are required to invoke this tool: my-test-auth3", + ): + await tool() async def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" From cd701b71478200d607f8db0ba208a1646b5a15f9 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:55:46 +0530 Subject: [PATCH 24/50] lint --- packages/toolbox-core/src/toolbox_core/client.py | 4 +++- packages/toolbox-core/src/toolbox_core/tool.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 6e2a490f..77ee40f8 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -75,7 +75,9 @@ def __parse_tool( else: # regular parameter params.append(p) - authn_params = filter_required_authn_params(authn_params, auth_token_getters.keys()) + authn_params = filter_required_authn_params( + authn_params, auth_token_getters.keys() + ) tool = ToolboxTool( session=self.__session, diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 0e146a0a..c5cb6c89 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -21,18 +21,18 @@ Any, Callable, DefaultDict, - TypeVar, Iterable, Mapping, Optional, Sequence, + TypeVar, Union, ) from aiohttp import ClientSession from pytest import Session -T = TypeVar('T') +T = TypeVar("T") class ToolboxTool: @@ -136,8 +136,12 @@ def _resolve_value(override_value: Optional[T], default_value: T) -> T: name=_resolve_value(name, self.__name__), desc=_resolve_value(desc, self.__desc), params=_resolve_value(params, self.__params), - required_authn_params=_resolve_value(required_authn_params, self.__required_authn_params), - auth_service_token_getters=_resolve_value(auth_service_token_getters, self.__auth_service_token_getters), + required_authn_params=_resolve_value( + required_authn_params, self.__required_authn_params + ), + auth_service_token_getters=_resolve_value( + auth_service_token_getters, self.__auth_service_token_getters + ), bound_params=_resolve_value(bound_params, self.__bound_parameters), ) From 7a5a1bddb3db8e351f1656e04962591c3892a55b Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 18:57:56 +0530 Subject: [PATCH 25/50] remove redundant test --- packages/toolbox-core/tests/test_e2e.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 0d4d402e..20424bba 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -159,15 +159,6 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): ): await tool() - async def test_run_tool_param_auth_no_service(self, toolbox): - """Tests running a tool with a param requiring auth, without a registered auth service.""" - tool = await toolbox.load_tool("get-row-by-email-auth-wrong-auth-source") - with pytest.raises( - Exception, - match="One of more of the following authn services are required to invoke this tool: my-test-auth3", - ): - await tool() - async def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" tool = await toolbox.load_tool( From e343e788b45560b0b546f968d3f363747ceb1a96 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 19:11:02 +0530 Subject: [PATCH 26/50] test fix --- packages/toolbox-core/tests/test_e2e.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 20424bba..8e6d8757 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -152,10 +152,10 @@ async def test_run_tool_auth(self, toolbox, auth_token1): async def test_run_tool_param_auth_no_auth(self, toolbox): """Tests running a tool with a param requiring auth, without auth.""" tool = await toolbox.load_tool("get-row-by-email-auth") + # TODO: Change match to {my-auth-service3} instead of email after fix in PR with pytest.raises( Exception, - match="provided parameters were invalid: error parsing authenticated parameter " - '"email": missing or invalid authentication header', + match="One of more of the following authn services are required to invoke this tool: email", ): await tool() From 11d62b4f047e11b15c4e2fb64b3b782b793fc318 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Tue, 1 Apr 2025 19:26:26 +0530 Subject: [PATCH 27/50] fix docstring --- packages/toolbox-core/tests/test_e2e.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 8e6d8757..8854eb4d 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -32,7 +32,6 @@ a. No auth provided. b. Correct auth provided. c. Auth provided does not contain the required claim. - d. Auth Service not registered in manifest """ import pytest import pytest_asyncio From 87e00cc4f02c08fdc715583679e8db61459b0baf Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:04:30 +0000 Subject: [PATCH 28/50] feat: add authenticated parameters support --- .../toolbox-core/src/toolbox_core/client.py | 39 ++++- .../toolbox-core/src/toolbox_core/tool.py | 160 ++++++++++++++++-- packages/toolbox-core/tests/test_client.py | 105 +++++++++++- 3 files changed, 279 insertions(+), 25 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index dc59e440..0fcca0d2 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Optional +import types +from typing import Any, Callable, Optional from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool +from .tool import ToolboxTool, filter_required_authn_params class ToolboxClient: @@ -53,14 +53,34 @@ def __init__( session = ClientSession() self.__session = session - def __parse_tool(self, name: str, schema: ToolSchema) -> ToolboxTool: + def __parse_tool( + self, + name: str, + schema: ToolSchema, + auth_token_getters: dict[str, Callable[[], str]], + ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" + # sort into authenticated and reg params + params = [] + authn_params: dict[str, list[str]] = {} + auth_sources: set[str] = set() + for p in schema.parameters: + if not p.authSources: + params.append(p) + else: + authn_params[p.name] = p.authSources + auth_sources.update(p.authSources) + + authn_params = filter_required_authn_params(authn_params, auth_sources) + tool = ToolboxTool( session=self.__session, base_url=self.__base_url, name=name, desc=schema.description, - params=[p.to_param() for p in schema.parameters], + params=[p.to_param() for p in params], + required_authn_params=types.MappingProxyType(authn_params), + auth_service_token_getters=auth_token_getters, ) return tool @@ -99,6 +119,7 @@ async def close(self): async def load_tool( self, name: str, + auth_service_tokens: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -127,13 +148,14 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name]) + tool = self.__parse_tool(name, manifest.tools[name], auth_service_tokens) return tool async def load_toolset( self, name: str, + auth_token_getters: dict[str, Callable[[], str]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. @@ -152,5 +174,8 @@ async def load_toolset( manifest: ManifestSchema = ManifestSchema(**json) # parse each tools name and schema into a list of ToolboxTools - tools = [self.__parse_tool(n, s) for n, s in manifest.tools.items()] + tools = [ + self.__parse_tool(n, s, auth_token_getters) + for n, s in manifest.tools.items() + ] return tools diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 48e4626c..494c7c21 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,10 +13,13 @@ # limitations under the License. +import types +from collections import defaultdict from inspect import Parameter, Signature -from typing import Any +from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence from aiohttp import ClientSession +from pytest import Session class ToolboxTool: @@ -32,20 +35,19 @@ class ToolboxTool: and `inspect` work as expected. """ - __url: str - __session: ClientSession - __signature__: Signature - def __init__( self, session: ClientSession, base_url: str, name: str, desc: str, - params: list[Parameter], + params: Sequence[Parameter], + required_authn_params: Mapping[str, list[str]], + auth_service_token_getters: Mapping[str, Callable[[], str]], ): """ - Initializes a callable that will trigger the tool invocation through the Toolbox server. + Initializes a callable that will trigger the tool invocation through the + Toolbox server. Args: session: The `aiohttp.ClientSession` used for making API requests. @@ -54,19 +56,69 @@ def __init__( desc: The description of the remote tool (used as its docstring). params: A list of `inspect.Parameter` objects defining the tool's arguments and their types/defaults. + required_authn_params: A dict of required authenticated parameters that + need a auth_service_token_getter set for them yet. + auth_service_tokens: A dict of authService -> token (or callables that + produce a token) """ # used to invoke the toolbox API - self.__session = session + self.__session: ClientSession = session + self.__base_url: str = base_url self.__url = f"{base_url}/api/tool/{name}/invoke" - # the following properties are set to help anyone that might inspect it determine + self.__desc = desc + self.__params = params + + # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name self.__doc__ = desc self.__signature__ = Signature(parameters=params, return_annotation=str) self.__annotations__ = {p.name: p.annotation for p in params} # TODO: self.__qualname__ ?? + # map of parameter name to auth service required by it + self.__required_authn_params = required_authn_params + # map of authService -> token_getter + self.__auth_service_token_getters = auth_service_token_getters + + def __copy( + self, + session: Optional[ClientSession] = None, + base_url: Optional[str] = None, + name: Optional[str] = None, + desc: Optional[str] = None, + params: Optional[list[Parameter]] = None, + required_authn_params: Optional[Mapping[str, list[str]]] = None, + auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, + ): + """ + Creates a copy of the ToolboxTool, overriding specific fields. + + Args: + session: The `aiohttp.ClientSession` used for making API requests. + base_url: The base URL of the Toolbox server API. + name: The name of the remote tool. + desc: The description of the remote tool (used as its docstring). + params: A list of `inspect.Parameter` objects defining the tool's + arguments and their types/defaults. + required_authn_params: A dict of required authenticated parameters that need + a auth_service_token_getter set for them yet. + auth_service_token_getters: A dict of authService -> token (or callables + that produce a token) + + """ + return ToolboxTool( + session=session or self.__session, + base_url=base_url or self.__base_url, + name=name or self.__name__, + desc=desc or self.__desc, + params=params or self.__params, + required_authn_params=required_authn_params or self.__required_authn_params, + auth_service_token_getters=auth_service_token_getters + or self.__auth_service_token_getters, + ) + async def __call__(self, *args: Any, **kwargs: Any) -> str: """ Asynchronously calls the remote tool with the provided arguments. @@ -81,16 +133,96 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: Returns: The string result returned by the remote tool execution. """ + + # check if any auth services need to be specified yet + if len(self.__required_authn_params) > 0: + req_auth_services = set(l for l in self.__required_authn_params.keys()) + raise Exception( + f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + ) + + # validate inputs to this call using the signature all_args = self.__signature__.bind(*args, **kwargs) all_args.apply_defaults() # Include default values if not provided payload = all_args.arguments + # create headers for auth services + headers = {} + for auth_service, token_getter in self.__auth_service_token_getters.items(): + headers[f"{auth_service}_token"] = token_getter() + async with self.__session.post( self.__url, json=payload, + headers=headers, ) as resp: - ret = await resp.json() - if "error" in ret: - # TODO: better error - raise Exception(ret["error"]) - return ret.get("result", ret) + body = await resp.json() + if resp.status < 200 or resp.status >= 300: + err = body.get("error", f"unexpected status from server: {resp.status}") + raise Exception(err) + return body.get("result", body) + + def add_auth_token_getters( + self, + auth_token_getters: Mapping[str, Callable[[], str]], + ) -> "ToolboxTool": + """ + Registers a auth token getter function that is used for AuthServices when tools + are invoked. + + Args: + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. + + Returns: + A new ToolboxTool instance with the specified authentication token + getters registered. + """ + + # throw an error if the authentication source is already registered + dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys() + if dupes: + raise ValueError( + f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`." + ) + + # create a read-only updated value for new_getters + new_getters = types.MappingProxyType( + dict(self.__auth_service_token_getters, **auth_token_getters) + ) + # create a read-only updated for params that are still required + new_req_authn_params = types.MappingProxyType( + filter_required_authn_params( + self.__required_authn_params, auth_token_getters.keys() + ) + ) + + return self.__copy( + auth_service_token_getters=new_getters, + required_authn_params=new_req_authn_params, + ) + + +def filter_required_authn_params( + req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str] +) -> dict[str, list[str]]: + """ + Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services. + + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_services: An iterable of authentication service names for which + token getters are available. + + Returns: + A new dictionary representing the subset of required authentication + parameters that are not covered by the provided `auth_services`. + """ + req_params = {} + for param, services in req_authn_params.items(): + # if we don't have a token_getter for any of the services required by the param, the param is still required + required = not any(s in services for s in auth_services) + if required: + req_params[param] = services + return req_params diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index b19c575b..101fdb91 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -16,7 +16,8 @@ import inspect import pytest - +import pytest_asyncio +from aioresponses import CallbackResult from toolbox_core import ToolboxClient from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema @@ -26,7 +27,7 @@ @pytest.fixture() def test_tool_str(): return ToolSchema( - description="Test Tool 1 Description", + description="Test Tool with String input", parameters=[ ParameterSchema( name="param1", type="string", description="Description of Param1" @@ -37,9 +38,8 @@ def test_tool_str(): @pytest.fixture() def test_tool_int_bool(): - """Fixture for the second test tool schema.""" return ToolSchema( - description="Test Tool 2 Description", + description="Test Tool with Int, Bool", parameters=[ ParameterSchema(name="argA", type="integer", description="Argument A"), ParameterSchema(name="argB", type="boolean", description="Argument B"), @@ -47,6 +47,22 @@ def test_tool_int_bool(): ) +@pytest.fixture() +def test_tool_auth(): + return ToolSchema( + description="Test Tool with Int,Bool+Auth", + parameters=[ + ParameterSchema(name="argA", type="integer", description="Argument A"), + ParameterSchema( + name="argB", + type="boolean", + description="Argument B", + authSources=["my-auth-service"], + ), + ], + ) + + @pytest.mark.asyncio async def test_load_tool_success(aioresponses, test_tool_str): """ @@ -83,6 +99,87 @@ async def test_load_tool_success(aioresponses, test_tool_str): assert await loaded_tool("some value") == "ok" +class TestAuth: + + @pytest.fixture + def expected_header(self): + return "some_token_for_testing" + + @pytest.fixture + def tool_name(self): + return "tool1" + + @pytest_asyncio.fixture + async def client(self, aioresponses, test_tool_auth, tool_name, expected_header): + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_auth} + ) + + # mock tool GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool INVOKE call + def require_headers(url, **kwargs): + if kwargs["headers"].get("my-auth-service_token") == expected_header: + return CallbackResult(status=200, body="{}") + else: + return CallbackResult(status=400, body="{}") + + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", + payload=manifest.model_dump(), + callback=require_headers, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + yield client + + @pytest.mark.asyncio + async def test_auth_with_load_tool_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool( + tool_name, auth_service_tokens={"my-auth-service": token_handler} + ) + res = await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_add_token_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool(tool_name) + tool = await client.add_auth_token_getters({"my-auth-service": token_handler}) + res = await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_load_tool_fail_no_token( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool(tool_name) + with pytest.raises(Exception): + res = await tool(5) + + @pytest.mark.asyncio async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): """Tests successfully loading a toolset with multiple tools.""" From 6b263ad2e18b1379eaa78dd68907216604ca2245 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:22:03 +0000 Subject: [PATCH 29/50] chore: add asyncio dep --- packages/toolbox-core/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index edc45a8a..ee8a5f73 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -44,7 +44,8 @@ test = [ "isort==6.0.1", "mypy==1.15.0", "pytest==8.3.5", - "pytest-aioresponses==0.3.0" + "pytest-aioresponses==0.3.0", + "pytest-asyncio==0.25.3", ] [build-system] requires = ["setuptools"] From 5fe541f863ad20ae20c1a8567481de1a8b6b4b95 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:24:23 +0000 Subject: [PATCH 30/50] chore: run itest --- packages/toolbox-core/tests/test_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 101fdb91..b6b1ff85 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -18,6 +18,7 @@ import pytest import pytest_asyncio from aioresponses import CallbackResult + from toolbox_core import ToolboxClient from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema From e9d7a3103d7bac63c45dd502427a0270e45ae96f Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:26:39 +0000 Subject: [PATCH 31/50] chore: add type hint --- packages/toolbox-core/src/toolbox_core/tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 494c7c21..17c44f63 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -91,7 +91,7 @@ def __copy( params: Optional[list[Parameter]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, - ): + ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. From 58c55cfc4d565a3255d07bbaa8bbd4b81ff8651e Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:32:36 +0000 Subject: [PATCH 32/50] fix: call tool instead of client --- packages/toolbox-core/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index b6b1ff85..aebf3133 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -164,7 +164,7 @@ def token_handler(): return expected_header tool = await client.load_tool(tool_name) - tool = await client.add_auth_token_getters({"my-auth-service": token_handler}) + tool = tool.add_auth_token_getters({"my-auth-service": token_handler}) res = await tool(5) @pytest.mark.asyncio From c1a482a762a5f7d97df3ff7667a7f243ad185fc2 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Mon, 31 Mar 2025 15:33:05 +0000 Subject: [PATCH 33/50] chore: correct arg name --- packages/toolbox-core/src/toolbox_core/client.py | 9 +++++++-- packages/toolbox-core/tests/test_client.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 0fcca0d2..1acd521a 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -119,7 +119,7 @@ async def close(self): async def load_tool( self, name: str, - auth_service_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -130,6 +130,8 @@ async def load_tool( Args: name: The unique name or identifier of the tool to load. + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. Returns: ToolboxTool: A callable object representing the loaded tool, ready @@ -148,7 +150,7 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name], auth_service_tokens) + tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters) return tool @@ -162,6 +164,9 @@ async def load_toolset( Args: name: Name of the toolset to load tools. + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. + Returns: list[ToolboxTool]: A list of callables, one for each tool defined diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index aebf3133..91ea2259 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -150,7 +150,7 @@ def token_handler(): return expected_header tool = await client.load_tool( - tool_name, auth_service_tokens={"my-auth-service": token_handler} + tool_name, auth_token_getters={"my-auth-service": token_handler} ) res = await tool(5) From c1ac2cd3033c64287cc1c273c75bc1c4fea88a75 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Wed, 2 Apr 2025 02:21:36 +0000 Subject: [PATCH 34/50] chore: address feedback --- .../toolbox-core/src/toolbox_core/client.py | 4 +- .../toolbox-core/src/toolbox_core/tool.py | 50 +++++++++++-------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 1acd521a..ca17a081 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import types from typing import Any, Callable, Optional @@ -79,8 +80,9 @@ def __parse_tool( name=name, desc=schema.description, params=[p.to_param() for p in params], + # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), - auth_service_token_getters=auth_token_getters, + auth_service_token_getters=types.MappingProxyType(auth_token_getters), ) return tool diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 17c44f63..bf9ddc34 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -56,9 +56,9 @@ def __init__( desc: The description of the remote tool (used as its docstring). params: A list of `inspect.Parameter` objects defining the tool's arguments and their types/defaults. - required_authn_params: A dict of required authenticated parameters that - need a auth_service_token_getter set for them yet. - auth_service_tokens: A dict of authService -> token (or callables that + required_authn_params: A dict of required authenticated parameters to a list + of services that provide values for them. + auth_service_token_getters: A dict of authService -> token (or callables that produce a token) """ @@ -108,15 +108,19 @@ def __copy( that produce a token) """ + check = lambda val, default: val if val is not None else default return ToolboxTool( - session=session or self.__session, - base_url=base_url or self.__base_url, - name=name or self.__name__, - desc=desc or self.__desc, - params=params or self.__params, - required_authn_params=required_authn_params or self.__required_authn_params, - auth_service_token_getters=auth_service_token_getters - or self.__auth_service_token_getters, + session=check(session, self.__session), + base_url=check(base_url, self.__base_url), + name=check(name, self.__name__), + desc=check(desc, self.__desc), + params=check(params, self.__params), + required_authn_params=check( + required_authn_params, self.__required_authn_params + ), + auth_service_token_getters=check( + auth_service_token_getters, self.__auth_service_token_getters + ), ) async def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -138,7 +142,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: if len(self.__required_authn_params) > 0: req_auth_services = set(l for l in self.__required_authn_params.keys()) raise Exception( - f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" ) # validate inputs to this call using the signature @@ -167,7 +171,7 @@ def add_auth_token_getters( auth_token_getters: Mapping[str, Callable[[], str]], ) -> "ToolboxTool": """ - Registers a auth token getter function that is used for AuthServices when tools + Registers an auth token getter function that is used for AuthServices when tools are invoked. Args: @@ -204,25 +208,27 @@ def add_auth_token_getters( def filter_required_authn_params( - req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str] + req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] ) -> dict[str, list[str]]: """ - Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services. + Utility function for reducing 'req_authn_params' to a subset of parameters that + aren't supplied by at least one service in auth_services. Args: req_authn_params: A mapping of parameter names to sets of required authentication services. - auth_services: An iterable of authentication service names for which + auth_service_names: An iterable of authentication service names for which token getters are available. Returns: A new dictionary representing the subset of required authentication - parameters that are not covered by the provided `auth_services`. + parameters that are not covered by the provided `auth_service_names`. """ - req_params = {} + required_params = {} # params that are still required with provided auth_services for param, services in req_authn_params.items(): - # if we don't have a token_getter for any of the services required by the param, the param is still required - required = not any(s in services for s in auth_services) + # if we don't have a token_getter for any of the services required by the param, + # the param is still required + required = not any(s in services for s in auth_service_names) if required: - req_params[param] = services - return req_params + required_params[param] = services + return required_params From bcba462a8275234da8c81a180935978219ede269 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Wed, 2 Apr 2025 02:30:48 +0000 Subject: [PATCH 35/50] chore: address more feedback --- .../toolbox-core/src/toolbox_core/client.py | 6 ++++-- .../toolbox-core/src/toolbox_core/tool.py | 21 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index ca17a081..9b69f6d1 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -18,7 +18,7 @@ from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool, filter_required_authn_params +from .tool import ToolboxTool, identify_required_authn_params class ToolboxClient: @@ -72,7 +72,9 @@ def __parse_tool( authn_params[p.name] = p.authSources auth_sources.update(p.authSources) - authn_params = filter_required_authn_params(authn_params, auth_sources) + authn_params = identify_required_authn_params( + authn_params, auth_token_getters.keys() + ) tool = ToolboxTool( session=self.__session, diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index bf9ddc34..6421cd99 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -140,7 +140,10 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # check if any auth services need to be specified yet if len(self.__required_authn_params) > 0: - req_auth_services = set(l for l in self.__required_authn_params.keys()) + # Gather all the required auth services into a set + req_auth_services = set() + for s in self.__required_authn_params.values(): + req_auth_services.update(s) raise Exception( f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" ) @@ -184,10 +187,12 @@ def add_auth_token_getters( """ # throw an error if the authentication source is already registered - dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys() - if dupes: + existing_services = self.__auth_service_token_getters.keys() + incoming_services = auth_token_getters.keys() + duplicates = existing_services & incoming_services + if duplicates: raise ValueError( - f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`." + f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`." ) # create a read-only updated value for new_getters @@ -196,7 +201,7 @@ def add_auth_token_getters( ) # create a read-only updated for params that are still required new_req_authn_params = types.MappingProxyType( - filter_required_authn_params( + identify_required_authn_params( self.__required_authn_params, auth_token_getters.keys() ) ) @@ -207,12 +212,12 @@ def add_auth_token_getters( ) -def filter_required_authn_params( +def identify_required_authn_params( req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] ) -> dict[str, list[str]]: """ - Utility function for reducing 'req_authn_params' to a subset of parameters that - aren't supplied by at least one service in auth_services. + Identifies authentication parameters that are still required; or not covered by + the provided `auth_service_names`. Args: req_authn_params: A mapping of parameter names to sets of required From c8491a920c0b246e9da101b6430e8dbd395f7bfc Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Mon, 31 Mar 2025 23:02:52 +0000 Subject: [PATCH 36/50] feat: add support for bound parameters --- .../toolbox-core/src/toolbox_core/client.py | 22 +++++-- .../toolbox-core/src/toolbox_core/tool.py | 61 +++++++++++++++++-- 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 9b69f6d1..3fd198e2 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -59,18 +59,22 @@ def __parse_tool( name: str, schema: ToolSchema, auth_token_getters: dict[str, Callable[[], str]], + all_bound_params: dict[str, Callable[[], str]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" - # sort into authenticated and reg params + # sort into reg, authn, and bound params params = [] authn_params: dict[str, list[str]] = {} + bound_params: dict[str, Callable[[], str]] = {} auth_sources: set[str] = set() for p in schema.parameters: - if not p.authSources: - params.append(p) - else: + if p.authSources: # authn parameter authn_params[p.name] = p.authSources auth_sources.update(p.authSources) + elif p.name in all_bound_params: # bound parameter + bound_params[p.name] = all_bound_params[p.name] + else: # regular parameter + params.append(p) authn_params = identify_required_authn_params( authn_params, auth_token_getters.keys() @@ -85,6 +89,7 @@ def __parse_tool( # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), auth_service_token_getters=types.MappingProxyType(auth_token_getters), + bound_params=types.MappingProxyType(bound_params), ) return tool @@ -124,6 +129,7 @@ async def load_tool( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -154,7 +160,9 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters) + tool = self.__parse_tool( + name, manifest.tools[name], auth_token_getters, bound_params + ) return tool @@ -162,6 +170,7 @@ async def load_toolset( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Callable[[], str]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. @@ -172,6 +181,7 @@ async def load_toolset( callables that return the corresponding authentication token. + Returns: list[ToolboxTool]: A list of callables, one for each tool defined in the toolset. @@ -184,7 +194,7 @@ async def load_toolset( # parse each tools name and schema into a list of ToolboxTools tools = [ - self.__parse_tool(n, s, auth_token_getters) + self.__parse_tool(n, s, auth_token_getters, bound_params) for n, s in manifest.tools.items() ] return tools diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 6421cd99..574706f4 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,10 +13,20 @@ # limitations under the License. +import asyncio import types from collections import defaultdict from inspect import Parameter, Signature -from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence +from typing import ( + Any, + Callable, + DefaultDict, + Iterable, + Mapping, + Optional, + Sequence, + Union, +) from aiohttp import ClientSession from pytest import Session @@ -44,6 +54,7 @@ def __init__( params: Sequence[Parameter], required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, ): """ Initializes a callable that will trigger the tool invocation through the @@ -81,6 +92,8 @@ def __init__( self.__required_authn_params = required_authn_params # map of authService -> token_getter self.__auth_service_token_getters = auth_service_token_getters + # map of parameter name to value or Callable + self.__bound_parameters = bound_params def __copy( self, @@ -91,6 +104,7 @@ def __copy( params: Optional[list[Parameter]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, + bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. @@ -121,6 +135,7 @@ def __copy( auth_service_token_getters=check( auth_service_token_getters, self.__auth_service_token_getters ), + bound_params=check(bound_params, self.__bound_parameters), ) async def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -153,6 +168,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: all_args.apply_defaults() # Include default values if not provided payload = all_args.arguments + # apply bounded parameters + for param, value in self.__bound_parameters.items(): + if asyncio.iscoroutinefunction(value): + value = await value() + elif callable(value): + value = value() + payload[param] = value + # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): @@ -211,13 +234,41 @@ def add_auth_token_getters( required_authn_params=new_req_authn_params, ) + def bind_parameters( + self, bound_params: Mapping[str, Callable[[], str]] + ) -> "ToolboxTool": + """ + Binds parameters to values or callables that produce values. + + Args: + bound_params: A mapping of parameter names to values or callables that + produce values. + + Returns: + A new ToolboxTool instance with the specified parameters bound. + """ + all_params = set(p.name for p in self.__params) + for name in bound_params.keys(): + if name not in all_params: + raise Exception(f"unable to bind parameters: no parameter named {name}") + + new_params = [] + for p in self.__params: + if p.name not in bound_params: + new_params.append(p) + + return self.__copy( + params=new_params, + bound_params=bound_params, + ) + def identify_required_authn_params( req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] ) -> dict[str, list[str]]: """ - Identifies authentication parameters that are still required; or not covered by - the provided `auth_service_names`. + Identifies authentication parameters that are still required; because they + not covered by the provided `auth_service_names`. Args: req_authn_params: A mapping of parameter names to sets of required @@ -226,8 +277,8 @@ def identify_required_authn_params( token getters are available. Returns: - A new dictionary representing the subset of required authentication - parameters that are not covered by the provided `auth_service_names`. + A new dictionary representing the subset of required authentication parameters + that are not covered by the provided `auth_services`. """ required_params = {} # params that are still required with provided auth_services for param, services in req_authn_params.items(): From c26c453bad514f3b50c7c2b46000007d32e18380 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Mon, 31 Mar 2025 23:03:47 +0000 Subject: [PATCH 37/50] chore: add tests for bound parameters --- packages/toolbox-core/tests/test_client.py | 133 +++++++++++++++++---- 1 file changed, 113 insertions(+), 20 deletions(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 91ea2259..e25573e2 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -14,6 +14,7 @@ import inspect +import json import pytest import pytest_asyncio @@ -100,6 +101,31 @@ async def test_load_tool_success(aioresponses, test_tool_str): assert await loaded_tool("some value") == "ok" +@pytest.mark.asyncio +async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): + """Tests successfully loading a toolset with multiple tools.""" + TOOLSET_NAME = "my_toolset" + TOOL1 = "tool1" + TOOL2 = "tool2" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} + ) + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", + payload=manifest.model_dump(), + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + tools = await client.load_toolset(TOOLSET_NAME) + + assert isinstance(tools, list) + assert len(tools) == len(manifest.tools) + + # Check if tools were created correctly + assert {t.__name__ for t in tools} == manifest.tools.keys() + + class TestAuth: @pytest.fixture @@ -181,26 +207,93 @@ def token_handler(): res = await tool(5) -@pytest.mark.asyncio -async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): - """Tests successfully loading a toolset with multiple tools.""" - TOOLSET_NAME = "my_toolset" - TOOL1 = "tool1" - TOOL2 = "tool2" - manifest = ManifestSchema( - serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} - ) - aioresponses.get( - f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", - payload=manifest.model_dump(), - status=200, - ) +class TestBoundParameter: - async with ToolboxClient(TEST_BASE_URL) as client: - tools = await client.load_toolset(TOOLSET_NAME) + @pytest.fixture + def tool_name(self): + return "tool1" - assert isinstance(tools, list) - assert len(tools) == len(manifest.tools) + @pytest_asyncio.fixture + async def client(self, aioresponses, test_tool_int_bool, tool_name): + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_int_bool} + ) - # Check if tools were created correctly - assert {t.__name__ for t in tools} == manifest.tools.keys() + # mock toolset GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool INVOKE call + def reflect_parameters(url, **kwargs): + body = {"result": kwargs["json"]} + return CallbackResult(status=200, body=json.dumps(body)) + + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", + payload=manifest.model_dump(), + callback=reflect_parameters, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + yield client + + @pytest.mark.asyncio + async def test_load_tool_success(self, tool_name, client): + """Tests 'load_tool' with a bound parameter specified.""" + tool = await client.load_tool(tool_name, bound_params={"argA": lambda: 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_load_toolset_success(self, tool_name, client): + """Tests 'load_toolset' with a bound parameter specified.""" + tools = await client.load_toolset("", bound_params={"argB": lambda: "hello"}) + tool = tools[0] + + assert len(tool.__signature__.parameters) == 1 + assert "argB" not in tool.__signature__.parameters + + res = await tool(True) + assert "argB" in res + + @pytest.mark.asyncio + async def test_bind_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_parameters({"argA": lambda: 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_param_fail(self, tool_name, client): + """Tests 'bind_param' with a bound parameter that doesn't exist.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + with pytest.raises(Exception): + tool = tool.bind_parameters({"argC": lambda: 5}) From 3abe9e469476416067580c62f18d16b7990e2c1b Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Wed, 2 Apr 2025 02:48:14 +0000 Subject: [PATCH 38/50] chore: address feedback --- .../toolbox-core/src/toolbox_core/client.py | 14 ++++++++--- .../toolbox-core/src/toolbox_core/tool.py | 25 +++++++++++-------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 3fd198e2..ecf67f0c 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -13,7 +13,7 @@ # limitations under the License. import re import types -from typing import Any, Callable, Optional +from typing import Any, Callable, Mapping, Optional, Union from aiohttp import ClientSession @@ -59,7 +59,7 @@ def __parse_tool( name: str, schema: ToolSchema, auth_token_getters: dict[str, Callable[[], str]], - all_bound_params: dict[str, Callable[[], str]], + all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" # sort into reg, authn, and bound params @@ -129,7 +129,7 @@ async def load_tool( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Callable[[], str]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -142,6 +142,10 @@ async def load_tool( name: The unique name or identifier of the tool to load. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. + + Returns: ToolboxTool: A callable object representing the loaded tool, ready @@ -170,7 +174,7 @@ async def load_toolset( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Callable[[], str]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. @@ -179,6 +183,8 @@ async def load_toolset( name: Name of the toolset to load tools. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 574706f4..58180341 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -54,7 +54,7 @@ def __init__( params: Sequence[Parameter], required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], - bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]], ): """ Initializes a callable that will trigger the tool invocation through the @@ -71,6 +71,9 @@ def __init__( of services that provide values for them. auth_service_token_getters: A dict of authService -> token (or callables that produce a token) + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. + """ # used to invoke the toolbox API @@ -92,7 +95,7 @@ def __init__( self.__required_authn_params = required_authn_params # map of authService -> token_getter self.__auth_service_token_getters = auth_service_token_getters - # map of parameter name to value or Callable + # map of parameter name to value (or callable that produces that value) self.__bound_parameters = bound_params def __copy( @@ -120,6 +123,8 @@ def __copy( a auth_service_token_getter set for them yet. auth_service_token_getters: A dict of authService -> token (or callables that produce a token) + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. """ check = lambda val, default: val if val is not None else default @@ -235,7 +240,7 @@ def add_auth_token_getters( ) def bind_parameters( - self, bound_params: Mapping[str, Callable[[], str]] + self, bound_params: Mapping[str, Union[Callable[[], Any], Any]] ) -> "ToolboxTool": """ Binds parameters to values or callables that produce values. @@ -247,9 +252,9 @@ def bind_parameters( Returns: A new ToolboxTool instance with the specified parameters bound. """ - all_params = set(p.name for p in self.__params) + param_names = set(p.name for p in self.__params) for name in bound_params.keys(): - if name not in all_params: + if name not in param_names: raise Exception(f"unable to bind parameters: no parameter named {name}") new_params = [] @@ -270,11 +275,11 @@ def identify_required_authn_params( Identifies authentication parameters that are still required; because they not covered by the provided `auth_service_names`. - Args: - req_authn_params: A mapping of parameter names to sets of required - authentication services. - auth_service_names: An iterable of authentication service names for which - token getters are available. + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_service_names: An iterable of authentication service names for which + token getters are available. Returns: A new dictionary representing the subset of required authentication parameters From 9482f37482a3e0b31f67d2639ecd9c8eefae9421 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Wed, 2 Apr 2025 11:25:47 +0530 Subject: [PATCH 39/50] revert package file changes --- .../toolbox-core/src/toolbox_core/client.py | 22 ++--- .../toolbox-core/src/toolbox_core/tool.py | 89 ++++++++----------- 2 files changed, 44 insertions(+), 67 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 27ad72be..5afaa4db 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import re import types -from typing import Any, Callable, Mapping, Optional, Union +from typing import Any, Callable, Optional from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool, identify_required_authn_params +from .tool import ToolboxTool, filter_required_authn_params class ToolboxClient: @@ -60,7 +58,7 @@ def __parse_tool( name: str, schema: ToolSchema, auth_token_getters: dict[str, Callable[[], str]], - all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], + all_bound_params: dict[str, Callable[[], str]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" # sort into reg, authn, and bound params @@ -77,9 +75,7 @@ def __parse_tool( else: # regular parameter params.append(p) - authn_params = identify_required_authn_params( - authn_params, auth_token_getters.keys() - ) + authn_params = filter_required_authn_params(authn_params, auth_sources) tool = ToolboxTool( session=self.__session, @@ -87,8 +83,6 @@ def __parse_tool( name=name, desc=schema.description, params=[p.to_param() for p in params], - - # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), auth_service_token_getters=types.MappingProxyType(auth_token_getters), bound_params=types.MappingProxyType(bound_params), @@ -131,7 +125,7 @@ async def load_tool( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + bound_params: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -144,8 +138,6 @@ async def load_tool( name: The unique name or identifier of the tool to load. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. - bound_params: A mapping of parameter names to bind to specific values or - callables that are called to produce values as needed. Returns: ToolboxTool: A callable object representing the loaded tool, ready @@ -174,7 +166,7 @@ async def load_toolset( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + bound_params: dict[str, Callable[[], str]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. @@ -183,8 +175,6 @@ async def load_toolset( name: Name of the toolset to load tools. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. - bound_params: A mapping of parameter names to bind to specific values or - callables that are called to produce values as needed. diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index d857b6e1..2e88c179 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -54,7 +54,7 @@ def __init__( params: Sequence[Parameter], required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], - bound_params: Mapping[str, Union[Callable[[], Any], Any]], + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, ): """ Initializes a callable that will trigger the tool invocation through the @@ -67,12 +67,10 @@ def __init__( desc: The description of the remote tool (used as its docstring). params: A list of `inspect.Parameter` objects defining the tool's arguments and their types/defaults. - required_authn_params: A dict of required authenticated parameters to a list - of services that provide values for them. - auth_service_token_getters: A dict of authService -> token (or callables that + required_authn_params: A dict of required authenticated parameters that + need a auth_service_token_getter set for them yet. + auth_service_tokens: A dict of authService -> token (or callables that produce a token) - bound_params: A mapping of parameter names to bind to specific values or - callables that are called to produce values as needed. """ # used to invoke the toolbox API @@ -94,7 +92,7 @@ def __init__( self.__required_authn_params = required_authn_params # map of authService -> token_getter self.__auth_service_token_getters = auth_service_token_getters - # map of parameter name to value (or callable that produces that value) + # map of parameter name to value or Callable self.__bound_parameters = bound_params def __copy( @@ -122,24 +120,18 @@ def __copy( a auth_service_token_getter set for them yet. auth_service_token_getters: A dict of authService -> token (or callables that produce a token) - bound_params: A mapping of parameter names to bind to specific values or - callables that are called to produce values as needed. """ - check = lambda val, default: val if val is not None else default return ToolboxTool( - session=check(session, self.__session), - base_url=check(base_url, self.__base_url), - name=check(name, self.__name__), - desc=check(desc, self.__desc), - params=check(params, self.__params), - required_authn_params=check( - required_authn_params, self.__required_authn_params - ), - auth_service_token_getters=check( - auth_service_token_getters, self.__auth_service_token_getters - ), - bound_params=check(bound_params, self.__bound_parameters), + session=session or self.__session, + base_url=base_url or self.__base_url, + name=name or self.__name__, + desc=desc or self.__desc, + params=params or self.__params, + required_authn_params=required_authn_params or self.__required_authn_params, + auth_service_token_getters=auth_service_token_getters + or self.__auth_service_token_getters, + bound_params=bound_params or self.__bound_parameters, ) async def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -159,12 +151,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # check if any auth services need to be specified yet if len(self.__required_authn_params) > 0: - # Gather all the required auth services into a set - req_auth_services = set() - for s in self.__required_authn_params.values(): - req_auth_services.update(s) + req_auth_services = set(l for l in self.__required_authn_params.keys()) raise Exception( - f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" ) # validate inputs to this call using the signature @@ -214,12 +203,10 @@ def add_auth_token_getters( """ # throw an error if the authentication source is already registered - existing_services = self.__auth_service_token_getters.keys() - incoming_services = auth_token_getters.keys() - duplicates = existing_services & incoming_services - if duplicates: + dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys() + if dupes: raise ValueError( - f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`." + f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`." ) # create a read-only updated value for new_getters @@ -228,7 +215,7 @@ def add_auth_token_getters( ) # create a read-only updated for params that are still required new_req_authn_params = types.MappingProxyType( - identify_required_authn_params( + filter_required_authn_params( self.__required_authn_params, auth_token_getters.keys() ) ) @@ -239,7 +226,7 @@ def add_auth_token_getters( ) def bind_parameters( - self, bound_params: Mapping[str, Union[Callable[[], Any], Any]] + self, bound_params: Mapping[str, Callable[[], str]] ) -> "ToolboxTool": """ Binds parameters to values or callables that produce values. @@ -251,9 +238,9 @@ def bind_parameters( Returns: A new ToolboxTool instance with the specified parameters bound. """ - param_names = set(p.name for p in self.__params) + all_params = set(p.name for p in self.__params) for name in bound_params.keys(): - if name not in param_names: + if name not in all_params: raise Exception(f"unable to bind parameters: no parameter named {name}") new_params = [] @@ -266,28 +253,28 @@ def bind_parameters( bound_params=bound_params, ) -def identify_required_authn_params( - req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] + +def filter_required_authn_params( + req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str] ) -> dict[str, list[str]]: """ - Identifies authentication parameters that are still required; because they - not covered by the provided `auth_service_names`. + Utility function for reducing 'req_authn_params' to a subset of parameters that + aren't supplied by a least one service in auth_services. - Args: - req_authn_params: A mapping of parameter names to sets of required - authentication services. - auth_service_names: An iterable of authentication service names for which - token getters are available. + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_services: An iterable of authentication service names for which + token getters are available. Returns: A new dictionary representing the subset of required authentication parameters that are not covered by the provided `auth_services`. """ - required_params = {} # params that are still required with provided auth_services + req_params = {} for param, services in req_authn_params.items(): - # if we don't have a token_getter for any of the services required by the param, - # the param is still required - required = not any(s in services for s in auth_service_names) + # if we don't have a token_getter for any of the services required by the param, the param is still required + required = not any(s in services for s in auth_services) if required: - required_params[param] = services - return required_params \ No newline at end of file + req_params[param] = services + return req_params From 6dca7f9b7e930b2f5725d0812bb8e75ae2a1a891 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Wed, 2 Apr 2025 11:27:36 +0530 Subject: [PATCH 40/50] fix error message --- packages/toolbox-core/tests/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 8854eb4d..5ce8e1c0 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -154,7 +154,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): # TODO: Change match to {my-auth-service3} instead of email after fix in PR with pytest.raises( Exception, - match="One of more of the following authn services are required to invoke this tool: email", + match="One of more of the following authn services are required to invoke this tool: my-test-auth", ): await tool() From 4d562bd4d831ce6de03bc0530656eb623d47388b Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Wed, 2 Apr 2025 11:30:20 +0530 Subject: [PATCH 41/50] revert package files --- .../toolbox-core/src/toolbox_core/client.py | 24 +++-- .../toolbox-core/src/toolbox_core/tool.py | 91 +++++++++++-------- 2 files changed, 70 insertions(+), 45 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 5afaa4db..a5da65e4 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import types -from typing import Any, Callable, Optional +from typing import Any, Callable, Mapping, Optional, Union from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool, filter_required_authn_params +from .tool import ToolboxTool, identify_required_authn_params class ToolboxClient: @@ -58,7 +59,7 @@ def __parse_tool( name: str, schema: ToolSchema, auth_token_getters: dict[str, Callable[[], str]], - all_bound_params: dict[str, Callable[[], str]], + all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" # sort into reg, authn, and bound params @@ -75,7 +76,9 @@ def __parse_tool( else: # regular parameter params.append(p) - authn_params = filter_required_authn_params(authn_params, auth_sources) + authn_params = identify_required_authn_params( + authn_params, auth_token_getters.keys() + ) tool = ToolboxTool( session=self.__session, @@ -83,6 +86,7 @@ def __parse_tool( name=name, desc=schema.description, params=[p.to_param() for p in params], + # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), auth_service_token_getters=types.MappingProxyType(auth_token_getters), bound_params=types.MappingProxyType(bound_params), @@ -125,7 +129,7 @@ async def load_tool( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Callable[[], str]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -138,6 +142,10 @@ async def load_tool( name: The unique name or identifier of the tool to load. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. + + Returns: ToolboxTool: A callable object representing the loaded tool, ready @@ -166,7 +174,7 @@ async def load_toolset( self, name: str, auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Callable[[], str]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. @@ -175,6 +183,8 @@ async def load_toolset( name: Name of the toolset to load tools. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. @@ -193,4 +203,4 @@ async def load_toolset( self.__parse_tool(n, s, auth_token_getters, bound_params) for n, s in manifest.tools.items() ] - return tools + return tools \ No newline at end of file diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 2e88c179..b6040430 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -54,7 +54,7 @@ def __init__( params: Sequence[Parameter], required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], - bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]], ): """ Initializes a callable that will trigger the tool invocation through the @@ -67,10 +67,13 @@ def __init__( desc: The description of the remote tool (used as its docstring). params: A list of `inspect.Parameter` objects defining the tool's arguments and their types/defaults. - required_authn_params: A dict of required authenticated parameters that - need a auth_service_token_getter set for them yet. - auth_service_tokens: A dict of authService -> token (or callables that + required_authn_params: A dict of required authenticated parameters to a list + of services that provide values for them. + auth_service_token_getters: A dict of authService -> token (or callables that produce a token) + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. + """ # used to invoke the toolbox API @@ -92,7 +95,7 @@ def __init__( self.__required_authn_params = required_authn_params # map of authService -> token_getter self.__auth_service_token_getters = auth_service_token_getters - # map of parameter name to value or Callable + # map of parameter name to value (or callable that produces that value) self.__bound_parameters = bound_params def __copy( @@ -120,18 +123,24 @@ def __copy( a auth_service_token_getter set for them yet. auth_service_token_getters: A dict of authService -> token (or callables that produce a token) + bound_params: A mapping of parameter names to bind to specific values or + callables that are called to produce values as needed. """ + check = lambda val, default: val if val is not None else default return ToolboxTool( - session=session or self.__session, - base_url=base_url or self.__base_url, - name=name or self.__name__, - desc=desc or self.__desc, - params=params or self.__params, - required_authn_params=required_authn_params or self.__required_authn_params, - auth_service_token_getters=auth_service_token_getters - or self.__auth_service_token_getters, - bound_params=bound_params or self.__bound_parameters, + session=check(session, self.__session), + base_url=check(base_url, self.__base_url), + name=check(name, self.__name__), + desc=check(desc, self.__desc), + params=check(params, self.__params), + required_authn_params=check( + required_authn_params, self.__required_authn_params + ), + auth_service_token_getters=check( + auth_service_token_getters, self.__auth_service_token_getters + ), + bound_params=check(bound_params, self.__bound_parameters), ) async def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -151,9 +160,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # check if any auth services need to be specified yet if len(self.__required_authn_params) > 0: - req_auth_services = set(l for l in self.__required_authn_params.keys()) + # Gather all the required auth services into a set + req_auth_services = set() + for s in self.__required_authn_params.values(): + req_auth_services.update(s) raise Exception( - f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" ) # validate inputs to this call using the signature @@ -190,7 +202,7 @@ def add_auth_token_getters( auth_token_getters: Mapping[str, Callable[[], str]], ) -> "ToolboxTool": """ - Registers a auth token getter function that is used for AuthServices when tools + Registers an auth token getter function that is used for AuthServices when tools are invoked. Args: @@ -203,10 +215,12 @@ def add_auth_token_getters( """ # throw an error if the authentication source is already registered - dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys() - if dupes: + existing_services = self.__auth_service_token_getters.keys() + incoming_services = auth_token_getters.keys() + duplicates = existing_services & incoming_services + if duplicates: raise ValueError( - f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`." + f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`." ) # create a read-only updated value for new_getters @@ -215,7 +229,7 @@ def add_auth_token_getters( ) # create a read-only updated for params that are still required new_req_authn_params = types.MappingProxyType( - filter_required_authn_params( + identify_required_authn_params( self.__required_authn_params, auth_token_getters.keys() ) ) @@ -226,7 +240,7 @@ def add_auth_token_getters( ) def bind_parameters( - self, bound_params: Mapping[str, Callable[[], str]] + self, bound_params: Mapping[str, Union[Callable[[], Any], Any]] ) -> "ToolboxTool": """ Binds parameters to values or callables that produce values. @@ -238,9 +252,9 @@ def bind_parameters( Returns: A new ToolboxTool instance with the specified parameters bound. """ - all_params = set(p.name for p in self.__params) + param_names = set(p.name for p in self.__params) for name in bound_params.keys(): - if name not in all_params: + if name not in param_names: raise Exception(f"unable to bind parameters: no parameter named {name}") new_params = [] @@ -254,27 +268,28 @@ def bind_parameters( ) -def filter_required_authn_params( - req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str] +def identify_required_authn_params( + req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] ) -> dict[str, list[str]]: """ - Utility function for reducing 'req_authn_params' to a subset of parameters that - aren't supplied by a least one service in auth_services. + Identifies authentication parameters that are still required; because they + not covered by the provided `auth_service_names`. - Args: - req_authn_params: A mapping of parameter names to sets of required - authentication services. - auth_services: An iterable of authentication service names for which - token getters are available. + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_service_names: An iterable of authentication service names for which + token getters are available. Returns: A new dictionary representing the subset of required authentication parameters that are not covered by the provided `auth_services`. """ - req_params = {} + required_params = {} # params that are still required with provided auth_services for param, services in req_authn_params.items(): - # if we don't have a token_getter for any of the services required by the param, the param is still required - required = not any(s in services for s in auth_services) + # if we don't have a token_getter for any of the services required by the param, + # the param is still required + required = not any(s in services for s in auth_service_names) if required: - req_params[param] = services - return req_params + required_params[param] = services + return required_params \ No newline at end of file From cc7ccad26511e0014bff00e2b8f660f6bf153797 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Wed, 2 Apr 2025 11:31:33 +0530 Subject: [PATCH 42/50] lint --- packages/toolbox-core/src/toolbox_core/client.py | 2 +- packages/toolbox-core/src/toolbox_core/tool.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index a5da65e4..ecf67f0c 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -203,4 +203,4 @@ async def load_toolset( self.__parse_tool(n, s, auth_token_getters, bound_params) for n, s in manifest.tools.items() ] - return tools \ No newline at end of file + return tools diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index b6040430..58180341 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -292,4 +292,4 @@ def identify_required_authn_params( required = not any(s in services for s in auth_service_names) if required: required_params[param] = services - return required_params \ No newline at end of file + return required_params From 84e66a683f2d8c7ab6fdf74f8ad70797ec1b2424 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Wed, 2 Apr 2025 11:36:24 +0530 Subject: [PATCH 43/50] fix error message --- packages/toolbox-core/tests/test_e2e.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 5ce8e1c0..451ae7d2 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -151,10 +151,9 @@ async def test_run_tool_auth(self, toolbox, auth_token1): async def test_run_tool_param_auth_no_auth(self, toolbox): """Tests running a tool with a param requiring auth, without auth.""" tool = await toolbox.load_tool("get-row-by-email-auth") - # TODO: Change match to {my-auth-service3} instead of email after fix in PR with pytest.raises( Exception, - match="One of more of the following authn services are required to invoke this tool: my-test-auth", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool() From fe2a3325e16bce2cec753433d67ff3e29a70cd7c Mon Sep 17 00:00:00 2001 From: Twisha Bansal <58483338+twishabansal@users.noreply.github.com> Date: Wed, 2 Apr 2025 12:01:19 +0530 Subject: [PATCH 44/50] Update packages/toolbox-core/tests/test_e2e.py Co-authored-by: Anubhav Dhawan --- packages/toolbox-core/tests/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 451ae7d2..2f4d8d0a 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -14,7 +14,7 @@ """End-to-end tests for the toolbox SDK interacting with the toolbox server. -This file covers the following use cases: +This file covers the following test cases: 1. Loading a tool. 2. Loading a specific toolset. From b05ea396d81daa62cfe9a4c23a36cd99ffed14ab Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Wed, 2 Apr 2025 16:34:10 +0530 Subject: [PATCH 45/50] add new test case --- packages/toolbox-core/tests/test_e2e.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 2f4d8d0a..3cf2aa8f 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -98,6 +98,16 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): ##### Bind param tests async def test_bind_params(self, toolbox, get_n_rows_tool): + new_tool = get_n_rows_tool.bind_parameters({"num_rows": "3"}) + response = await new_tool() + + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" in response + assert "row4" not in response + + async def test_bind_params_callable(self, toolbox, get_n_rows_tool): new_tool = get_n_rows_tool.bind_parameters({"num_rows": lambda: "3"}) response = await new_tool() From 7d3b77aa0459ba8b7b7327274304a2ce5835c3d5 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Wed, 2 Apr 2025 16:36:26 +0530 Subject: [PATCH 46/50] change docstring to reflect new test cases --- packages/toolbox-core/tests/test_e2e.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 3cf2aa8f..2d6ba650 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -32,6 +32,9 @@ a. No auth provided. b. Correct auth provided. c. Auth provided does not contain the required claim. +8. Bind params to a tool + a. Static param + b. Callable param value """ import pytest import pytest_asyncio From 755d3e9967128371047f57c801e6ae4997c23ae7 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 3 Apr 2025 10:46:50 +0530 Subject: [PATCH 47/50] clean up docstring --- packages/toolbox-core/tests/test_e2e.py | 35 +++++++------------------ 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 2d6ba650..3fe513d7 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -11,31 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""End-to-end tests for the toolbox SDK interacting with the toolbox server. - -This file covers the following test cases: - -1. Loading a tool. -2. Loading a specific toolset. -3. Loading the default toolset (contains all tools). -4. Running a tool with - a. Missing params. - b. Wrong param type. -5. Running a tool with no required auth, with auth provided. -6. Running a tool with required auth: - a. No auth provided. - b. Wrong auth provided: The tool requires a different authentication - than the one provided. - c. Correct auth provided. -7. Running a tool with a parameter that requires auth: - a. No auth provided. - b. Correct auth provided. - c. Auth provided does not contain the required claim. -8. Bind params to a tool - a. Static param - b. Callable param value -""" import pytest import pytest_asyncio @@ -56,6 +31,7 @@ async def toolbox(self): @pytest_asyncio.fixture(scope="function") async def get_n_rows_tool(self, toolbox: ToolboxClient) -> ToolboxTool: + """Load a tool.""" tool = await toolbox.load_tool("get-n-rows") assert tool.__name__ == "get-n-rows" return tool @@ -75,12 +51,14 @@ async def test_load_toolset_specific( expected_length: int, expected_tools: list[str], ): + """Load a specific toolset""" toolset = await toolbox.load_toolset(toolset_name) assert len(toolset) == expected_length tool_names = {tool.__name__ for tool in toolset} assert tool_names == set(expected_tools) async def test_run_tool(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool.""" response = await get_n_rows_tool(num_rows="2") assert isinstance(response, str) @@ -89,10 +67,12 @@ async def test_run_tool(self, get_n_rows_tool: ToolboxTool): assert "row3" not in response async def test_run_tool_missing_params(self, get_n_rows_tool): + """Invoke a tool with missing params.""" with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): await get_n_rows_tool() async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool with wrong param type.""" with pytest.raises( Exception, match='provided parameters were invalid: unable to parse value for "num_rows": .* not type "string"', @@ -101,6 +81,7 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): ##### Bind param tests async def test_bind_params(self, toolbox, get_n_rows_tool): + """Bind a param to an existing tool.""" new_tool = get_n_rows_tool.bind_parameters({"num_rows": "3"}) response = await new_tool() @@ -111,6 +92,7 @@ async def test_bind_params(self, toolbox, get_n_rows_tool): assert "row4" not in response async def test_bind_params_callable(self, toolbox, get_n_rows_tool): + """Bind a callable param to an existing tool.""" new_tool = get_n_rows_tool.bind_parameters({"num_rows": lambda: "3"}) response = await new_tool() @@ -141,7 +123,8 @@ async def test_run_tool_no_auth(self, toolbox): await tool(id="2") async def test_run_tool_wrong_auth(self, toolbox, auth_token2): - """Tests running a tool with incorrect auth.""" + """Tests running a tool with incorrect auth. The tool + requires a different authentication than the one provided.""" tool = await toolbox.load_tool( "get-row-by-id-auth", ) From 738b984e62772d35d2834b8c5a55ed07187fc65e Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 3 Apr 2025 10:49:02 +0530 Subject: [PATCH 48/50] lint --- packages/toolbox-core/tests/test_e2e.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 3fe513d7..e776607e 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -113,9 +113,7 @@ async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): async def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" - tool = await toolbox.load_tool( - "get-row-by-id-auth", - ) + tool = await toolbox.load_tool("get-row-by-id-auth") with pytest.raises( Exception, match="tool invocation not authorized. Please make sure your specify correct auth headers", @@ -125,9 +123,7 @@ async def test_run_tool_no_auth(self, toolbox): async def test_run_tool_wrong_auth(self, toolbox, auth_token2): """Tests running a tool with incorrect auth. The tool requires a different authentication than the one provided.""" - tool = await toolbox.load_tool( - "get-row-by-id-auth", - ) + tool = await toolbox.load_tool("get-row-by-id-auth") auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token2}) with pytest.raises( Exception, @@ -137,9 +133,7 @@ async def test_run_tool_wrong_auth(self, toolbox, auth_token2): async def test_run_tool_auth(self, toolbox, auth_token1): """Tests running a tool with correct auth.""" - tool = await toolbox.load_tool( - "get-row-by-id-auth", - ) + tool = await toolbox.load_tool("get-row-by-id-auth") auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token1}) response = await auth_tool(id="2") assert "row2" in response From 76830638ba7bda3260902b890afdabca9aa3e675 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 3 Apr 2025 11:00:39 +0530 Subject: [PATCH 49/50] Move tests to different classes --- packages/toolbox-core/tests/test_e2e.py | 77 +++++++++++++++---------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index e776607e..43f4d0f8 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -18,25 +18,28 @@ from toolbox_core.tool import ToolboxTool +# --- Shared Fixtures Defined at Module Level --- +@pytest_asyncio.fixture(scope="function") +async def toolbox(): + """Creates a ToolboxClient instance shared by all tests in this module.""" + toolbox = ToolboxClient("http://localhost:5000") + try: + yield toolbox + finally: + await toolbox.close() + + +@pytest_asyncio.fixture(scope="function") +async def get_n_rows_tool(toolbox: ToolboxClient) -> ToolboxTool: + """Load the 'get-n-rows' tool using the shared toolbox client.""" + tool = await toolbox.load_tool("get-n-rows") + assert tool.__name__ == "get-n-rows" + return tool + + @pytest.mark.asyncio @pytest.mark.usefixtures("toolbox_server") -class TestE2EClient: - @pytest_asyncio.fixture(scope="function") - async def toolbox(self): - toolbox = ToolboxClient("http://localhost:5000") - try: - yield toolbox - finally: - await toolbox.close() - - @pytest_asyncio.fixture(scope="function") - async def get_n_rows_tool(self, toolbox: ToolboxClient) -> ToolboxTool: - """Load a tool.""" - tool = await toolbox.load_tool("get-n-rows") - assert tool.__name__ == "get-n-rows" - return tool - - #### Basic e2e tests +class TestBasicE2E: @pytest.mark.parametrize( "toolset_name, expected_length, expected_tools", [ @@ -66,7 +69,7 @@ async def test_run_tool(self, get_n_rows_tool: ToolboxTool): assert "row2" in response assert "row3" not in response - async def test_run_tool_missing_params(self, get_n_rows_tool): + async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool): """Invoke a tool with missing params.""" with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): await get_n_rows_tool() @@ -79,31 +82,41 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): ): await get_n_rows_tool(num_rows=2) - ##### Bind param tests - async def test_bind_params(self, toolbox, get_n_rows_tool): + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestBindParams: + async def test_bind_params( + self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool + ): """Bind a param to an existing tool.""" new_tool = get_n_rows_tool.bind_parameters({"num_rows": "3"}) response = await new_tool() - assert isinstance(response, str) assert "row1" in response assert "row2" in response assert "row3" in response assert "row4" not in response - async def test_bind_params_callable(self, toolbox, get_n_rows_tool): + async def test_bind_params_callable( + self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool + ): """Bind a callable param to an existing tool.""" new_tool = get_n_rows_tool.bind_parameters({"num_rows": lambda: "3"}) response = await new_tool() - assert isinstance(response, str) assert "row1" in response assert "row2" in response assert "row3" in response assert "row4" not in response - ##### Auth tests - async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestAuth: + async def test_run_tool_unauth_with_auth( + self, toolbox: ToolboxClient, auth_token2: str + ): """Tests running a tool that doesn't require auth, with auth provided.""" tool = await toolbox.load_tool( "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} @@ -111,7 +124,7 @@ async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): response = await tool(id="2") assert "row2" in response - async def test_run_tool_no_auth(self, toolbox): + async def test_run_tool_no_auth(self, toolbox: ToolboxClient): """Tests running a tool requiring auth without providing auth.""" tool = await toolbox.load_tool("get-row-by-id-auth") with pytest.raises( @@ -120,7 +133,7 @@ async def test_run_tool_no_auth(self, toolbox): ): await tool(id="2") - async def test_run_tool_wrong_auth(self, toolbox, auth_token2): + async def test_run_tool_wrong_auth(self, toolbox: ToolboxClient, auth_token2: str): """Tests running a tool with incorrect auth. The tool requires a different authentication than the one provided.""" tool = await toolbox.load_tool("get-row-by-id-auth") @@ -131,14 +144,14 @@ async def test_run_tool_wrong_auth(self, toolbox, auth_token2): ): await auth_tool(id="2") - async def test_run_tool_auth(self, toolbox, auth_token1): + async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str): """Tests running a tool with correct auth.""" tool = await toolbox.load_tool("get-row-by-id-auth") auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token1}) response = await auth_tool(id="2") assert "row2" in response - async def test_run_tool_param_auth_no_auth(self, toolbox): + async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient): """Tests running a tool with a param requiring auth, without auth.""" tool = await toolbox.load_tool("get-row-by-email-auth") with pytest.raises( @@ -147,7 +160,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): ): await tool() - async def test_run_tool_param_auth(self, toolbox, auth_token1): + async def test_run_tool_param_auth(self, toolbox: ToolboxClient, auth_token1: str): """Tests running a tool with a param requiring auth, with correct auth.""" tool = await toolbox.load_tool( "get-row-by-email-auth", @@ -158,7 +171,9 @@ async def test_run_tool_param_auth(self, toolbox, auth_token1): assert "row5" in response assert "row6" in response - async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): + async def test_run_tool_param_auth_no_field( + self, toolbox: ToolboxClient, auth_token1: str + ): """Tests running a tool with a param requiring auth, with insufficient auth.""" tool = await toolbox.load_tool( "get-row-by-content-auth", From 49203ac170e26fa6281cf897acb78d1d163e5ef2 Mon Sep 17 00:00:00 2001 From: Twisha Bansal Date: Thu, 3 Apr 2025 11:11:34 +0530 Subject: [PATCH 50/50] add timeout --- packages/toolbox-core/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/conftest.py b/packages/toolbox-core/tests/conftest.py index ee6f5e4f..e579f843 100644 --- a/packages/toolbox-core/tests/conftest.py +++ b/packages/toolbox-core/tests/conftest.py @@ -163,4 +163,4 @@ def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None # Clean up toolbox server toolbox_server.terminate() - toolbox_server.wait() + toolbox_server.wait(timeout=5)