From ebd317bd8dd8c0f78116a6695448b7c121321215 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 2 May 2025 22:08:27 +0530 Subject: [PATCH 1/4] chore: Add unit test coverage. --- packages/toolbox-core/tests/test_client.py | 132 +++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index d510d73f..c1ed1393 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -697,6 +697,138 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client): bound_async_callable.assert_awaited_once() + @pytest.mark.asyncio + async def test_bind_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_param("argA", 5) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_callable_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool = tool.bind_param("argA", lambda: 5) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_param_fail(self, tool_name, client): + """Tests 'bind_param' with a bound parameter that doesn't exist.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + with pytest.raises(Exception) as e: + tool.bind_param("argC", lambda: 5) + assert "unable to bind parameters: no parameter named argC" in str(e.value) + + @pytest.mark.asyncio + async def test_rebind_param_fail(self, tool_name, client): + """ + Tests that 'bind_param' fails when attempting to re-bind a + parameter that has already been bound. + """ + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + + tool_with_bound_param = tool.bind_param("argA", lambda: 10) + + assert len(tool_with_bound_param.__signature__.parameters) == 1 + assert "argA" not in tool_with_bound_param.__signature__.parameters + + with pytest.raises(ValueError) as e: + tool_with_bound_param.bind_param("argA", lambda: 20) + + assert "cannot re-bind parameter: parameter 'argA' is already bound" in str( + e.value + ) + + @pytest.mark.asyncio + async def test_bind_param_static_value_success(self, tool_name, client): + """ + Tests bind_param method with a static value. + """ + + bound_value = "Test value" + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_param("argB", bound_value) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value} + + @pytest.mark.asyncio + async def test_bind_param_sync_callable_value_success(self, tool_name, client): + """ + Tests bind_param method with a sync callable value. + """ + + bound_value_result = True + bound_sync_callable = Mock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_param("argB", bound_sync_callable) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_sync_callable.assert_called_once() + + @pytest.mark.asyncio + async def test_bind_param_async_callable_value_success(self, tool_name, client): + """ + Tests bind_param method with an async callable value. + """ + + bound_value_result = True + bound_async_callable = AsyncMock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_param("argB", bound_async_callable) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_async_callable.assert_awaited_once() + + class TestUnusedParameterValidation: """ Tests for validation errors related to unused auth tokens or bound From 7ff635548878fc46f4993cb18895fa0fdcfd9004 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 5 May 2025 13:35:59 +0530 Subject: [PATCH 2/4] chore: Consolidate auth header creation logic Post adding the feature of adding client-level auth headers (#178), we have the logic for creating an auth header, from the given auth token getter name, in 3 different places. This PR unifies all of that logic into a single helper to improve maintenance, and make it easier to change the way we add suffix/prefix, and reduces WET code. --- packages/toolbox-core/src/toolbox_core/tool.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 19b6d710..e401beed 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -94,7 +94,7 @@ def __init__( # Validate conflicting Headers/Auth Tokens request_header_names = client_headers.keys() auth_token_names = [ - auth_token_name + "_token" + self.__get_auth_header(auth_token_name) for auth_token_name in auth_service_token_getters.keys() ] duplicates = request_header_names & auth_token_names @@ -187,6 +187,11 @@ def __copy( client_headers=check(client_headers, self.__client_headers), ) + def __get_auth_header(self, auth_token_name: str) -> str: + """Returns the formatted auth token header name.""" + return f"{auth_token_name}_token" + + async def __call__(self, *args: Any, **kwargs: Any) -> str: """ Asynchronously calls the remote tool with the provided arguments. @@ -228,7 +233,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): - headers[f"{auth_service}_token"] = await resolve_value(token_getter) + headers[self.__get_auth_header(auth_service)] = await resolve_value(token_getter) for client_header_name, client_header_val in self.__client_headers.items(): headers[client_header_name] = await resolve_value(client_header_val) @@ -276,7 +281,8 @@ def add_auth_token_getters( # Validate duplicates with client headers request_header_names = self.__client_headers.keys() auth_token_names = [ - auth_token_name + "_token" for auth_token_name in incoming_services + self.__get_auth_header(auth_token_name) + for auth_token_name in incoming_services ] duplicates = request_header_names & auth_token_names if duplicates: From e6396b68cb9b6502485a404c6b8063b46033caae Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 5 May 2025 13:36:31 +0530 Subject: [PATCH 3/4] chore: Delint --- packages/toolbox-core/src/toolbox_core/tool.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index e401beed..bdddd09a 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -191,7 +191,6 @@ def __get_auth_header(self, auth_token_name: str) -> str: """Returns the formatted auth token header name.""" return f"{auth_token_name}_token" - async def __call__(self, *args: Any, **kwargs: Any) -> str: """ Asynchronously calls the remote tool with the provided arguments. @@ -233,7 +232,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): - headers[self.__get_auth_header(auth_service)] = await resolve_value(token_getter) + headers[self.__get_auth_header(auth_service)] = await resolve_value( + token_getter + ) for client_header_name, client_header_val in self.__client_headers.items(): headers[client_header_name] = await resolve_value(client_header_val) From d89dc407fc38ac693cc9b29a2857eec5eca57f54 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 12 May 2025 18:53:52 +0530 Subject: [PATCH 4/4] chore: Delint --- packages/toolbox-core/tests/test_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index c1ed1393..82066bca 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -696,7 +696,6 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client): assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} bound_async_callable.assert_awaited_once() - @pytest.mark.asyncio async def test_bind_param_success(self, tool_name, client): """Tests 'bind_param' with a bound parameter specified."""