Skip to content

feat(toolbox-core): Add support for optional parameters #290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions packages/toolbox-core/src/toolbox_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,39 @@ class ParameterSchema(BaseModel):

name: str
type: str
required: bool = True
description: str
authSources: Optional[list[str]] = None
items: Optional["ParameterSchema"] = None

def __get_type(self) -> Type:
base_type: Type
if self.type == "string":
return str
base_type = str
elif self.type == "integer":
return int
base_type = int
elif self.type == "float":
return float
base_type = float
elif self.type == "boolean":
return bool
base_type = bool
elif self.type == "array":
if self.items is None:
raise Exception("Unexpected value: type is 'list' but items is None")
return list[self.items.__get_type()] # type: ignore
base_type = list[self.items.__get_type()] # type: ignore
else:
raise ValueError(f"Unsupported schema type: {self.type}")

raise ValueError(f"Unsupported schema type: {self.type}")
if not self.required:
return Optional[base_type] # type: ignore

return base_type

def to_param(self) -> Parameter:
return Parameter(
self.name,
Parameter.POSITIONAL_OR_KEYWORD,
annotation=self.__get_type(),
default=Parameter.empty if self.required else None,
)


Expand Down
13 changes: 11 additions & 2 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy
import itertools
from inspect import Signature
from types import MappingProxyType
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
Expand Down Expand Up @@ -89,7 +90,13 @@ def __init__(
self.__params = params
self.__pydantic_model = params_to_pydantic_model(name, self.__params)

inspect_type_params = [param.to_param() for param in self.__params]
# Separate parameters into required (no default) and optional (with
# default) to prevent the "non-default argument follows default
# argument" error when creating the function signature.
required_params = (p for p in self.__params if p.required)
optional_params = (p for p in self.__params if not p.required)
ordered_params = itertools.chain(required_params, optional_params)
inspect_type_params = [param.to_param() for param in ordered_params]

# the following properties are set to help anyone that might inspect it determine usage
self.__name__ = name
Expand Down Expand Up @@ -268,7 +275,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:

# validate inputs to this call using the signature
all_args = self.__signature__.bind(*args, **kwargs)
all_args.apply_defaults() # Include default values if not provided

# The payload will only contain arguments explicitly provided by the user.
# Optional arguments not provided by the user will not be in the payload.
payload = all_args.arguments

# Perform argument type validations using pydantic
Expand Down
11 changes: 10 additions & 1 deletion packages/toolbox-core/src/toolbox_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,20 @@ def params_to_pydantic_model(
"""Converts the given parameters to a Pydantic BaseModel class."""
field_definitions = {}
for field in params:

# Determine the default value based on the 'required' flag.
# '...' (Ellipsis) signifies a required field in Pydantic.
# 'None' makes the field optional with a default value of None.
default_value = ... if field.required else None

field_definitions[field.name] = cast(
Any,
(
field.to_param().annotation,
Field(description=field.description),
Field(
description=field.description,
default=default_value,
),
),
)
return create_model(tool_name, **field_definitions)
Expand Down
114 changes: 114 additions & 0 deletions packages/toolbox-core/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from inspect import Parameter, signature
from typing import Optional

import pytest
import pytest_asyncio
from pydantic import ValidationError
Expand Down Expand Up @@ -217,3 +221,113 @@ async def test_run_tool_param_auth_no_field(
match="no field named row_data in claims",
):
await tool()


@pytest.mark.asyncio
@pytest.mark.usefixtures("toolbox_server")
class TestOptionalParams:
"""
End-to-end tests for tools with optional parameters.
"""

async def test_tool_signature_is_correct(self, toolbox: ToolboxClient):
"""Verify the client correctly constructs the signature for a tool with optional params."""
tool = await toolbox.load_tool("search-rows")
sig = signature(tool)

assert "query" in sig.parameters
assert "limit" in sig.parameters

# The required parameter should have no default
assert sig.parameters["email"].default is Parameter.empty
assert sig.parameters["email"].annotation is str

# The optional parameter should have a default of None
assert sig.parameters["data"].default is "row2"
assert sig.parameters["limit"].annotation is Optional[str]

# The optional parameter should have a default of None
assert sig.parameters["id"].default is None
assert sig.parameters["id"].annotation is Optional[int]

async def test_run_tool_with_optional_params_omitted(self, toolbox: ToolboxClient):
"""Invoke a tool providing only the required parameter."""
tool = await toolbox.load_tool("search-rows")

response = await tool(email="twishabansal@google.com")
assert isinstance(response, str)
assert 'email="twishabansal@google.com"' in response
assert "row1" not in response
assert "row2" in response
assert "row3" not in response
assert "row4" not in response
assert "row5" not in response
assert "row6" not in response

async def test_run_tool_with_optional_data_provided(self, toolbox: ToolboxClient):
"""Invoke a tool providing both required and optional parameters."""
tool = await toolbox.load_tool("search-rows")

response = await tool(email="twishabansal@google.com", data="row3")
assert isinstance(response, str)
assert 'email="twishabansal@google.com"' in response
assert "row1" not in response
assert "row2" not in response
assert "row3" in response
assert "row4" not in response
assert "row5" not in response
assert "row6" not in response

async def test_run_tool_with_optional_data_null(self, toolbox: ToolboxClient):
"""Invoke a tool providing both required and optional parameters."""
tool = await toolbox.load_tool("search-rows")

response = await tool(email="twishabansal@google.com", data=None)
assert isinstance(response, str)
assert 'email="twishabansal@google.com"' in response
assert "row1" not in response
assert "row2" in response
assert "row3" not in response
assert "row4" not in response
assert "row5" not in response
assert "row6" not in response

async def test_run_tool_with_optional_id_provided(self, toolbox: ToolboxClient):
"""Invoke a tool providing both required and optional parameters."""
tool = await toolbox.load_tool("search-rows")

response = await tool(email="twishabansal@google.com", id=1)
assert isinstance(response, str)
assert 'email="twishabansal@google.com"' in response
assert "row1" in response
assert "row2" not in response
assert "row3" not in response
assert "row4" not in response
assert "row5" not in response
assert "row6" not in response

async def test_run_tool_with_optional_id_null(self, toolbox: ToolboxClient):
"""Invoke a tool providing both required and optional parameters."""
tool = await toolbox.load_tool("search-rows")

response = await tool(email="twishabansal@google.com", id=None)
assert isinstance(response, str)
assert 'email="twishabansal@google.com"' in response
assert "row1" not in response
assert "row2" in response
assert "row3" not in response
assert "row4" not in response
assert "row5" not in response
assert "row6" not in response

async def test_run_tool_with_missing_required_param(self, toolbox: ToolboxClient):
"""Invoke a tool without its required parameter."""
tool = await toolbox.load_tool("search-rows")
with pytest.raises(TypeError, match="missing a required argument: 'email'"):
await tool(id=5, data="row5")

async def test_run_tool_with_required_param_null(self, toolbox: ToolboxClient):
"""Invoke a tool without its required parameter."""
tool = await toolbox.load_tool("search-rows")
with pytest.raises(TypeError, match="missing a required argument: 'email'"):
await tool(email=None, id=5, data="row5")
64 changes: 64 additions & 0 deletions packages/toolbox-core/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from inspect import Parameter
from typing import Optional

import pytest

Expand Down Expand Up @@ -106,3 +107,66 @@ def test_parameter_schema_unsupported_type_error():

with pytest.raises(ValueError, match=expected_error_msg):
schema.to_param()


def test_parameter_schema_string_optional():
"""Tests an optional ParameterSchema with type 'string'."""
schema = ParameterSchema(
name="nickname",
type="string",
description="An optional nickname",
required=False,
)
expected_type = Optional[str]

# Test __get_type()
assert schema._ParameterSchema__get_type() == expected_type

# Test to_param()
param = schema.to_param()
assert isinstance(param, Parameter)
assert param.name == "nickname"
assert param.annotation == expected_type
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
assert param.default is None


def test_parameter_schema_required_by_default():
"""Tests that a parameter is required by default."""
# 'required' is not specified, so it should default to True.
schema = ParameterSchema(name="id", type="integer", description="A required ID")
expected_type = int

# Test __get_type()
assert schema._ParameterSchema__get_type() == expected_type

# Test to_param()
param = schema.to_param()
assert isinstance(param, Parameter)
assert param.name == "id"
assert param.annotation == expected_type
assert param.default == Parameter.empty


def test_parameter_schema_array_optional():
"""Tests an optional ParameterSchema with type 'array'."""
item_schema = ParameterSchema(name="", type="integer", description="")
schema = ParameterSchema(
name="optional_scores",
type="array",
description="An optional list of scores",
items=item_schema,
required=False,
)
expected_type = Optional[list[int]]

# Test __get_type()
assert schema._ParameterSchema__get_type() == expected_type

# Test to_param()
param = schema.to_param()
assert isinstance(param, Parameter)
assert param.name == "optional_scores"
assert param.annotation == expected_type
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
assert param.default is None
1 change: 1 addition & 0 deletions packages/toolbox-core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def create_param_mock(name: str, description: str, annotation: Type) -> Mock:
param_mock = Mock(spec=ParameterSchema)
param_mock.name = name
param_mock.description = description
param_mock.required = True

mock_param_info = Mock()
mock_param_info.annotation = annotation
Expand Down