Skip to content

Commit 14bc4ee

Browse files
authored
chore: Add unit tests for the tool, client and protocol files (#170)
* dep: Add pytest-cov package as a test dependency. * chore: Remove unused imports from sync_client.py * chore: Add unit tests for the tool and client classes * chore: Delint * chore: Delint * chore: Cover tool not found case * chore: Add toolbox tool unit test cases * chore: Add additional test cases to cover tool invocation and better docstring validation. * chore: Add test cases for sync and static bound parameter. * chore: Reorder tests in matching classes. This will improve maintainability. * feat: Add support for async token getters to ToolboxTool (#147) * feat: Add support for async token getters to ToolboxTool * chore: Improve variable names and docstring for more clarity * chore: Improve docstring * chore: Add unit test cases * chore: Add e2e test case * chore: Fix e2e test case
1 parent bddc8b5 commit 14bc4ee

File tree

8 files changed

+597
-15
lines changed

8 files changed

+597
-15
lines changed

packages/toolbox-core/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ test = [
4646
"pytest==8.3.5",
4747
"pytest-aioresponses==0.3.0",
4848
"pytest-asyncio==0.26.0",
49+
"pytest-cov==6.1.0",
4950
"google-cloud-secret-manager==2.23.2",
5051
"google-cloud-storage==3.1.0",
5152
]

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
1416
import types
1517
from typing import Any, Callable, Mapping, Optional, Union
1618

packages/toolbox-core/src/toolbox_core/sync_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414

1515
import asyncio
1616
from threading import Thread
17-
from typing import Any, Awaitable, Callable, Mapping, Optional, TypeVar, Union
18-
19-
from aiohttp import ClientSession
17+
from typing import Any, Callable, Mapping, Optional, TypeVar, Union
2018

2119
from .client import ToolboxClient
2220
from .sync_tool import ToolboxSyncTool

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from inspect import Signature
1919
from typing import (
2020
Any,
21+
Awaitable,
2122
Callable,
2223
Iterable,
2324
Mapping,
@@ -181,16 +182,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
181182

182183
# apply bounded parameters
183184
for param, value in self.__bound_parameters.items():
184-
if asyncio.iscoroutinefunction(value):
185-
value = await value()
186-
elif callable(value):
187-
value = value()
188-
payload[param] = value
185+
payload[param] = await resolve_value(value)
189186

190187
# create headers for auth services
191188
headers = {}
192189
for auth_service, token_getter in self.__auth_service_token_getters.items():
193-
headers[f"{auth_service}_token"] = token_getter()
190+
headers[f"{auth_service}_token"] = await resolve_value(token_getter)
194191

195192
async with self.__session.post(
196193
self.__url,
@@ -330,3 +327,28 @@ def params_to_pydantic_model(
330327
),
331328
)
332329
return create_model(tool_name, **field_definitions)
330+
331+
332+
async def resolve_value(
333+
source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any],
334+
) -> Any:
335+
"""
336+
Asynchronously or synchronously resolves a given source to its value.
337+
338+
If the `source` is a coroutine function, it will be awaited.
339+
If the `source` is a regular callable, it will be called.
340+
Otherwise (if it's not a callable), the `source` itself is returned directly.
341+
342+
Args:
343+
source: The value, a callable returning a value, or a callable
344+
returning an awaitable value.
345+
346+
Returns:
347+
The resolved value.
348+
"""
349+
350+
if asyncio.iscoroutinefunction(source):
351+
return await source()
352+
elif callable(source):
353+
return source()
354+
return source

packages/toolbox-core/tests/test_client.py

Lines changed: 156 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import inspect
1717
import json
18+
from unittest.mock import AsyncMock, Mock
1819

1920
import pytest
2021
import pytest_asyncio
@@ -130,6 +131,60 @@ async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_b
130131
assert {t.__name__ for t in tools} == manifest.tools.keys()
131132

132133

134+
@pytest.mark.asyncio
135+
async def test_invoke_tool_server_error(aioresponses, test_tool_str):
136+
"""Tests that invoking a tool raises an Exception when the server returns an
137+
error status."""
138+
TOOL_NAME = "server_error_tool"
139+
ERROR_MESSAGE = "Simulated Server Error"
140+
manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str})
141+
142+
aioresponses.get(
143+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
144+
payload=manifest.model_dump(),
145+
status=200,
146+
)
147+
aioresponses.post(
148+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke",
149+
payload={"error": ERROR_MESSAGE},
150+
status=500,
151+
)
152+
153+
async with ToolboxClient(TEST_BASE_URL) as client:
154+
loaded_tool = await client.load_tool(TOOL_NAME)
155+
156+
with pytest.raises(Exception, match=ERROR_MESSAGE):
157+
await loaded_tool(param1="some input")
158+
159+
160+
@pytest.mark.asyncio
161+
async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str):
162+
"""
163+
Tests that load_tool raises an Exception when the requested tool name
164+
is not found in the manifest returned by the server, using existing fixtures.
165+
"""
166+
ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc"
167+
REQUESTED_TOOL_NAME = "non_existent_tool_xyz"
168+
169+
manifest = ManifestSchema(
170+
serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str}
171+
)
172+
173+
aioresponses.get(
174+
f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}",
175+
payload=manifest.model_dump(),
176+
status=200,
177+
)
178+
179+
async with ToolboxClient(TEST_BASE_URL) as client:
180+
with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"):
181+
await client.load_tool(REQUESTED_TOOL_NAME)
182+
183+
aioresponses.assert_called_once_with(
184+
f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET"
185+
)
186+
187+
133188
class TestAuth:
134189

