From 87e00cc4f02c08fdc715583679e8db61459b0baf Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:04:30 +0000 Subject: [PATCH 1/8] feat: add authenticated parameters support --- .../toolbox-core/src/toolbox_core/client.py | 39 ++++- .../toolbox-core/src/toolbox_core/tool.py | 160 ++++++++++++++++-- packages/toolbox-core/tests/test_client.py | 105 +++++++++++- 3 files changed, 279 insertions(+), 25 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index dc59e440..0fcca0d2 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Optional +import types +from typing import Any, Callable, Optional from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool +from .tool import ToolboxTool, filter_required_authn_params class ToolboxClient: @@ -53,14 +53,34 @@ def __init__( session = ClientSession() self.__session = session - def __parse_tool(self, name: str, schema: ToolSchema) -> ToolboxTool: + def __parse_tool( + self, + name: str, + schema: ToolSchema, + auth_token_getters: dict[str, Callable[[], str]], + ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" + # sort into authenticated and reg params + params = [] + authn_params: dict[str, list[str]] = {} + auth_sources: set[str] = set() + for p in schema.parameters: + if not p.authSources: + params.append(p) + else: + authn_params[p.name] = p.authSources + auth_sources.update(p.authSources) + + authn_params = filter_required_authn_params(authn_params, auth_sources) + tool = ToolboxTool( session=self.__session, base_url=self.__base_url, name=name, desc=schema.description, - params=[p.to_param() for p in schema.parameters], + params=[p.to_param() for p in params], + required_authn_params=types.MappingProxyType(authn_params), + auth_service_token_getters=auth_token_getters, ) return tool @@ -99,6 +119,7 @@ async def close(self): async def load_tool( self, name: str, + auth_service_tokens: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -127,13 +148,14 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name]) + tool = self.__parse_tool(name, manifest.tools[name], auth_service_tokens) return tool async def load_toolset( self, name: str, + auth_token_getters: dict[str, Callable[[], str]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. @@ -152,5 +174,8 @@ async def load_toolset( manifest: ManifestSchema = ManifestSchema(**json) # parse each tools name and schema into a list of ToolboxTools - tools = [self.__parse_tool(n, s) for n, s in manifest.tools.items()] + tools = [ + self.__parse_tool(n, s, auth_token_getters) + for n, s in manifest.tools.items() + ] return tools diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 48e4626c..494c7c21 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,10 +13,13 @@ # limitations under the License. +import types +from collections import defaultdict from inspect import Parameter, Signature -from typing import Any +from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence from aiohttp import ClientSession +from pytest import Session class ToolboxTool: @@ -32,20 +35,19 @@ class ToolboxTool: and `inspect` work as expected. """ - __url: str - __session: ClientSession - __signature__: Signature - def __init__( self, session: ClientSession, base_url: str, name: str, desc: str, - params: list[Parameter], + params: Sequence[Parameter], + required_authn_params: Mapping[str, list[str]], + auth_service_token_getters: Mapping[str, Callable[[], str]], ): """ - Initializes a callable that will trigger the tool invocation through the Toolbox server. + Initializes a callable that will trigger the tool invocation through the + Toolbox server. Args: session: The `aiohttp.ClientSession` used for making API requests. @@ -54,19 +56,69 @@ def __init__( desc: The description of the remote tool (used as its docstring). params: A list of `inspect.Parameter` objects defining the tool's arguments and their types/defaults. + required_authn_params: A dict of required authenticated parameters that + need a auth_service_token_getter set for them yet. + auth_service_tokens: A dict of authService -> token (or callables that + produce a token) """ # used to invoke the toolbox API - self.__session = session + self.__session: ClientSession = session + self.__base_url: str = base_url self.__url = f"{base_url}/api/tool/{name}/invoke" - # the following properties are set to help anyone that might inspect it determine + self.__desc = desc + self.__params = params + + # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name self.__doc__ = desc self.__signature__ = Signature(parameters=params, return_annotation=str) self.__annotations__ = {p.name: p.annotation for p in params} # TODO: self.__qualname__ ?? + # map of parameter name to auth service required by it + self.__required_authn_params = required_authn_params + # map of authService -> token_getter + self.__auth_service_token_getters = auth_service_token_getters + + def __copy( + self, + session: Optional[ClientSession] = None, + base_url: Optional[str] = None, + name: Optional[str] = None, + desc: Optional[str] = None, + params: Optional[list[Parameter]] = None, + required_authn_params: Optional[Mapping[str, list[str]]] = None, + auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, + ): + """ + Creates a copy of the ToolboxTool, overriding specific fields. + + Args: + session: The `aiohttp.ClientSession` used for making API requests. + base_url: The base URL of the Toolbox server API. + name: The name of the remote tool. + desc: The description of the remote tool (used as its docstring). + params: A list of `inspect.Parameter` objects defining the tool's + arguments and their types/defaults. + required_authn_params: A dict of required authenticated parameters that need + a auth_service_token_getter set for them yet. + auth_service_token_getters: A dict of authService -> token (or callables + that produce a token) + + """ + return ToolboxTool( + session=session or self.__session, + base_url=base_url or self.__base_url, + name=name or self.__name__, + desc=desc or self.__desc, + params=params or self.__params, + required_authn_params=required_authn_params or self.__required_authn_params, + auth_service_token_getters=auth_service_token_getters + or self.__auth_service_token_getters, + ) + async def __call__(self, *args: Any, **kwargs: Any) -> str: """ Asynchronously calls the remote tool with the provided arguments. @@ -81,16 +133,96 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: Returns: The string result returned by the remote tool execution. """ + + # check if any auth services need to be specified yet + if len(self.__required_authn_params) > 0: + req_auth_services = set(l for l in self.__required_authn_params.keys()) + raise Exception( + f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + ) + + # validate inputs to this call using the signature all_args = self.__signature__.bind(*args, **kwargs) all_args.apply_defaults() # Include default values if not provided payload = all_args.arguments + # create headers for auth services + headers = {} + for auth_service, token_getter in self.__auth_service_token_getters.items(): + headers[f"{auth_service}_token"] = token_getter() + async with self.__session.post( self.__url, json=payload, + headers=headers, ) as resp: - ret = await resp.json() - if "error" in ret: - # TODO: better error - raise Exception(ret["error"]) - return ret.get("result", ret) + body = await resp.json() + if resp.status < 200 or resp.status >= 300: + err = body.get("error", f"unexpected status from server: {resp.status}") + raise Exception(err) + return body.get("result", body) + + def add_auth_token_getters( + self, + auth_token_getters: Mapping[str, Callable[[], str]], + ) -> "ToolboxTool": + """ + Registers a auth token getter function that is used for AuthServices when tools + are invoked. + + Args: + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. + + Returns: + A new ToolboxTool instance with the specified authentication token + getters registered. + """ + + # throw an error if the authentication source is already registered + dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys() + if dupes: + raise ValueError( + f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`." + ) + + # create a read-only updated value for new_getters + new_getters = types.MappingProxyType( + dict(self.__auth_service_token_getters, **auth_token_getters) + ) + # create a read-only updated for params that are still required + new_req_authn_params = types.MappingProxyType( + filter_required_authn_params( + self.__required_authn_params, auth_token_getters.keys() + ) + ) + + return self.__copy( + auth_service_token_getters=new_getters, + required_authn_params=new_req_authn_params, + ) + + +def filter_required_authn_params( + req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str] +) -> dict[str, list[str]]: + """ + Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services. + + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_services: An iterable of authentication service names for which + token getters are available. + + Returns: + A new dictionary representing the subset of required authentication + parameters that are not covered by the provided `auth_services`. + """ + req_params = {} + for param, services in req_authn_params.items(): + # if we don't have a token_getter for any of the services required by the param, the param is still required + required = not any(s in services for s in auth_services) + if required: + req_params[param] = services + return req_params diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index b19c575b..101fdb91 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -16,7 +16,8 @@ import inspect import pytest - +import pytest_asyncio +from aioresponses import CallbackResult from toolbox_core import ToolboxClient from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema @@ -26,7 +27,7 @@ @pytest.fixture() def test_tool_str(): return ToolSchema( - description="Test Tool 1 Description", + description="Test Tool with String input", parameters=[ ParameterSchema( name="param1", type="string", description="Description of Param1" @@ -37,9 +38,8 @@ def test_tool_str(): @pytest.fixture() def test_tool_int_bool(): - """Fixture for the second test tool schema.""" return ToolSchema( - description="Test Tool 2 Description", + description="Test Tool with Int, Bool", parameters=[ ParameterSchema(name="argA", type="integer", description="Argument A"), ParameterSchema(name="argB", type="boolean", description="Argument B"), @@ -47,6 +47,22 @@ def test_tool_int_bool(): ) +@pytest.fixture() +def test_tool_auth(): + return ToolSchema( + description="Test Tool with Int,Bool+Auth", + parameters=[ + ParameterSchema(name="argA", type="integer", description="Argument A"), + ParameterSchema( + name="argB", + type="boolean", + description="Argument B", + authSources=["my-auth-service"], + ), + ], + ) + + @pytest.mark.asyncio async def test_load_tool_success(aioresponses, test_tool_str): """ @@ -83,6 +99,87 @@ async def test_load_tool_success(aioresponses, test_tool_str): assert await loaded_tool("some value") == "ok" +class TestAuth: + + @pytest.fixture + def expected_header(self): + return "some_token_for_testing" + + @pytest.fixture + def tool_name(self): + return "tool1" + + @pytest_asyncio.fixture + async def client(self, aioresponses, test_tool_auth, tool_name, expected_header): + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_auth} + ) + + # mock tool GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool INVOKE call + def require_headers(url, **kwargs): + if kwargs["headers"].get("my-auth-service_token") == expected_header: + return CallbackResult(status=200, body="{}") + else: + return CallbackResult(status=400, body="{}") + + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", + payload=manifest.model_dump(), + callback=require_headers, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + yield client + + @pytest.mark.asyncio + async def test_auth_with_load_tool_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool( + tool_name, auth_service_tokens={"my-auth-service": token_handler} + ) + res = await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_add_token_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool(tool_name) + tool = await client.add_auth_token_getters({"my-auth-service": token_handler}) + res = await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_load_tool_fail_no_token( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool(tool_name) + with pytest.raises(Exception): + res = await tool(5) + + @pytest.mark.asyncio async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): """Tests successfully loading a toolset with multiple tools.""" From 6b263ad2e18b1379eaa78dd68907216604ca2245 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:22:03 +0000 Subject: [PATCH 2/8] chore: add asyncio dep --- packages/toolbox-core/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index edc45a8a..ee8a5f73 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -44,7 +44,8 @@ test = [ "isort==6.0.1", "mypy==1.15.0", "pytest==8.3.5", - "pytest-aioresponses==0.3.0" + "pytest-aioresponses==0.3.0", + "pytest-asyncio==0.25.3", ] [build-system] requires = ["setuptools"] From 5fe541f863ad20ae20c1a8567481de1a8b6b4b95 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:24:23 +0000 Subject: [PATCH 3/8] chore: run itest --- packages/toolbox-core/tests/test_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 101fdb91..b6b1ff85 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -18,6 +18,7 @@ import pytest import pytest_asyncio from aioresponses import CallbackResult + from toolbox_core import ToolboxClient from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema From e9d7a3103d7bac63c45dd502427a0270e45ae96f Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:26:39 +0000 Subject: [PATCH 4/8] chore: add type hint --- packages/toolbox-core/src/toolbox_core/tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 494c7c21..17c44f63 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -91,7 +91,7 @@ def __copy( params: Optional[list[Parameter]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, - ): + ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. From 58c55cfc4d565a3255d07bbaa8bbd4b81ff8651e Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Sun, 30 Mar 2025 17:32:36 +0000 Subject: [PATCH 5/8] fix: call tool instead of client --- packages/toolbox-core/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index b6b1ff85..aebf3133 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -164,7 +164,7 @@ def token_handler(): return expected_header tool = await client.load_tool(tool_name) - tool = await client.add_auth_token_getters({"my-auth-service": token_handler}) + tool = tool.add_auth_token_getters({"my-auth-service": token_handler}) res = await tool(5) @pytest.mark.asyncio From c1a482a762a5f7d97df3ff7667a7f243ad185fc2 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Mon, 31 Mar 2025 15:33:05 +0000 Subject: [PATCH 6/8] chore: correct arg name --- packages/toolbox-core/src/toolbox_core/client.py | 9 +++++++-- packages/toolbox-core/tests/test_client.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 0fcca0d2..1acd521a 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -119,7 +119,7 @@ async def close(self): async def load_tool( self, name: str, - auth_service_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -130,6 +130,8 @@ async def load_tool( Args: name: The unique name or identifier of the tool to load. + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. Returns: ToolboxTool: A callable object representing the loaded tool, ready @@ -148,7 +150,7 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name], auth_service_tokens) + tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters) return tool @@ -162,6 +164,9 @@ async def load_toolset( Args: name: Name of the toolset to load tools. + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. + Returns: list[ToolboxTool]: A list of callables, one for each tool defined diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index aebf3133..91ea2259 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -150,7 +150,7 @@ def token_handler(): return expected_header tool = await client.load_tool( - tool_name, auth_service_tokens={"my-auth-service": token_handler} + tool_name, auth_token_getters={"my-auth-service": token_handler} ) res = await tool(5) From c1ac2cd3033c64287cc1c273c75bc1c4fea88a75 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Wed, 2 Apr 2025 02:21:36 +0000 Subject: [PATCH 7/8] chore: address feedback --- .../toolbox-core/src/toolbox_core/client.py | 4 +- .../toolbox-core/src/toolbox_core/tool.py | 50 +++++++++++-------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 1acd521a..ca17a081 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import types from typing import Any, Callable, Optional @@ -79,8 +80,9 @@ def __parse_tool( name=name, desc=schema.description, params=[p.to_param() for p in params], + # create a read-only values for the maps to prevent mutation required_authn_params=types.MappingProxyType(authn_params), - auth_service_token_getters=auth_token_getters, + auth_service_token_getters=types.MappingProxyType(auth_token_getters), ) return tool diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 17c44f63..bf9ddc34 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -56,9 +56,9 @@ def __init__( desc: The description of the remote tool (used as its docstring). params: A list of `inspect.Parameter` objects defining the tool's arguments and their types/defaults. - required_authn_params: A dict of required authenticated parameters that - need a auth_service_token_getter set for them yet. - auth_service_tokens: A dict of authService -> token (or callables that + required_authn_params: A dict of required authenticated parameters to a list + of services that provide values for them. + auth_service_token_getters: A dict of authService -> token (or callables that produce a token) """ @@ -108,15 +108,19 @@ def __copy( that produce a token) """ + check = lambda val, default: val if val is not None else default return ToolboxTool( - session=session or self.__session, - base_url=base_url or self.__base_url, - name=name or self.__name__, - desc=desc or self.__desc, - params=params or self.__params, - required_authn_params=required_authn_params or self.__required_authn_params, - auth_service_token_getters=auth_service_token_getters - or self.__auth_service_token_getters, + session=check(session, self.__session), + base_url=check(base_url, self.__base_url), + name=check(name, self.__name__), + desc=check(desc, self.__desc), + params=check(params, self.__params), + required_authn_params=check( + required_authn_params, self.__required_authn_params + ), + auth_service_token_getters=check( + auth_service_token_getters, self.__auth_service_token_getters + ), ) async def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -138,7 +142,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: if len(self.__required_authn_params) > 0: req_auth_services = set(l for l in self.__required_authn_params.keys()) raise Exception( - f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" ) # validate inputs to this call using the signature @@ -167,7 +171,7 @@ def add_auth_token_getters( auth_token_getters: Mapping[str, Callable[[], str]], ) -> "ToolboxTool": """ - Registers a auth token getter function that is used for AuthServices when tools + Registers an auth token getter function that is used for AuthServices when tools are invoked. Args: @@ -204,25 +208,27 @@ def add_auth_token_getters( def filter_required_authn_params( - req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str] + req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] ) -> dict[str, list[str]]: """ - Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services. + Utility function for reducing 'req_authn_params' to a subset of parameters that + aren't supplied by at least one service in auth_services. Args: req_authn_params: A mapping of parameter names to sets of required authentication services. - auth_services: An iterable of authentication service names for which + auth_service_names: An iterable of authentication service names for which token getters are available. Returns: A new dictionary representing the subset of required authentication - parameters that are not covered by the provided `auth_services`. + parameters that are not covered by the provided `auth_service_names`. """ - req_params = {} + required_params = {} # params that are still required with provided auth_services for param, services in req_authn_params.items(): - # if we don't have a token_getter for any of the services required by the param, the param is still required - required = not any(s in services for s in auth_services) + # if we don't have a token_getter for any of the services required by the param, + # the param is still required + required = not any(s in services for s in auth_service_names) if required: - req_params[param] = services - return req_params + required_params[param] = services + return required_params From bcba462a8275234da8c81a180935978219ede269 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Wed, 2 Apr 2025 02:30:48 +0000 Subject: [PATCH 8/8] chore: address more feedback --- .../toolbox-core/src/toolbox_core/client.py | 6 ++++-- .../toolbox-core/src/toolbox_core/tool.py | 21 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index ca17a081..9b69f6d1 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -18,7 +18,7 @@ from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool, filter_required_authn_params +from .tool import ToolboxTool, identify_required_authn_params class ToolboxClient: @@ -72,7 +72,9 @@ def __parse_tool( authn_params[p.name] = p.authSources auth_sources.update(p.authSources) - authn_params = filter_required_authn_params(authn_params, auth_sources) + authn_params = identify_required_authn_params( + authn_params, auth_token_getters.keys() + ) tool = ToolboxTool( session=self.__session, diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index bf9ddc34..6421cd99 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -140,7 +140,10 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # check if any auth services need to be specified yet if len(self.__required_authn_params) > 0: - req_auth_services = set(l for l in self.__required_authn_params.keys()) + # Gather all the required auth services into a set + req_auth_services = set() + for s in self.__required_authn_params.values(): + req_auth_services.update(s) raise Exception( f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" ) @@ -184,10 +187,12 @@ def add_auth_token_getters( """ # throw an error if the authentication source is already registered - dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys() - if dupes: + existing_services = self.__auth_service_token_getters.keys() + incoming_services = auth_token_getters.keys() + duplicates = existing_services & incoming_services + if duplicates: raise ValueError( - f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`." + f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`." ) # create a read-only updated value for new_getters @@ -196,7 +201,7 @@ def add_auth_token_getters( ) # create a read-only updated for params that are still required new_req_authn_params = types.MappingProxyType( - filter_required_authn_params( + identify_required_authn_params( self.__required_authn_params, auth_token_getters.keys() ) ) @@ -207,12 +212,12 @@ def add_auth_token_getters( ) -def filter_required_authn_params( +def identify_required_authn_params( req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] ) -> dict[str, list[str]]: """ - Utility function for reducing 'req_authn_params' to a subset of parameters that - aren't supplied by at least one service in auth_services. + Identifies authentication parameters that are still required; or not covered by + the provided `auth_service_names`. Args: req_authn_params: A mapping of parameter names to sets of required