diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 19b6d710..bdddd09a 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -94,7 +94,7 @@ def __init__( # Validate conflicting Headers/Auth Tokens request_header_names = client_headers.keys() auth_token_names = [ - auth_token_name + "_token" + self.__get_auth_header(auth_token_name) for auth_token_name in auth_service_token_getters.keys() ] duplicates = request_header_names & auth_token_names @@ -187,6 +187,10 @@ def __copy( client_headers=check(client_headers, self.__client_headers), ) + def __get_auth_header(self, auth_token_name: str) -> str: + """Returns the formatted auth token header name.""" + return f"{auth_token_name}_token" + async def __call__(self, *args: Any, **kwargs: Any) -> str: """ Asynchronously calls the remote tool with the provided arguments. @@ -228,7 +232,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): - headers[f"{auth_service}_token"] = await resolve_value(token_getter) + headers[self.__get_auth_header(auth_service)] = await resolve_value( + token_getter + ) for client_header_name, client_header_val in self.__client_headers.items(): headers[client_header_name] = await resolve_value(client_header_val) @@ -276,7 +282,8 @@ def add_auth_token_getters( # Validate duplicates with client headers request_header_names = self.__client_headers.keys() auth_token_names = [ - auth_token_name + "_token" for auth_token_name in incoming_services + self.__get_auth_header(auth_token_name) + for auth_token_name in incoming_services ] duplicates = request_header_names & auth_token_names if duplicates: diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index d510d73f..82066bca 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -696,6 +696,137 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client): assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} bound_async_callable.assert_awaited_once() + @pytest.mark.asyncio + async def test_bind_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_param("argA", 5) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_callable_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_param("argA", lambda: 5) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_param_fail(self, tool_name, client): + """Tests 'bind_param' with a bound parameter that doesn't exist.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + with pytest.raises(Exception) as e: + tool.bind_param("argC", lambda: 5) + assert "unable to bind parameters: no parameter named argC" in str(e.value) + + @pytest.mark.asyncio + async def test_rebind_param_fail(self, tool_name, client): + """ + Tests that 'bind_param' fails when attempting to re-bind a + parameter that has already been bound. + """ + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool_with_bound_param = tool.bind_param("argA", lambda: 10) + + assert len(tool_with_bound_param.__signature__.parameters) == 1 + assert "argA" not in tool_with_bound_param.__signature__.parameters + + with pytest.raises(ValueError) as e: + tool_with_bound_param.bind_param("argA", lambda: 20) + + assert "cannot re-bind parameter: parameter 'argA' is already bound" in str( + e.value + ) + + @pytest.mark.asyncio + async def test_bind_param_static_value_success(self, tool_name, client): + """ + Tests bind_param method with a static value. + """ + + bound_value = "Test value" + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_param("argB", bound_value) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value} + + @pytest.mark.asyncio + async def test_bind_param_sync_callable_value_success(self, tool_name, client): + """ + Tests bind_param method with a sync callable value. + """ + + bound_value_result = True + bound_sync_callable = Mock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_param("argB", bound_sync_callable) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_sync_callable.assert_called_once() + + @pytest.mark.asyncio + async def test_bind_param_async_callable_value_success(self, tool_name, client): + """ + Tests bind_param method with an async callable value. + """ + + bound_value_result = True + bound_async_callable = AsyncMock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_param("argB", bound_async_callable) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_async_callable.assert_awaited_once() + class TestUnusedParameterValidation: """