Skip to content

chore: Add unit tests for the tool, client and protocol files #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 16, 2025
Merged
1 change: 1 addition & 0 deletions packages/toolbox-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ test = [
"pytest==8.3.5",
"pytest-aioresponses==0.3.0",
"pytest-asyncio==0.26.0",
"pytest-cov==6.1.0",
"google-cloud-secret-manager==2.23.2",
"google-cloud-storage==3.1.0",
]
Expand Down
2 changes: 2 additions & 0 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import types
from typing import Any, Callable, Mapping, Optional, Union

Expand Down
4 changes: 1 addition & 3 deletions packages/toolbox-core/src/toolbox_core/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

import asyncio
from threading import Thread
from typing import Any, Awaitable, Callable, Mapping, Optional, TypeVar, Union

from aiohttp import ClientSession
from typing import Any, Callable, Mapping, Optional, TypeVar, Union

from .client import ToolboxClient
from .sync_tool import ToolboxSyncTool
Expand Down
162 changes: 156 additions & 6 deletions packages/toolbox-core/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import inspect
import json
from unittest.mock import AsyncMock, Mock

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -130,6 +131,60 @@ async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_b
assert {t.__name__ for t in tools} == manifest.tools.keys()


@pytest.mark.asyncio
async def test_invoke_tool_server_error(aioresponses, test_tool_str):
"""Tests that invoking a tool raises an Exception when the server returns an
error status."""
TOOL_NAME = "server_error_tool"
ERROR_MESSAGE = "Simulated Server Error"
manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str})

aioresponses.get(
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
payload=manifest.model_dump(),
status=200,
)
aioresponses.post(
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke",
payload={"error": ERROR_MESSAGE},
status=500,
)

async with ToolboxClient(TEST_BASE_URL) as client:
loaded_tool = await client.load_tool(TOOL_NAME)

with pytest.raises(Exception, match=ERROR_MESSAGE):
await loaded_tool(param1="some input")


@pytest.mark.asyncio
async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str):
"""
Tests that load_tool raises an Exception when the requested tool name
is not found in the manifest returned by the server, using existing fixtures.
"""
ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc"
REQUESTED_TOOL_NAME = "non_existent_tool_xyz"

manifest = ManifestSchema(
serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str}
)

aioresponses.get(
f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}",
payload=manifest.model_dump(),
status=200,
)

async with ToolboxClient(TEST_BASE_URL) as client:
with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"):
await client.load_tool(REQUESTED_TOOL_NAME)

aioresponses.assert_called_once_with(
f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET"
)


class TestAuth:

@pytest.fixture
Expand Down Expand Up @@ -182,7 +237,7 @@ def token_handler():
tool = await client.load_tool(
tool_name, auth_token_getters={"my-auth-service": token_handler}
)
res = await tool(5)
await tool(5)

@pytest.mark.asyncio
async def test_auth_with_add_token_success(
Expand All @@ -195,20 +250,35 @@ def token_handler():

tool = await client.load_tool(tool_name)
tool = tool.add_auth_token_getters({"my-auth-service": token_handler})
res = await tool(5)
await tool(5)

@pytest.mark.asyncio
async def test_auth_with_load_tool_fail_no_token(
self, tool_name, expected_header, client
):
"""Tests 'load_tool' with auth token is specified."""

def token_handler():
return expected_header

tool = await client.load_tool(tool_name)
with pytest.raises(Exception):
res = await tool(5)
await tool(5)

@pytest.mark.asyncio
async def test_add_auth_token_getters_duplicate_fail(self, tool_name, client):
"""
Tests that adding a duplicate auth token getter raises ValueError.
"""
AUTH_SERVICE = "my-auth-service"

tool = await client.load_tool(tool_name)

authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: {}})
assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters

