diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index edc45a8a..ee8a5f73 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -44,7 +44,8 @@ test = [ "isort==6.0.1", "mypy==1.15.0", "pytest==8.3.5", - "pytest-aioresponses==0.3.0" + "pytest-aioresponses==0.3.0", + "pytest-asyncio==0.25.3", ] [build-system] requires = ["setuptools"] diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index dc59e440..9b69f6d1 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Optional +import re +import types +from typing import Any, Callable, Optional from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool +from .tool import ToolboxTool, identify_required_authn_params class ToolboxClient: @@ -53,14 +54,37 @@ def __init__( session = ClientSession() self.__session = session - def __parse_tool(self, name: str, schema: ToolSchema) -> ToolboxTool: + def __parse_tool( + self, + name: str, + schema: ToolSchema, + auth_token_getters: dict[str, Callable[[], str]], + ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" + # sort into authenticated and reg params + params = [] + authn_params: dict[str, list[str]] = {} + auth_sources: set[str] = set() + for p in schema.parameters: + if not p.authSources: + params.append(p) + else: + authn_params[p.name] = p.authSources + auth_sources.update(p.authSources) + + authn_params = identify_required_authn_params( + authn_params, auth_token_getters.keys() + ) + tool = ToolboxTool( session=self.__session, base_url=self.__base_url, name=name, desc=schema.description, - params=[p.to_param() for p in schema.parameters], + params=[p.to_param() for p in params], + # create a read-only values for the maps to prevent mutation + required_authn_params=types.MappingProxyType(authn_params), + auth_service_token_getters=types.MappingProxyType(auth_token_getters), ) return tool @@ -99,6 +123,7 @@ async def close(self): async def load_tool( self, name: str, + auth_token_getters: dict[str, Callable[[], str]] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -109,6 +134,8 @@ async def load_tool( Args: name: The unique name or identifier of the tool to load. + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. Returns: ToolboxTool: A callable object representing the loaded tool, ready @@ -127,19 +154,23 @@ async def load_tool( if name not in manifest.tools: # TODO: Better exception raise Exception(f"Tool '{name}' not found!") - tool = self.__parse_tool(name, manifest.tools[name]) + tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters) return tool async def load_toolset( self, name: str, + auth_token_getters: dict[str, Callable[[], str]] = {}, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. Args: name: Name of the toolset to load tools. + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. + Returns: list[ToolboxTool]: A list of callables, one for each tool defined @@ -152,5 +183,8 @@ async def load_toolset( manifest: ManifestSchema = ManifestSchema(**json) # parse each tools name and schema into a list of ToolboxTools - tools = [self.__parse_tool(n, s) for n, s in manifest.tools.items()] + tools = [ + self.__parse_tool(n, s, auth_token_getters) + for n, s in manifest.tools.items() + ] return tools diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 48e4626c..6421cd99 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,10 +13,13 @@ # limitations under the License. +import types +from collections import defaultdict from inspect import Parameter, Signature -from typing import Any +from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence from aiohttp import ClientSession +from pytest import Session class ToolboxTool: @@ -32,20 +35,19 @@ class ToolboxTool: and `inspect` work as expected. """ - __url: str - __session: ClientSession - __signature__: Signature - def __init__( self, session: ClientSession, base_url: str, name: str, desc: str, - params: list[Parameter], + params: Sequence[Parameter], + required_authn_params: Mapping[str, list[str]], + auth_service_token_getters: Mapping[str, Callable[[], str]], ): """ - Initializes a callable that will trigger the tool invocation through the Toolbox server. + Initializes a callable that will trigger the tool invocation through the + Toolbox server. Args: session: The `aiohttp.ClientSession` used for making API requests. @@ -54,19 +56,73 @@ def __init__( desc: The description of the remote tool (used as its docstring). params: A list of `inspect.Parameter` objects defining the tool's arguments and their types/defaults. + required_authn_params: A dict of required authenticated parameters to a list + of services that provide values for them. + auth_service_token_getters: A dict of authService -> token (or callables that + produce a token) """ # used to invoke the toolbox API - self.__session = session + self.__session: ClientSession = session + self.__base_url: str = base_url self.__url = f"{base_url}/api/tool/{name}/invoke" - # the following properties are set to help anyone that might inspect it determine + self.__desc = desc + self.__params = params + + # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name self.__doc__ = desc self.__signature__ = Signature(parameters=params, return_annotation=str) self.__annotations__ = {p.name: p.annotation for p in params} # TODO: self.__qualname__ ?? + # map of parameter name to auth service required by it + self.__required_authn_params = required_authn_params + # map of authService -> token_getter + self.__auth_service_token_getters = auth_service_token_getters + + def __copy( + self, + session: Optional[ClientSession] = None, + base_url: Optional[str] = None, + name: Optional[str] = None, + desc: Optional[str] = None, + params: Optional[list[Parameter]] = None, + required_authn_params: Optional[Mapping[str, list[str]]] = None, + auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, + ) -> "ToolboxTool": + """ + Creates a copy of the ToolboxTool, overriding specific fields. + + Args: + session: The `aiohttp.ClientSession` used for making API requests. + base_url: The base URL of the Toolbox server API. + name: The name of the remote tool. + desc: The description of the remote tool (used as its docstring). + params: A list of `inspect.Parameter` objects defining the tool's + arguments and their types/defaults. + required_authn_params: A dict of required authenticated parameters that need + a auth_service_token_getter set for them yet. + auth_service_token_getters: A dict of authService -> token (or callables + that produce a token) + + """ + check = lambda val, default: val if val is not None else default + return ToolboxTool( + session=check(session, self.__session), + base_url=check(base_url, self.__base_url), + name=check(name, self.__name__), + desc=check(desc, self.__desc), + params=check(params, self.__params), + required_authn_params=check( + required_authn_params, self.__required_authn_params + ), + auth_service_token_getters=check( + auth_service_token_getters, self.__auth_service_token_getters + ), + ) + async def __call__(self, *args: Any, **kwargs: Any) -> str: """ Asynchronously calls the remote tool with the provided arguments. @@ -81,16 +137,103 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: Returns: The string result returned by the remote tool execution. """ + + # check if any auth services need to be specified yet + if len(self.__required_authn_params) > 0: + # Gather all the required auth services into a set + req_auth_services = set() + for s in self.__required_authn_params.values(): + req_auth_services.update(s) + raise Exception( + f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + ) + + # validate inputs to this call using the signature all_args = self.__signature__.bind(*args, **kwargs) all_args.apply_defaults() # Include default values if not provided payload = all_args.arguments + # create headers for auth services + headers = {} + for auth_service, token_getter in self.__auth_service_token_getters.items(): + headers[f"{auth_service}_token"] = token_getter() + async with self.__session.post( self.__url, json=payload, + headers=headers, ) as resp: - ret = await resp.json() - if "error" in ret: - # TODO: better error - raise Exception(ret["error"]) - return ret.get("result", ret) + body = await resp.json() + if resp.status < 200 or resp.status >= 300: + err = body.get("error", f"unexpected status from server: {resp.status}") + raise Exception(err) + return body.get("result", body) + + def add_auth_token_getters( + self, + auth_token_getters: Mapping[str, Callable[[], str]], + ) -> "ToolboxTool": + """ + Registers an auth token getter function that is used for AuthServices when tools + are invoked. + + Args: + auth_token_getters: A mapping of authentication service names to + callables that return the corresponding authentication token. + + Returns: + A new ToolboxTool instance with the specified authentication token + getters registered. + """ + + # throw an error if the authentication source is already registered + existing_services = self.__auth_service_token_getters.keys() + incoming_services = auth_token_getters.keys() + duplicates = existing_services & incoming_services + if duplicates: + raise ValueError( + f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`." + ) + + # create a read-only updated value for new_getters + new_getters = types.MappingProxyType( + dict(self.__auth_service_token_getters, **auth_token_getters) + ) + # create a read-only updated for params that are still required + new_req_authn_params = types.MappingProxyType( + identify_required_authn_params( + self.__required_authn_params, auth_token_getters.keys() + ) + ) + + return self.__copy( + auth_service_token_getters=new_getters, + required_authn_params=new_req_authn_params, + ) + + +def identify_required_authn_params( + req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] +) -> dict[str, list[str]]: + """ + Identifies authentication parameters that are still required; or not covered by + the provided `auth_service_names`. + + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_service_names: An iterable of authentication service names for which + token getters are available. + + Returns: + A new dictionary representing the subset of required authentication + parameters that are not covered by the provided `auth_service_names`. + """ + required_params = {} # params that are still required with provided auth_services + for param, services in req_authn_params.items(): + # if we don't have a token_getter for any of the services required by the param, + # the param is still required + required = not any(s in services for s in auth_service_names) + if required: + required_params[param] = services + return required_params diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index b19c575b..91ea2259 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -16,6 +16,8 @@ import inspect import pytest +import pytest_asyncio +from aioresponses import CallbackResult from toolbox_core import ToolboxClient from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema @@ -26,7 +28,7 @@ @pytest.fixture() def test_tool_str(): return ToolSchema( - description="Test Tool 1 Description", + description="Test Tool with String input", parameters=[ ParameterSchema( name="param1", type="string", description="Description of Param1" @@ -37,9 +39,8 @@ def test_tool_str(): @pytest.fixture() def test_tool_int_bool(): - """Fixture for the second test tool schema.""" return ToolSchema( - description="Test Tool 2 Description", + description="Test Tool with Int, Bool", parameters=[ ParameterSchema(name="argA", type="integer", description="Argument A"), ParameterSchema(name="argB", type="boolean", description="Argument B"), @@ -47,6 +48,22 @@ def test_tool_int_bool(): ) +@pytest.fixture() +def test_tool_auth(): + return ToolSchema( + description="Test Tool with Int,Bool+Auth", + parameters=[ + ParameterSchema(name="argA", type="integer", description="Argument A"), + ParameterSchema( + name="argB", + type="boolean", + description="Argument B", + authSources=["my-auth-service"], + ), + ], + ) + + @pytest.mark.asyncio async def test_load_tool_success(aioresponses, test_tool_str): """ @@ -83,6 +100,87 @@ async def test_load_tool_success(aioresponses, test_tool_str): assert await loaded_tool("some value") == "ok" +class TestAuth: + + @pytest.fixture + def expected_header(self): + return "some_token_for_testing" + + @pytest.fixture + def tool_name(self): + return "tool1" + + @pytest_asyncio.fixture + async def client(self, aioresponses, test_tool_auth, tool_name, expected_header): + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_auth} + ) + + # mock tool GET call + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + # mock tool INVOKE call + def require_headers(url, **kwargs): + if kwargs["headers"].get("my-auth-service_token") == expected_header: + return CallbackResult(status=200, body="{}") + else: + return CallbackResult(status=400, body="{}") + + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", + payload=manifest.model_dump(), + callback=require_headers, + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + yield client + + @pytest.mark.asyncio + async def test_auth_with_load_tool_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool( + tool_name, auth_token_getters={"my-auth-service": token_handler} + ) + res = await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_add_token_success( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool(tool_name) + tool = tool.add_auth_token_getters({"my-auth-service": token_handler}) + res = await tool(5) + + @pytest.mark.asyncio + async def test_auth_with_load_tool_fail_no_token( + self, tool_name, expected_header, client + ): + """Tests 'load_tool' with auth token is specified.""" + + def token_handler(): + return expected_header + + tool = await client.load_tool(tool_name) + with pytest.raises(Exception): + res = await tool(5) + + @pytest.mark.asyncio async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool): """Tests successfully loading a toolset with multiple tools."""