diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index a534e706..4d4d9db2 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,15 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import types -from typing import Any, Callable, Mapping, Optional, Union +from typing import Any, Callable, Coroutine, Mapping, Optional, Union from aiohttp import ClientSession from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool, identify_required_authn_params +from .tool import ToolboxTool +from .utils import identify_required_authn_params, resolve_value class ToolboxClient: @@ -37,6 +36,7 @@ def __init__( self, url: str, session: Optional[ClientSession] = None, + client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None, ): """ Initializes the ToolboxClient. @@ -47,6 +47,7 @@ def __init__( If None (default), a new session is created internally. Note that if a session is provided, its lifecycle (including closing) should typically be managed externally. + client_headers: Headers to include in each request sent through this client. """ self.__base_url = url @@ -55,12 +56,15 @@ def __init__( session = ClientSession() self.__session = session + self.__client_headers = client_headers if client_headers is not None else {} + def __parse_tool( self, name: str, schema: ToolSchema, auth_token_getters: dict[str, Callable[[], str]], all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], + client_headers: Mapping[str, Union[Callable, Coroutine, str]], ) -> ToolboxTool: """Internal helper to create a callable tool from its schema.""" # sort into reg, authn, and bound params @@ -89,6 +93,7 @@ def __parse_tool( required_authn_params=types.MappingProxyType(authn_params), auth_service_token_getters=types.MappingProxyType(auth_token_getters), bound_params=types.MappingProxyType(bound_params), + client_headers=types.MappingProxyType(client_headers), ) return tool @@ -144,18 +149,21 @@ async def load_tool( bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - - Returns: ToolboxTool: A callable object representing the loaded tool, ready for execution. The specific arguments and behavior of the callable depend on the tool itself. """ + # Resolve client headers + resolved_headers = { + name: await resolve_value(val) + for name, val in self.__client_headers.items() + } # request the definition of the tool from the server url = f"{self.__base_url}/api/tool/{name}" - async with self.__session.get(url) as response: + async with self.__session.get(url, headers=resolved_headers) as response: json = await response.json() manifest: ManifestSchema = ManifestSchema(**json) @@ -164,7 +172,11 @@ async def load_tool( # TODO: Better exception raise Exception(f"Tool '{name}' not found!") tool = self.__parse_tool( - name, manifest.tools[name], auth_token_getters, bound_params + name, + manifest.tools[name], + auth_token_getters, + bound_params, + self.__client_headers, ) return tool @@ -185,21 +197,50 @@ async def load_toolset( bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - - Returns: list[ToolboxTool]: A list of callables, one for each tool defined in the toolset. """ + # Resolve client headers + original_headers = self.__client_headers + resolved_headers = { + header_name: await resolve_value(original_headers[header_name]) + for header_name in original_headers + } # Request the definition of the tool from the server url = f"{self.__base_url}/api/toolset/{name or ''}" - async with self.__session.get(url) as response: + async with self.__session.get(url, headers=resolved_headers) as response: json = await response.json() manifest: ManifestSchema = ManifestSchema(**json) # parse each tools name and schema into a list of ToolboxTools tools = [ - self.__parse_tool(n, s, auth_token_getters, bound_params) + self.__parse_tool( + n, s, auth_token_getters, bound_params, self.__client_headers + ) for n, s in manifest.tools.items() ] return tools + + async def add_headers( + self, headers: Mapping[str, Union[Callable, Coroutine, str]] + ) -> None: + """ + Asynchronously Add headers to be included in each request sent through this client. + + Args: + headers: Headers to include in each request sent through this client. + + Raises: + ValueError: If any of the headers are already registered in the client. + """ + existing_headers = self.__client_headers.keys() + incoming_headers = headers.keys() + duplicates = existing_headers & incoming_headers + if duplicates: + raise ValueError( + f"Client header(s) `{', '.join(duplicates)}` already registered in the client." + ) + + merged_headers = {**self.__client_headers, **headers} + self.__client_headers = merged_headers diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 9ad420f6..d26d9287 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -15,14 +15,7 @@ import types from inspect import Signature -from typing import ( - Any, - Callable, - Mapping, - Optional, - Sequence, - Union, -) +from typing import Any, Callable, Coroutine, Mapping, Optional, Sequence, Union from aiohttp import ClientSession @@ -58,6 +51,7 @@ def __init__( required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], bound_params: Mapping[str, Union[Callable[[], Any], Any]], + client_headers: Mapping[str, Union[Callable, Coroutine, str]], ): """ Initializes a callable that will trigger the tool invocation through the @@ -75,6 +69,7 @@ def __init__( produce a token) bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. + client_headers: Client specific headers bound to the tool. """ # used to invoke the toolbox API self.__session: ClientSession = session @@ -96,12 +91,27 @@ def __init__( self.__annotations__ = {p.name: p.annotation for p in inspect_type_params} self.__qualname__ = f"{self.__class__.__qualname__}.{self.__name__}" + # Validate conflicting Headers/Auth Tokens + request_header_names = client_headers.keys() + auth_token_names = [ + auth_token_name + "_token" + for auth_token_name in auth_service_token_getters.keys() + ] + duplicates = request_header_names & auth_token_names + if duplicates: + raise ValueError( + f"Client header(s) `{', '.join(duplicates)}` already registered in client. " + f"Cannot register client the same headers in the client as well as tool." + ) + # map of parameter name to auth service required by it self.__required_authn_params = required_authn_params # map of authService -> token_getter self.__auth_service_token_getters = auth_service_token_getters # map of parameter name to value (or callable that produces that value) self.__bound_parameters = bound_params + # map of client headers to their value/callable/coroutine + self.__client_headers = client_headers def __copy( self, @@ -113,6 +123,7 @@ def __copy( required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, + client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None, ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. @@ -129,7 +140,7 @@ def __copy( that produce a token) bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - + client_headers: Client specific headers bound to the tool. """ check = lambda val, default: val if val is not None else default return ToolboxTool( @@ -145,6 +156,7 @@ def __copy( auth_service_token_getters, self.__auth_service_token_getters ), bound_params=check(bound_params, self.__bound_parameters), + client_headers=check(client_headers, self.__client_headers), ) async def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -169,7 +181,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: for s in self.__required_authn_params.values(): req_auth_services.update(s) raise Exception( - f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + f"One or more of the following authn services are required to invoke this tool" + f": {','.join(req_auth_services)}" ) # validate inputs to this call using the signature @@ -188,6 +201,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): headers[f"{auth_service}_token"] = await resolve_value(token_getter) + for client_header_name, client_header_val in self.__client_headers.items(): + headers[client_header_name] = await resolve_value(client_header_val) async with self.__session.post( self.__url, @@ -215,6 +230,10 @@ def add_auth_token_getters( Returns: A new ToolboxTool instance with the specified authentication token getters registered. + + Raises + ValueError: If the auth source has already been registered either + to the tool or to the corresponding client. """ # throw an error if the authentication source is already registered @@ -226,6 +245,18 @@ def add_auth_token_getters( f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`." ) + # Validate duplicates with client headers + request_header_names = self.__client_headers.keys() + auth_token_names = [ + auth_token_name + "_token" for auth_token_name in incoming_services + ] + duplicates = request_header_names & auth_token_names + if duplicates: + raise ValueError( + f"Client header(s) `{', '.join(duplicates)}` already registered in client. " + f"Cannot register client the same headers in the client as well as tool." + ) + # create a read-only updated value for new_getters new_getters = types.MappingProxyType( dict(self.__auth_service_token_getters, **auth_token_getters) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index a9cb091a..e6d12a0c 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -15,11 +15,12 @@ import inspect import json +from typing import Any, Callable, Mapping, Optional from unittest.mock import AsyncMock, Mock import pytest import pytest_asyncio -from aioresponses import CallbackResult +from aioresponses import CallbackResult, aioresponses from toolbox_core import ToolboxClient from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema @@ -66,6 +67,76 @@ def test_tool_auth(): ) +# --- Helper Functions for Mocking --- + + +def mock_tool_load( + aio_resp: aioresponses, + tool_name: str, + tool_schema: ToolSchema, + base_url: str = TEST_BASE_URL, + server_version: str = "0.0.0", + status: int = 200, + callback: Optional[Callable] = None, + payload_override: Optional[Callable] = None, +): + """Mocks the GET /api/tool/{tool_name} endpoint.""" + url = f"{base_url}/api/tool/{tool_name}" + if payload_override is not None: + payload = payload_override + else: + manifest = ManifestSchema( + serverVersion=server_version, tools={tool_name: tool_schema} + ) + payload = manifest.model_dump() + aio_resp.get( + url, + payload=payload, + status=status, + callback=callback, + ) + + +def mock_toolset_load( + aio_resp: aioresponses, + toolset_name: str, + tools_dict: Mapping[str, ToolSchema], + base_url: str = TEST_BASE_URL, + server_version: str = "0.0.0", + status: int = 200, + callback: Optional[Callable] = None, +): + """Mocks the GET /api/toolset/{toolset_name} endpoint.""" + # Handle default toolset name (empty string) + url_path = f"toolset/{toolset_name}" if toolset_name else "toolset/" + url = f"{base_url}/api/{url_path}" + manifest = ManifestSchema(serverVersion=server_version, tools=tools_dict) + aio_resp.get( + url, + payload=manifest.model_dump(), + status=status, + callback=callback, + ) + + +def mock_tool_invoke( + aio_resp: aioresponses, + tool_name: str, + base_url: str = TEST_BASE_URL, + response_payload: Any = {"result": "ok"}, + status: int = 200, + callback: Optional[Callable] = None, +): + """Mocks the POST /api/tool/{tool_name}/invoke endpoint.""" + url = f"{base_url}/api/tool/{tool_name}/invoke" + aio_resp.post( + url, + payload=response_payload, + status=status, + callback=callback, + ) + + @pytest.mark.asyncio async def test_load_tool_success(aioresponses, test_tool_str): """ @@ -73,17 +144,8 @@ async def test_load_tool_success(aioresponses, test_tool_str): """ # Mock out responses from server TOOL_NAME = "test_tool_1" - manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", - payload=manifest.model_dump(), - status=200, - ) - aioresponses.post( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", - payload={"result": "ok"}, - status=200, - ) + mock_tool_load(aioresponses, TOOL_NAME, test_tool_str) + mock_tool_invoke(aioresponses, TOOL_NAME) async with ToolboxClient(TEST_BASE_URL) as client: # Load a Tool @@ -115,11 +177,7 @@ async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_b manifest = ManifestSchema( serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool} ) - aioresponses.get( - f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}", - payload=manifest.model_dump(), - status=200, - ) + mock_toolset_load(aioresponses, TOOLSET_NAME, manifest.tools) async with ToolboxClient(TEST_BASE_URL) as client: tools = await client.load_toolset(TOOLSET_NAME) @@ -137,17 +195,10 @@ async def test_invoke_tool_server_error(aioresponses, test_tool_str): error status.""" TOOL_NAME = "server_error_tool" ERROR_MESSAGE = "Simulated Server Error" - manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", - payload=manifest.model_dump(), - status=200, - ) - aioresponses.post( - f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", - payload={"error": ERROR_MESSAGE}, - status=500, + mock_tool_load(aioresponses, TOOL_NAME, test_tool_str) + mock_tool_invoke( + aioresponses, TOOL_NAME, response_payload={"error": ERROR_MESSAGE}, status=500 ) async with ToolboxClient(TEST_BASE_URL) as client: @@ -166,14 +217,15 @@ async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc" REQUESTED_TOOL_NAME = "non_existent_tool_xyz" - manifest = ManifestSchema( + mismatched_manifest_payload = ManifestSchema( serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str} - ) + ).model_dump() - aioresponses.get( - f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", - payload=manifest.model_dump(), - status=200, + mock_tool_load( + aio_resp=aioresponses, + tool_name=REQUESTED_TOOL_NAME, + tool_schema=test_tool_str, + payload_override=mismatched_manifest_payload, ) async with ToolboxClient(TEST_BASE_URL) as client: @@ -181,7 +233,7 @@ async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): await client.load_tool(REQUESTED_TOOL_NAME) aioresponses.assert_called_once_with( - f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET" + f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET", headers={} ) @@ -451,3 +503,271 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client): assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} bound_async_callable.assert_awaited_once() + + +class TestClientHeaders: + @pytest.fixture + def static_header(self): + return {"X-Static-Header": "static_value"} + + @pytest.fixture + def sync_callable_header_value(self): + return "sync_callable_value" + + @pytest.fixture + def sync_callable_header(self, sync_callable_header_value): + return {"X-Sync-Callable-Header": Mock(return_value=sync_callable_header_value)} + + @pytest.fixture + def async_callable_header_value(self): + return "async_callable_value" + + @pytest.fixture + def async_callable_header(self, async_callable_header_value): + return { + "X-Async-Callable-Header": AsyncMock( + return_value=async_callable_header_value + ) + } + + @staticmethod + def create_callback_factory( + expected_header, + callback_payload, + callback_status: int = 200, + ) -> Callable: + """ + Factory that RETURNS a callback function for aioresponses. + The returned callback will check headers and return the specified payload/status. + """ + + def actual_callback(url, **kwargs): + received_headers = kwargs.get("headers") + assert received_headers == expected_header + return CallbackResult(status=callback_status, payload=callback_payload) + + return actual_callback + + @pytest.mark.asyncio + async def test_client_init_with_headers(self, static_header): + """Tests client initialization with static headers.""" + async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: + assert client._ToolboxClient__client_headers == static_header + + @pytest.mark.asyncio + async def test_load_tool_with_static_headers( + self, aioresponses, test_tool_str, static_header + ): + """Tests loading and invoking a tool with static client headers.""" + tool_name = "tool_with_static_headers" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_str} + ) + expected_payload = {"result": "ok"} + + get_callback = self.create_callback_factory( + expected_header=static_header, + callback_payload=manifest.model_dump(), + ) + aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) + + post_callback = self.create_callback_factory( + expected_header=static_header, + callback_payload=expected_payload, + ) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback + ) + + async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: + tool = await client.load_tool(tool_name) + result = await tool(param1="test") + assert result == expected_payload["result"] + + @pytest.mark.asyncio + async def test_load_tool_with_sync_callable_headers( + self, + aioresponses, + test_tool_str, + sync_callable_header, + sync_callable_header_value, + ): + """Tests loading and invoking a tool with sync callable client headers.""" + tool_name = "tool_with_sync_callable_headers" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_str} + ) + expected_payload = {"result": "ok_sync"} + header_key = list(sync_callable_header.keys())[0] + header_mock = sync_callable_header[header_key] + resolved_header = {header_key: sync_callable_header_value} + + get_callback = self.create_callback_factory( + expected_header=resolved_header, + callback_payload=manifest.model_dump(), + ) + aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) + + post_callback = self.create_callback_factory( + expected_header=resolved_header, + callback_payload=expected_payload, + ) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback + ) + + async with ToolboxClient( + TEST_BASE_URL, client_headers=sync_callable_header + ) as client: + tool = await client.load_tool(tool_name) + header_mock.assert_called_once() # GET + + header_mock.reset_mock() # Reset before invoke + + result = await tool(param1="test") + assert result == expected_payload["result"] + header_mock.assert_called_once() # POST/invoke + + @pytest.mark.asyncio + async def test_load_tool_with_async_callable_headers( + self, + aioresponses, + test_tool_str, + async_callable_header, + async_callable_header_value, + ): + """Tests loading and invoking a tool with async callable client headers.""" + tool_name = "tool_with_async_callable_headers" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_str} + ) + expected_payload = {"result": "ok_async"} + + header_key = list(async_callable_header.keys())[0] + header_mock: AsyncMock = async_callable_header[header_key] # Get the AsyncMock + + # Calculate expected result using the VALUE fixture + resolved_header = {header_key: async_callable_header_value} + + get_callback = self.create_callback_factory( + expected_header=resolved_header, + callback_payload=manifest.model_dump(), + ) + aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) + + post_callback = self.create_callback_factory( + expected_header=resolved_header, + callback_payload=expected_payload, + ) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback + ) + + async with ToolboxClient( + TEST_BASE_URL, client_headers=async_callable_header + ) as client: + tool = await client.load_tool(tool_name) + header_mock.assert_awaited_once() # GET + + header_mock.reset_mock() + + result = await tool(param1="test") + assert result == expected_payload["result"] + header_mock.assert_awaited_once() # POST/invoke + + @pytest.mark.asyncio + async def test_load_toolset_with_headers( + self, aioresponses, test_tool_str, static_header + ): + """Tests loading a toolset with client headers.""" + toolset_name = "toolset_with_headers" + tool_name = "tool_in_set" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_str} + ) + + get_callback = self.create_callback_factory( + expected_header=static_header, + callback_payload=manifest.model_dump(), + ) + aioresponses.get( + f"{TEST_BASE_URL}/api/toolset/{toolset_name}", callback=get_callback + ) + async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: + tools = await client.load_toolset(toolset_name) + assert len(tools) == 1 + assert tools[0].__name__ == tool_name + + @pytest.mark.asyncio + async def test_add_headers_success( + self, aioresponses, test_tool_str, static_header + ): + """Tests adding headers after client initialization.""" + tool_name = "tool_after_add_headers" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_str} + ) + expected_payload = {"result": "added_ok"} + + get_callback = self.create_callback_factory( + expected_header=static_header, + callback_payload=manifest.model_dump(), + ) + aioresponses.get(f"{TEST_BASE_URL}/api/tool/{tool_name}", callback=get_callback) + + post_callback = self.create_callback_factory( + expected_header=static_header, + callback_payload=expected_payload, + ) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke", callback=post_callback + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + await client.add_headers(static_header) + assert client._ToolboxClient__client_headers == static_header + + tool = await client.load_tool(tool_name) + result = await tool(param1="test") + assert result == expected_payload["result"] + + @pytest.mark.asyncio + async def test_add_headers_duplicate_fail(self, static_header): + """Tests that adding a duplicate header via add_headers raises ValueError.""" + async with ToolboxClient(TEST_BASE_URL, client_headers=static_header) as client: + with pytest.raises( + ValueError, + match=f"Client header\\(s\\) `X-Static-Header` already registered", + ): + await client.add_headers(static_header) + + @pytest.mark.asyncio + async def test_client_header_auth_token_conflict_fail( + self, aioresponses, test_tool_auth + ): + """ + Tests that loading a tool fails if a client header conflicts with an auth token name. + """ + tool_name = "auth_conflict_tool" + conflict_key = "my-auth-service_token" + manifest = ManifestSchema( + serverVersion="0.0.0", tools={tool_name: test_tool_auth} + ) + + conflicting_headers = {conflict_key: "some_value"} + auth_getters = {"my-auth-service": lambda: "token_val"} + + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{tool_name}", + payload=manifest.model_dump(), + status=200, + ) + + async with ToolboxClient( + TEST_BASE_URL, client_headers=conflicting_headers + ) as client: + with pytest.raises( + ValueError, + match=f"Client header\\(s\\) `{conflict_key}` already registered", + ): + await client.load_tool(tool_name, auth_token_getters=auth_getters) diff --git a/packages/toolbox-core/tests/test_tools.py b/packages/toolbox-core/tests/test_tools.py index 7cb4f305..a12fe83f 100644 --- a/packages/toolbox-core/tests/test_tools.py +++ b/packages/toolbox-core/tests/test_tools.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from typing import AsyncGenerator +import inspect +from typing import AsyncGenerator, Callable from unittest.mock import AsyncMock, Mock import pytest @@ -40,6 +39,20 @@ def sample_tool_params() -> list[ParameterSchema]: ] +@pytest.fixture +def sample_tool_auth_params() -> list[ParameterSchema]: + """Parameters for a sample tool requiring authentication.""" + return [ + ParameterSchema(name="target", type="string", description="Target system"), + ParameterSchema( + name="token", + type="string", + description="Auth token", + authSources=["test-auth"], + ), + ] + + @pytest.fixture def sample_tool_description() -> str: """Description for the sample tool.""" @@ -53,6 +66,61 @@ async def http_session() -> AsyncGenerator[ClientSession, None]: yield session +# --- Fixtures for Client Headers --- + + +@pytest.fixture +def static_client_header() -> dict[str, str]: + return {"X-Client-Static": "client-static-value"} + + +@pytest.fixture +def sync_callable_client_header_value() -> str: + return "client-sync-callable-value" + + +@pytest.fixture +def sync_callable_client_header(sync_callable_client_header_value) -> dict[str, Mock]: + return {"X-Client-Sync": Mock(return_value=sync_callable_client_header_value)} + + +@pytest.fixture +def async_callable_client_header_value() -> str: + return "client-async-callable-value" + + +@pytest.fixture +def async_callable_client_header( + async_callable_client_header_value, +) -> dict[str, AsyncMock]: + return { + "X-Client-Async": AsyncMock(return_value=async_callable_client_header_value) + } + + +# --- Fixtures for Auth Getters --- + + +@pytest.fixture +def auth_token_value() -> str: + return "auth-token-123" + + +@pytest.fixture +def auth_getters(auth_token_value) -> dict[str, Callable[[], str]]: + return {"test-auth": lambda: auth_token_value} + + +@pytest.fixture +def auth_getters_mock(auth_token_value) -> dict[str, Mock]: + return {"test-auth": Mock(return_value=auth_token_value)} + + +@pytest.fixture +def auth_header_key() -> str: + return "test-auth_token" + + def test_create_func_docstring_one_param_real_schema(): """ Tests create_func_docstring with one real ParameterSchema instance. @@ -168,6 +236,7 @@ async def test_tool_creation_callable_and_run( required_authn_params={}, auth_service_token_getters={}, bound_params={}, + client_headers={}, ) assert callable(tool_instance), "ToolboxTool instance should be callable" @@ -212,17 +281,15 @@ async def test_tool_run_with_pydantic_validation_error( required_authn_params={}, auth_service_token_getters={}, bound_params={}, + client_headers={}, ) assert callable(tool_instance) - with pytest.raises(ValidationError) as exc_info: + expected_pattern = r"1 validation error for sample_tool\ncount\n Input should be a valid integer, unable to parse string as an integer \[\s*type=int_parsing,\s*input_value='not-a-number',\s*input_type=str\s*\]*" + with pytest.raises(ValidationError, match=expected_pattern): await tool_instance(message="hello", count="not-a-number") - assert ( - "1 validation error for sample_tool\ncount\n Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not-a-number', input_type=str]\n For further information visit https://errors.pydantic.dev/2.11/v/int_parsing" - in str(exc_info.value) - ) m.assert_not_called() @@ -285,3 +352,72 @@ async def test_resolve_value_async_callable(): async_callable.assert_awaited_once() assert resolved == expected_value + + +# --- Tests for ToolboxTool Initialization and Validation --- + + +def test_tool_init_basic(http_session, sample_tool_params, sample_tool_description): + """Tests basic tool initialization without headers or auth.""" + tool_instance = ToolboxTool( + session=http_session, + base_url=TEST_BASE_URL, + name=TEST_TOOL_NAME, + description=sample_tool_description, + params=sample_tool_params, + required_authn_params={}, + auth_service_token_getters={}, + bound_params={}, + client_headers={}, + ) + assert tool_instance.__name__ == TEST_TOOL_NAME + assert inspect.iscoroutinefunction(tool_instance.__call__) + assert "message" in tool_instance.__signature__.parameters + assert "count" in tool_instance.__signature__.parameters + + assert tool_instance._ToolboxTool__client_headers == {} + assert tool_instance._ToolboxTool__auth_service_token_getters == {} + + +def test_tool_init_with_client_headers( + http_session, sample_tool_params, sample_tool_description, static_client_header +): + """Tests tool initialization *with* client headers.""" + tool_instance = ToolboxTool( + session=http_session, + base_url=TEST_BASE_URL, + name=TEST_TOOL_NAME, + description=sample_tool_description, + params=sample_tool_params, + required_authn_params={}, + auth_service_token_getters={}, + bound_params={}, + client_headers=static_client_header, + ) + assert tool_instance._ToolboxTool__client_headers == static_client_header + + +def test_tool_init_header_auth_conflict( + http_session, + sample_tool_auth_params, + sample_tool_description, + auth_getters, + auth_header_key, +): + """Tests ValueError on init if client header conflicts with auth token.""" + conflicting_client_header = {auth_header_key: "some-client-value"} + + with pytest.raises( + ValueError, match=f"Client header\\(s\\) `{auth_header_key}` already registered" + ): + ToolboxTool( + session=http_session, + base_url=TEST_BASE_URL, + name="auth_conflict_tool", + description=sample_tool_description, + params=sample_tool_auth_params, + required_authn_params={}, + auth_service_token_getters=auth_getters, + bound_params={}, + client_headers=conflicting_client_header, + )