135190
@pytest.fixture
@@ -182,7 +237,7 @@ def token_handler():
182237
tool = await client.load_tool(
183238
tool_name, auth_token_getters={"my-auth-service": token_handler}
184239
)
185-
res = await tool(5)
240+
await tool(5)
186241

187242
@pytest.mark.asyncio
188243
async def test_auth_with_add_token_success(
@@ -195,20 +250,35 @@ def token_handler():
195250

196251
tool = await client.load_tool(tool_name)
197252
tool = tool.add_auth_token_getters({"my-auth-service": token_handler})
198-
res = await tool(5)
253+
await tool(5)
199254

200255
@pytest.mark.asyncio
201256
async def test_auth_with_load_tool_fail_no_token(
202257
self, tool_name, expected_header, client
203258
):
204259
"""Tests 'load_tool' with auth token is specified."""
205260

206-
def token_handler():
207-
return expected_header
208-
209261
tool = await client.load_tool(tool_name)
210262
with pytest.raises(Exception):
211-
res = await tool(5)
263+
await tool(5)
264+
265+
@pytest.mark.asyncio
266+
async def test_add_auth_token_getters_duplicate_fail(self, tool_name, client):
267+
"""
268+
Tests that adding a duplicate auth token getter raises ValueError.
269+
"""
270+
AUTH_SERVICE = "my-auth-service"
271+
272+
tool = await client.load_tool(tool_name)
273+
274+
authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: {}})
275+
assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters
276+
277+
with pytest.raises(
278+
ValueError,
279+
match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{tool_name}`.",
280+
):
281+
authed_tool.add_auth_token_getters({AUTH_SERVICE: {}})
212282

213283

214284
class TestBoundParameter:
@@ -283,6 +353,22 @@ async def test_bind_param_success(self, tool_name, client):
283353
assert len(tool.__signature__.parameters) == 2
284354
assert "argA" in tool.__signature__.parameters
285355

356+
tool = tool.bind_parameters({"argA": 5})
357+
358+
assert len(tool.__signature__.parameters) == 1
359+
assert "argA" not in tool.__signature__.parameters
360+
361+
res = await tool(True)
362+
assert "argA" in res
363+
364+
@pytest.mark.asyncio
365+
async def test_bind_callable_param_success(self, tool_name, client):
366+
"""Tests 'bind_param' with a bound parameter specified."""
367+
tool = await client.load_tool(tool_name)
368+
369+
assert len(tool.__signature__.parameters) == 2
370+
assert "argA" in tool.__signature__.parameters
371+
286372
tool = tool.bind_parameters({"argA": lambda: 5})
287373

288374
assert len(tool.__signature__.parameters) == 1
@@ -301,3 +387,67 @@ async def test_bind_param_fail(self, tool_name, client):
301387

302388
with pytest.raises(Exception):
303389
tool = tool.bind_parameters({"argC": lambda: 5})
390+
391+
@pytest.mark.asyncio
392+
async def test_bind_param_static_value_success(self, tool_name, client):
393+
"""
394+
Tests bind_parameters method with a static value.
395+
"""
396+
397+
bound_value = "Test value"
398+
399+
tool = await client.load_tool(tool_name)
400+
bound_tool = tool.bind_parameters({"argB": bound_value})
401+
402+
assert bound_tool is not tool
403+
assert "argB" not in bound_tool.__signature__.parameters
404+
assert "argA" in bound_tool.__signature__.parameters
405+
406+
passed_value_a = 42
407+
res_payload = await bound_tool(argA=passed_value_a)
408+
409+
assert res_payload == {"argA": passed_value_a, "argB": bound_value}
410+
411+
@pytest.mark.asyncio
412+
async def test_bind_param_sync_callable_value_success(self, tool_name, client):
413+
"""
414+
Tests bind_parameters method with a sync callable value.
415+
"""
416+
417+
bound_value_result = True
418+
bound_sync_callable = Mock(return_value=bound_value_result)
419+
420+
tool = await client.load_tool(tool_name)
421+
bound_tool = tool.bind_parameters({"argB": bound_sync_callable})
422+
423+
assert bound_tool is not tool
424+
assert "argB" not in bound_tool.__signature__.parameters
425+
assert "argA" in bound_tool.__signature__.parameters
426+
427+
passed_value_a = 42
428+
res_payload = await bound_tool(argA=passed_value_a)
429+
430+
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
431+
bound_sync_callable.assert_called_once()
432+
433+
@pytest.mark.asyncio
434+
async def test_bind_param_async_callable_value_success(self, tool_name, client):
435+
"""
436+
Tests bind_parameters method with an async callable value.
437+
"""
438+
439+
bound_value_result = True
440+
bound_async_callable = AsyncMock(return_value=bound_value_result)
441+
442+
tool = await client.load_tool(tool_name)
443+
bound_tool = tool.bind_parameters({"argB": bound_async_callable})
444+
445+
assert bound_tool is not tool
446+
assert "argB" not in bound_tool.__signature__.parameters
447+
assert "argA" in bound_tool.__signature__.parameters
448+
449+
passed_value_a = 42
450+
res_payload = await bound_tool(argA=passed_value_a)
451+
452+
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
453+
bound_async_callable.assert_awaited_once()

packages/toolbox-core/tests/test_e2e.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,20 @@ async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str):
166166
response = await auth_tool(id="2")
167167
assert "row2" in response
168168

169+
@pytest.mark.asyncio
170+
async def test_run_tool_async_auth(self, toolbox: ToolboxClient, auth_token1: str):
171+
"""Tests running a tool with correct auth using an async token getter."""
172+
tool = await toolbox.load_tool("get-row-by-id-auth")
173+
174+
async def get_token_asynchronously():
175+
return auth_token1
176+
177+
auth_tool = tool.add_auth_token_getters(
178+
{"my-test-auth": get_token_asynchronously}
179+
)
180+
response = await auth_tool(id="2")
181+
assert "row2" in response
182+
169183
async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient):
170184
"""Tests running a tool with a param requiring auth, without auth."""
171185
tool = await toolbox.load_tool("get-row-by-email-auth")

0 commit comments

Comments
 (0)