diff --git a/packages/toolbox-core/src/toolbox_core/protocol.py b/packages/toolbox-core/src/toolbox_core/protocol.py index e2071ba2..6606ef93 100644 --- a/packages/toolbox-core/src/toolbox_core/protocol.py +++ b/packages/toolbox-core/src/toolbox_core/protocol.py @@ -25,31 +25,39 @@ class ParameterSchema(BaseModel): name: str type: str + required: bool = True description: str authSources: Optional[list[str]] = None items: Optional["ParameterSchema"] = None def __get_type(self) -> Type: + base_type: Type if self.type == "string": - return str + base_type = str elif self.type == "integer": - return int + base_type = int elif self.type == "float": - return float + base_type = float elif self.type == "boolean": - return bool + base_type = bool elif self.type == "array": if self.items is None: raise Exception("Unexpected value: type is 'list' but items is None") - return list[self.items.__get_type()] # type: ignore + base_type = list[self.items.__get_type()] # type: ignore + else: + raise ValueError(f"Unsupported schema type: {self.type}") - raise ValueError(f"Unsupported schema type: {self.type}") + if not self.required: + return Optional[base_type] # type: ignore + + return base_type def to_param(self) -> Parameter: return Parameter( self.name, Parameter.POSITIONAL_OR_KEYWORD, annotation=self.__get_type(), + default=Parameter.empty if self.required else None, ) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 48a31dab..8d413970 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import itertools from inspect import Signature from types import MappingProxyType from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union @@ -89,7 +90,13 @@ def __init__( self.__params = params self.__pydantic_model = params_to_pydantic_model(name, self.__params) - inspect_type_params = [param.to_param() for param in self.__params] + # Separate parameters into required (no default) and optional (with + # default) to prevent the "non-default argument follows default + # argument" error when creating the function signature. + required_params = (p for p in self.__params if p.required) + optional_params = (p for p in self.__params if not p.required) + ordered_params = itertools.chain(required_params, optional_params) + inspect_type_params = [param.to_param() for param in ordered_params] # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name @@ -268,7 +275,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # validate inputs to this call using the signature all_args = self.__signature__.bind(*args, **kwargs) - all_args.apply_defaults() # Include default values if not provided + + # The payload will only contain arguments explicitly provided by the user. + # Optional arguments not provided by the user will not be in the payload. payload = all_args.arguments # Perform argument type validations using pydantic diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index 615a23ec..28c046b1 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -111,11 +111,20 @@ def params_to_pydantic_model( """Converts the given parameters to a Pydantic BaseModel class.""" field_definitions = {} for field in params: + + # Determine the default value based on the 'required' flag. + # '...' (Ellipsis) signifies a required field in Pydantic. + # 'None' makes the field optional with a default value of None. + default_value = ... if field.required else None + field_definitions[field.name] = cast( Any, ( field.to_param().annotation, - Field(description=field.description), + Field( + description=field.description, + default=default_value, + ), ), ) return create_model(tool_name, **field_definitions) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 8920bc3b..de25262b 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from inspect import Parameter, signature +from typing import Optional + import pytest import pytest_asyncio from pydantic import ValidationError @@ -217,3 +221,113 @@ async def test_run_tool_param_auth_no_field( match="no field named row_data in claims", ): await tool() + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestOptionalParams: + """ + End-to-end tests for tools with optional parameters. + """ + + async def test_tool_signature_is_correct(self, toolbox: ToolboxClient): + """Verify the client correctly constructs the signature for a tool with optional params.""" + tool = await toolbox.load_tool("search-rows") + sig = signature(tool) + + assert "query" in sig.parameters + assert "limit" in sig.parameters + + # The required parameter should have no default + assert sig.parameters["email"].default is Parameter.empty + assert sig.parameters["email"].annotation is str + + # The optional parameter should have a default of None + assert sig.parameters["data"].default is "row2" + assert sig.parameters["limit"].annotation is Optional[str] + + # The optional parameter should have a default of None + assert sig.parameters["id"].default is None + assert sig.parameters["id"].annotation is Optional[int] + + async def test_run_tool_with_optional_params_omitted(self, toolbox: ToolboxClient): + """Invoke a tool providing only the required parameter.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com") + assert isinstance(response, str) + assert 'email="twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_data_provided(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", data="row3") + assert isinstance(response, str) + assert 'email="twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" not in response + assert "row3" in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_data_null(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", data=None) + assert isinstance(response, str) + assert 'email="twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_id_provided(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=1) + assert isinstance(response, str) + assert 'email="twishabansal@google.com"' in response + assert "row1" in response + assert "row2" not in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_id_null(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=None) + assert isinstance(response, str) + assert 'email="twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_missing_required_param(self, toolbox: ToolboxClient): + """Invoke a tool without its required parameter.""" + tool = await toolbox.load_tool("search-rows") + with pytest.raises(TypeError, match="missing a required argument: 'email'"): + await tool(id=5, data="row5") + + async def test_run_tool_with_required_param_null(self, toolbox: ToolboxClient): + """Invoke a tool without its required parameter.""" + tool = await toolbox.load_tool("search-rows") + with pytest.raises(TypeError, match="missing a required argument: 'email'"): + await tool(email=None, id=5, data="row5") diff --git a/packages/toolbox-core/tests/test_protocol.py b/packages/toolbox-core/tests/test_protocol.py index a70fa3fe..b7650792 100644 --- a/packages/toolbox-core/tests/test_protocol.py +++ b/packages/toolbox-core/tests/test_protocol.py @@ -14,6 +14,7 @@ from inspect import Parameter +from typing import Optional import pytest @@ -106,3 +107,66 @@ def test_parameter_schema_unsupported_type_error(): with pytest.raises(ValueError, match=expected_error_msg): schema.to_param() + + +def test_parameter_schema_string_optional(): + """Tests an optional ParameterSchema with type 'string'.""" + schema = ParameterSchema( + name="nickname", + type="string", + description="An optional nickname", + required=False, + ) + expected_type = Optional[str] + + # Test __get_type() + assert schema._ParameterSchema__get_type() == expected_type + + # Test to_param() + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "nickname" + assert param.annotation == expected_type + assert param.kind == Parameter.POSITIONAL_OR_KEYWORD + assert param.default is None + + +def test_parameter_schema_required_by_default(): + """Tests that a parameter is required by default.""" + # 'required' is not specified, so it should default to True. + schema = ParameterSchema(name="id", type="integer", description="A required ID") + expected_type = int + + # Test __get_type() + assert schema._ParameterSchema__get_type() == expected_type + + # Test to_param() + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "id" + assert param.annotation == expected_type + assert param.default == Parameter.empty + + +def test_parameter_schema_array_optional(): + """Tests an optional ParameterSchema with type 'array'.""" + item_schema = ParameterSchema(name="", type="integer", description="") + schema = ParameterSchema( + name="optional_scores", + type="array", + description="An optional list of scores", + items=item_schema, + required=False, + ) + expected_type = Optional[list[int]] + + # Test __get_type() + assert schema._ParameterSchema__get_type() == expected_type + + # Test to_param() + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "optional_scores" + assert param.annotation == expected_type + assert param.kind == Parameter.POSITIONAL_OR_KEYWORD + assert param.default is None diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py index c07f44cb..b3ddd7c3 100644 --- a/packages/toolbox-core/tests/test_utils.py +++ b/packages/toolbox-core/tests/test_utils.py @@ -34,6 +34,7 @@ def create_param_mock(name: str, description: str, annotation: Type) -> Mock: param_mock = Mock(spec=ParameterSchema) param_mock.name = name param_mock.description = description + param_mock.required = True mock_param_info = Mock() mock_param_info.annotation = annotation