Skip to content

chore: Add unit tests for the tool, client and protocol files #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 16, 2025
Merged
1 change: 1 addition & 0 deletions packages/toolbox-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ test = [
"pytest==8.3.5",
"pytest-aioresponses==0.3.0",
"pytest-asyncio==0.26.0",
"pytest-cov==6.1.0",
"google-cloud-secret-manager==2.23.2",
"google-cloud-storage==3.1.0",
]
Expand Down
2 changes: 2 additions & 0 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import types
from typing import Any, Callable, Mapping, Optional, Union

Expand Down
4 changes: 1 addition & 3 deletions packages/toolbox-core/src/toolbox_core/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

import asyncio
from threading import Thread
from typing import Any, Awaitable, Callable, Mapping, Optional, TypeVar, Union

from aiohttp import ClientSession
from typing import Any, Callable, Mapping, Optional, TypeVar, Union

from .client import ToolboxClient
from .sync_tool import ToolboxSyncTool
Expand Down
34 changes: 28 additions & 6 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from inspect import Signature
from typing import (
Any,
Awaitable,
Callable,
Iterable,
Mapping,
Expand Down Expand Up @@ -181,16 +182,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:

# apply bounded parameters
for param, value in self.__bound_parameters.items():
if asyncio.iscoroutinefunction(value):
value = await value()
elif callable(value):
value = value()
payload[param] = value
payload[param] = await resolve_value(value)

# create headers for auth services
headers = {}
for auth_service, token_getter in self.__auth_service_token_getters.items():
headers[f"{auth_service}_token"] = token_getter()
headers[f"{auth_service}_token"] = await resolve_value(token_getter)

async with self.__session.post(
self.__url,
Expand Down Expand Up @@ -330,3 +327,28 @@ def params_to_pydantic_model(
),
)
return create_model(tool_name, **field_definitions)


async def resolve_value(
source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any],
) -> Any:
"""
Asynchronously or synchronously resolves a given source to its value.

If the `source` is a coroutine function, it will be awaited.
If the `source` is a regular callable, it will be called.
Otherwise (if it's not a callable), the `source` itself is returned directly.

Args:
source: The value, a callable returning a value, or a callable
returning an awaitable value.

Returns:
The resolved value.
"""

if asyncio.iscoroutinefunction(source):
return await source()
elif callable(source):
return source()
return source
162 changes: 156 additions & 6 deletions packages/toolbox-core/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import inspect
import json
from unittest.mock import AsyncMock, Mock

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -130,6 +131,60 @@ async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_b
assert {t.__name__ for t in tools} == manifest.tools.keys()


@pytest.mark.asyncio
async def test_invoke_tool_server_error(aioresponses, test_tool_str):
"""Tests that invoking a tool raises an Exception when the server returns an
error status."""
TOOL_NAME = "server_error_tool"
ERROR_MESSAGE = "Simulated Server Error"
manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str})

aioresponses.get(
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
payload=manifest.model_dump(),
status=200,
)
aioresponses.post(
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke",
payload={"error": ERROR_MESSAGE},
status=500,
)

async with ToolboxClient(TEST_BASE_URL) as client:
loaded_tool = await client.load_tool(TOOL_NAME)

with pytest.raises(Exception, match=ERROR_MESSAGE):
await loaded_tool(param1="some input")


@pytest.mark.asyncio
async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str):
"""
Tests that load_tool raises an Exception when the requested tool name
is not found in the manifest returned by the server, using existing fixtures.
"""
ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc"
REQUESTED_TOOL_NAME = "non_existent_tool_xyz"

manifest = ManifestSchema(
serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str}
)

aioresponses.get(
f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}",
payload=manifest.model_dump(),
status=200,
)

async with ToolboxClient(TEST_BASE_URL) as client:
with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"):
await client.load_tool(REQUESTED_TOOL_NAME)

aioresponses.assert_called_once_with(
f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET"
)


class TestAuth:

@pytest.fixture
Expand Down Expand Up @@ -182,7 +237,7 @@ def token_handler():
tool = await client.load_tool(
tool_name, auth_token_getters={"my-auth-service": token_handler}
)
res = await tool(5)
await tool(5)

@pytest.mark.asyncio
async def test_auth_with_add_token_success(
Expand All @@ -195,20 +250,35 @@ def token_handler():

tool = await client.load_tool(tool_name)
tool = tool.add_auth_token_getters({"my-auth-service": token_handler})
res = await tool(5)
await tool(5)

@pytest.mark.asyncio
async def test_auth_with_load_tool_fail_no_token(
self, tool_name, expected_header, client
):
"""Tests 'load_tool' with auth token is specified."""

def token_handler():
return expected_header

tool = await client.load_tool(tool_name)
with pytest.raises(Exception):
res = await tool(5)
await tool(5)

@pytest.mark.asyncio
async def test_add_auth_token_getters_duplicate_fail(self, tool_name, client):
"""
Tests that adding a duplicate auth token getter raises ValueError.
"""
AUTH_SERVICE = "my-auth-service"

tool = await client.load_tool(tool_name)

authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: {}})
assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters

with pytest.raises(
ValueError,
match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{tool_name}`.",
):
authed_tool.add_auth_token_getters({AUTH_SERVICE: {}})


class TestBoundParameter:
Expand Down Expand Up @@ -283,6 +353,22 @@ async def test_bind_param_success(self, tool_name, client):
assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool = tool.bind_parameters({"argA": 5})

assert len(tool.__signature__.parameters) == 1
assert "argA" not in tool.__signature__.parameters

res = await tool(True)
assert "argA" in res

@pytest.mark.asyncio
async def test_bind_callable_param_success(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter specified."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool = tool.bind_parameters({"argA": lambda: 5})

assert len(tool.__signature__.parameters) == 1
Expand All @@ -301,3 +387,67 @@ async def test_bind_param_fail(self, tool_name, client):

with pytest.raises(Exception):
tool = tool.bind_parameters({"argC": lambda: 5})

@pytest.mark.asyncio
async def test_bind_param_static_value_success(self, tool_name, client):
"""
Tests bind_parameters method with a static value.
"""

bound_value = "Test value"

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_parameters({"argB": bound_value})

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value}

@pytest.mark.asyncio
async def test_bind_param_sync_callable_value_success(self, tool_name, client):
"""
Tests bind_parameters method with a sync callable value.
"""

bound_value_result = True
bound_sync_callable = Mock(return_value=bound_value_result)

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_parameters({"argB": bound_sync_callable})

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_sync_callable.assert_called_once()

@pytest.mark.asyncio
async def test_bind_param_async_callable_value_success(self, tool_name, client):
"""
Tests bind_parameters method with an async callable value.
"""

bound_value_result = True
bound_async_callable = AsyncMock(return_value=bound_value_result)

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_parameters({"argB": bound_async_callable})

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_async_callable.assert_awaited_once()
14 changes: 14 additions & 0 deletions packages/toolbox-core/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,20 @@ async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str):
response = await auth_tool(id="2")
assert "row2" in response

@pytest.mark.asyncio
async def test_run_tool_async_auth(self, toolbox: ToolboxClient, auth_token1: str):
"""Tests running a tool with correct auth using an async token getter."""
tool = await toolbox.load_tool("get-row-by-id-auth")

async def get_token_asynchronously():
return auth_token1

auth_tool = tool.add_auth_token_getters(
{"my-test-auth": get_token_asynchronously}
)
response = await auth_tool(id="2")
assert "row2" in response

async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient):
"""Tests running a tool with a param requiring auth, without auth."""
tool = await toolbox.load_tool("get-row-by-email-auth")
Expand Down
Loading