diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 3436580a..3150be94 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,27 +13,28 @@ # limitations under the License. -import asyncio import types from inspect import Signature from typing import ( Any, - Awaitable, Callable, - Iterable, Mapping, Optional, Sequence, - Type, Union, - cast, ) from aiohttp import ClientSession -from pydantic import BaseModel, Field, create_model from toolbox_core.protocol import ParameterSchema +from .utils import ( + create_func_docstring, + identify_required_authn_params, + params_to_pydantic_model, + resolve_value, +) + class ToolboxTool: """ @@ -88,7 +89,7 @@ def __init__( # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name - self.__doc__ = create_docstring(self.__description, self.__params) + self.__doc__ = create_func_docstring(self.__description, self.__params) self.__signature__ = Signature( parameters=inspect_type_params, return_annotation=str ) @@ -271,84 +272,3 @@ def bind_parameters( params=new_params, bound_params=types.MappingProxyType(all_bound_params), ) - - -def create_docstring(description: str, params: Sequence[ParameterSchema]) -> str: - """Convert tool description and params into its function docstring""" - docstring = description - if not params: - return docstring - docstring += "\n\nArgs:" - for p in params: - docstring += ( - f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}" - ) - return docstring - - -def identify_required_authn_params( - req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] -) -> dict[str, list[str]]: - """ - Identifies authentication parameters that are still required; because they - not covered by the provided `auth_service_names`. - - Args: - req_authn_params: A mapping of parameter names to sets of required - authentication services. - auth_service_names: An iterable of authentication service names for which - token getters are available. - - Returns: - A new dictionary representing the subset of required authentication parameters - that are not covered by the provided `auth_services`. - """ - required_params = {} # params that are still required with provided auth_services - for param, services in req_authn_params.items(): - # if we don't have a token_getter for any of the services required by the param, - # the param is still required - required = not any(s in services for s in auth_service_names) - if required: - required_params[param] = services - return required_params - - -def params_to_pydantic_model( - tool_name: str, params: Sequence[ParameterSchema] -) -> Type[BaseModel]: - """Converts the given parameters to a Pydantic BaseModel class.""" - field_definitions = {} - for field in params: - field_definitions[field.name] = cast( - Any, - ( - field.to_param().annotation, - Field(description=field.description), - ), - ) - return create_model(tool_name, **field_definitions) - - -async def resolve_value( - source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any], -) -> Any: - """ - Asynchronously or synchronously resolves a given source to its value. - - If the `source` is a coroutine function, it will be awaited. - If the `source` is a regular callable, it will be called. - Otherwise (if it's not a callable), the `source` itself is returned directly. - - Args: - source: The value, a callable returning a value, or a callable - returning an awaitable value. - - Returns: - The resolved value. - """ - - if asyncio.iscoroutinefunction(source): - return await source() - elif callable(source): - return source() - return source diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py new file mode 100644 index 00000000..4c4ec5a7 --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -0,0 +1,112 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from typing import ( + Any, + Awaitable, + Callable, + Iterable, + Mapping, + Sequence, + Type, + Union, + cast, +) + +from pydantic import BaseModel, Field, create_model + +from toolbox_core.protocol import ParameterSchema + + +def create_func_docstring(description: str, params: Sequence[ParameterSchema]) -> str: + """Convert tool description and params into its function docstring""" + docstring = description + if not params: + return docstring + docstring += "\n\nArgs:" + for p in params: + docstring += ( + f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}" + ) + return docstring + + +def identify_required_authn_params( + req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] +) -> dict[str, list[str]]: + """ + Identifies authentication parameters that are still required; because they + are not covered by the provided `auth_service_names`. + + Args: + req_authn_params: A mapping of parameter names to sets of required + authentication services. + auth_service_names: An iterable of authentication service names for which + token getters are available. + + Returns: + A new dictionary representing the subset of required authentication parameters + that are not covered by the provided `auth_services`. + """ + required_params = {} # params that are still required with provided auth_services + for param, services in req_authn_params.items(): + # if we don't have a token_getter for any of the services required by the param, + # the param is still required + required = not any(s in services for s in auth_service_names) + if required: + required_params[param] = services + return required_params + + +def params_to_pydantic_model( + tool_name: str, params: Sequence[ParameterSchema] +) -> Type[BaseModel]: + """Converts the given parameters to a Pydantic BaseModel class.""" + field_definitions = {} + for field in params: + field_definitions[field.name] = cast( + Any, + ( + field.to_param().annotation, + Field(description=field.description), + ), + ) + return create_model(tool_name, **field_definitions) + + +async def resolve_value( + source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any], +) -> Any: + """ + Asynchronously or synchronously resolves a given source to its value. + + If the `source` is a coroutine function, it will be awaited. + If the `source` is a regular callable, it will be called. + Otherwise (if it's not a callable), the `source` itself is returned directly. + + Args: + source: The value, a callable returning a value, or a callable + returning an awaitable value. + + Returns: + The resolved value. + """ + + if asyncio.iscoroutinefunction(source): + return await source() + elif callable(source): + return source() + return source diff --git a/packages/toolbox-core/tests/test_tools.py b/packages/toolbox-core/tests/test_tools.py index 505aa7f7..7cb4f305 100644 --- a/packages/toolbox-core/tests/test_tools.py +++ b/packages/toolbox-core/tests/test_tools.py @@ -23,7 +23,7 @@ from pydantic import ValidationError from toolbox_core.protocol import ParameterSchema -from toolbox_core.tool import ToolboxTool, create_docstring, resolve_value +from toolbox_core.tool import ToolboxTool, create_func_docstring, resolve_value TEST_BASE_URL = "http://toolbox.example.com" TEST_TOOL_NAME = "sample_tool" @@ -53,9 +53,9 @@ async def http_session() -> AsyncGenerator[ClientSession, None]: yield session -def test_create_docstring_one_param_real_schema(): +def test_create_func_docstring_one_param_real_schema(): """ - Tests create_docstring with one real ParameterSchema instance. + Tests create_func_docstring with one real ParameterSchema instance. """ description = "This tool does one thing." params = [ @@ -64,7 +64,7 @@ def test_create_docstring_one_param_real_schema(): ) ] - result_docstring = create_docstring(description, params) + result_docstring = create_func_docstring(description, params) expected_docstring = ( "This tool does one thing.\n\n" @@ -75,9 +75,9 @@ def test_create_docstring_one_param_real_schema(): assert result_docstring == expected_docstring -def test_create_docstring_multiple_params_real_schema(): +def test_create_func_docstring_multiple_params_real_schema(): """ - Tests create_docstring with multiple real ParameterSchema instances. + Tests create_func_docstring with multiple real ParameterSchema instances. """ description = "This tool does multiple things." params = [ @@ -90,7 +90,7 @@ def test_create_docstring_multiple_params_real_schema(): ), ] - result_docstring = create_docstring(description, params) + result_docstring = create_func_docstring(description, params) expected_docstring = ( "This tool does multiple things.\n\n" @@ -103,9 +103,9 @@ def test_create_docstring_multiple_params_real_schema(): assert result_docstring == expected_docstring -def test_create_docstring_no_description_real_schema(): +def test_create_func_docstring_no_description_real_schema(): """ - Tests create_docstring with empty description and one real ParameterSchema. + Tests create_func_docstring with empty description and one real ParameterSchema. """ description = "" params = [ @@ -114,7 +114,7 @@ def test_create_docstring_no_description_real_schema(): ) ] - result_docstring = create_docstring(description, params) + result_docstring = create_func_docstring(description, params) expected_docstring = ( "\n\nArgs:\n" " config_id (str): The ID of the configuration." @@ -125,14 +125,14 @@ def test_create_docstring_no_description_real_schema(): assert "config_id (str): The ID of the configuration." in result_docstring -def test_create_docstring_no_params(): +def test_create_func_docstring_no_params(): """ - Tests create_docstring when the params list is empty. + Tests create_func_docstring when the params list is empty. """ description = "This is a tool description." params = [] - result_docstring = create_docstring(description, params) + result_docstring = create_func_docstring(description, params) assert result_docstring == description assert "\n\nArgs:" not in result_docstring diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py new file mode 100644 index 00000000..b71284b6 --- /dev/null +++ b/packages/toolbox-core/tests/test_utils.py @@ -0,0 +1,258 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from typing import Type +from unittest.mock import Mock + +import pytest +from pydantic import BaseModel, ValidationError + +from toolbox_core.protocol import ParameterSchema +from toolbox_core.utils import ( + create_func_docstring, + identify_required_authn_params, + params_to_pydantic_model, + resolve_value, +) + + +def create_param_mock(name: str, description: str, annotation: Type) -> Mock: + """Creates a mock for ParameterSchema.""" + param_mock = Mock(spec=ParameterSchema) + param_mock.name = name + param_mock.description = description + + mock_param_info = Mock() + mock_param_info.annotation = annotation + + param_mock.to_param.return_value = mock_param_info + return param_mock + + +def test_create_func_docstring_no_params(): + """Test create_func_docstring with no parameters.""" + description = "This is a tool description." + params = [] + expected_docstring = "This is a tool description." + assert create_func_docstring(description, params) == expected_docstring + + +def test_create_func_docstring_with_params(): + """Test create_func_docstring with multiple parameters using mocks.""" + description = "Tool description." + params = [ + create_param_mock( + name="param1", description="First parameter.", annotation=str + ), + create_param_mock(name="count", description="A number.", annotation=int), + ] + expected_docstring = """Tool description. + +Args: + param1 (str): First parameter. + count (int): A number.""" + assert create_func_docstring(description, params) == expected_docstring + + +def test_create_func_docstring_empty_description(): + """Test create_func_docstring with an empty description using mocks.""" + description = "" + params = [ + create_param_mock( + name="param1", description="First parameter.", annotation=str + ), + ] + expected_docstring = """ + +Args: + param1 (str): First parameter.""" + assert create_func_docstring(description, params) == expected_docstring + + +def test_identify_required_authn_params_none_required(): + """Test when no authentication parameters are required initially.""" + req_authn_params = {} + auth_service_names = ["service_a", "service_b"] + expected = {} + assert ( + identify_required_authn_params(req_authn_params, auth_service_names) == expected + ) + + +def test_identify_required_authn_params_all_covered(): + """Test when all required parameters are covered by available services.""" + req_authn_params = { + "token_a": ["service_a"], + "token_b": ["service_b", "service_c"], + } + auth_service_names = ["service_a", "service_b"] + expected = {} + assert ( + identify_required_authn_params(req_authn_params, auth_service_names) == expected + ) + + +def test_identify_required_authn_params_some_covered(): + """Test when some parameters are covered, and some are not.""" + req_authn_params = { + "token_a": ["service_a"], + "token_b": ["service_b", "service_c"], + "token_d": ["service_d"], + "token_e": ["service_e", "service_f"], + } + auth_service_names = ["service_a", "service_b"] + expected = { + "token_d": ["service_d"], + "token_e": ["service_e", "service_f"], + } + assert ( + identify_required_authn_params(req_authn_params, auth_service_names) == expected + ) + + +def test_identify_required_authn_params_none_covered(): + """Test when none of the required parameters are covered.""" + req_authn_params = { + "token_d": ["service_d"], + "token_e": ["service_e", "service_f"], + } + auth_service_names = ["service_a", "service_b"] + expected = { + "token_d": ["service_d"], + "token_e": ["service_e", "service_f"], + } + assert ( + identify_required_authn_params(req_authn_params, auth_service_names) == expected + ) + + +def test_identify_required_authn_params_no_available_services(): + """Test when no authentication services are available.""" + req_authn_params = { + "token_a": ["service_a"], + "token_b": ["service_b", "service_c"], + } + auth_service_names = [] + expected = { + "token_a": ["service_a"], + "token_b": ["service_b", "service_c"], + } + assert ( + identify_required_authn_params(req_authn_params, auth_service_names) == expected + ) + + +def test_identify_required_authn_params_empty_services_for_param(): + """Test edge case where a param requires an empty list of services.""" + req_authn_params = { + "token_x": [], + } + auth_service_names = ["service_a"] + expected = { + "token_x": [], + } + assert ( + identify_required_authn_params(req_authn_params, auth_service_names) == expected + ) + + +def test_params_to_pydantic_model_no_params(): + """Test creating a Pydantic model with no parameters.""" + tool_name = "NoParamTool" + params = [] + Model = params_to_pydantic_model(tool_name, params) + + assert issubclass(Model, BaseModel) + assert Model.__name__ == tool_name + assert not Model.model_fields + + instance = Model() + assert isinstance(instance, BaseModel) + + +def test_params_to_pydantic_model_with_params(): + """Test creating a Pydantic model with various parameter types using mocks.""" + tool_name = "MyTool" + params = [ + create_param_mock(name="name", description="User name", annotation=str), + create_param_mock(name="age", description="User age", annotation=int), + create_param_mock( + name="is_active", description="Activity status", annotation=bool + ), + ] + Model = params_to_pydantic_model(tool_name, params) + + assert issubclass(Model, BaseModel) + assert Model.__name__ == tool_name + assert len(Model.model_fields) == 3 + + assert "name" in Model.model_fields + assert Model.model_fields["name"].annotation == str + assert Model.model_fields["name"].description == "User name" + + assert "age" in Model.model_fields + assert Model.model_fields["age"].annotation == int + assert Model.model_fields["age"].description == "User age" + + assert "is_active" in Model.model_fields + assert Model.model_fields["is_active"].annotation == bool + assert Model.model_fields["is_active"].description == "Activity status" + + instance = Model(name="Alice", age=30, is_active=True) + assert instance.name == "Alice" + assert instance.age == 30 + assert instance.is_active is True + + with pytest.raises(ValidationError): + Model(name="Bob", age="thirty", is_active=True) + + +@pytest.mark.asyncio +async def test_resolve_value_plain_value(): + """Test resolving a plain, non-callable value.""" + value = 123 + assert await resolve_value(value) == 123 + + value = "hello" + assert await resolve_value(value) == "hello" + + value = None + assert await resolve_value(value) is None + + +@pytest.mark.asyncio +async def test_resolve_value_sync_callable(): + """Test resolving a synchronous callable using Mock.""" + mock_sync_func = Mock(return_value="sync result") + assert await resolve_value(mock_sync_func) == "sync result" + mock_sync_func.assert_called_once() + assert await resolve_value(lambda: [1, 2, 3]) == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_resolve_value_async_callable(): + """Test resolving an asynchronous callable (coroutine function).""" + + async def async_func(): + await asyncio.sleep(0.01) + return "async result" + + assert await resolve_value(async_func) == "async result" + + async def another_async_func(): + return {"key": "value"} + + assert await resolve_value(another_async_func) == {"key": "value"}