with pytest.raises(
ValueError,
match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{tool_name}`.",
):
authed_tool.add_auth_token_getters({AUTH_SERVICE: {}})


class TestBoundParameter:
Expand Down Expand Up @@ -283,6 +353,22 @@ async def test_bind_param_success(self, tool_name, client):
assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool = tool.bind_parameters({"argA": 5})

assert len(tool.__signature__.parameters) == 1
assert "argA" not in tool.__signature__.parameters

res = await tool(True)
assert "argA" in res

@pytest.mark.asyncio
async def test_bind_callable_param_success(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter specified."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool = tool.bind_parameters({"argA": lambda: 5})

assert len(tool.__signature__.parameters) == 1
Expand All @@ -301,3 +387,67 @@ async def test_bind_param_fail(self, tool_name, client):

with pytest.raises(Exception):
tool = tool.bind_parameters({"argC": lambda: 5})

@pytest.mark.asyncio
async def test_bind_param_static_value_success(self, tool_name, client):
"""
Tests bind_parameters method with a static value.
"""

bound_value = "Test value"

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_parameters({"argB": bound_value})

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value}

@pytest.mark.asyncio
async def test_bind_param_sync_callable_value_success(self, tool_name, client):
"""
Tests bind_parameters method with a sync callable value.
"""

bound_value_result = True
bound_sync_callable = Mock(return_value=bound_value_result)

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_parameters({"argB": bound_sync_callable})

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_sync_callable.assert_called_once()

@pytest.mark.asyncio
async def test_bind_param_async_callable_value_success(self, tool_name, client):
"""
Tests bind_parameters method with an async callable value.
"""

bound_value_result = True
bound_async_callable = AsyncMock(return_value=bound_value_result)

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_parameters({"argB": bound_async_callable})

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_async_callable.assert_awaited_once()
108 changes: 108 additions & 0 deletions packages/toolbox-core/tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from inspect import Parameter

import pytest

from toolbox_core.protocol import ParameterSchema


def test_parameter_schema_float():
"""Tests ParameterSchema with type 'float'."""
schema = ParameterSchema(name="price", type="float", description="The item price")
expected_type = float
assert schema._ParameterSchema__get_type() == expected_type

param = schema.to_param()
assert isinstance(param, Parameter)
assert param.name == "price"
assert param.annotation == expected_type
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
assert param.default == Parameter.empty


def test_parameter_schema_boolean():
"""Tests ParameterSchema with type 'boolean'."""
schema = ParameterSchema(
name="is_active", type="boolean", description="Activity status"
)
expected_type = bool
assert schema._ParameterSchema__get_type() == expected_type

param = schema.to_param()
assert isinstance(param, Parameter)
assert param.name == "is_active"
assert param.annotation == expected_type
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD


def test_parameter_schema_array_string():
"""Tests ParameterSchema with type 'array' containing strings."""
item_schema = ParameterSchema(name="", type="string", description="")
schema = ParameterSchema(
name="tags", type="array", description="List of tags", items=item_schema
)

assert schema._ParameterSchema__get_type() == list[str]

param = schema.to_param()
assert isinstance(param, Parameter)
assert param.name == "tags"
assert param.annotation == list[str]
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD


def test_parameter_schema_array_integer():
"""Tests ParameterSchema with type 'array' containing integers."""
item_schema = ParameterSchema(name="", type="integer", description="")
schema = ParameterSchema(
name="scores", type="array", description="List of scores", items=item_schema
)

param = schema.to_param()
assert isinstance(param, Parameter)
assert param.name == "scores"
assert param.annotation == list[int]
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD


def test_parameter_schema_array_no_items_error():
"""Tests that 'array' type raises error if 'items' is None."""
schema = ParameterSchema(
name="bad_list", type="array", description="List without item type"
)

expected_error_msg = "Unexpected value: type is 'list' but items is None"
with pytest.raises(Exception, match=expected_error_msg):
schema._ParameterSchema__get_type()

with pytest.raises(Exception, match=expected_error_msg):
schema.to_param()


def test_parameter_schema_unsupported_type_error():
"""Tests that an unsupported type raises ValueError."""
unsupported_type = "datetime"
schema = ParameterSchema(
name="event_time", type=unsupported_type, description="When it happened"
)

expected_error_msg = f"Unsupported schema type: {unsupported_type}"
with pytest.raises(ValueError, match=expected_error_msg):
schema._ParameterSchema__get_type()

with pytest.raises(ValueError, match=expected_error_msg):
schema.to_param()
Loading