Skip to content

fix: Make all fields required in Tool schema #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 3 additions & 28 deletions src/toolbox_langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import json
from typing import Any, Callable, Optional, Type, Union, cast
from typing import Any, Callable, Optional, Type, cast
from warnings import warn

from aiohttp import ClientSession
Expand Down Expand Up @@ -99,9 +99,7 @@ def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[Bas
field_definitions[field.name] = cast(
Any,
(
# TODO: Remove the hardcoded optional types once optional fields
# are supported by Toolbox.
Optional[_parse_type(field)],
_parse_type(field),
Field(description=field.description),
),
)
Expand Down Expand Up @@ -202,37 +200,14 @@ async def _invoke_tool(

async with session.post(
url,
json=_convert_none_to_empty_string(data),
json=data,
headers=auth_tokens,
) as response:
# TODO: Remove as it masks error messages.
response.raise_for_status()
return await response.json()


def _convert_none_to_empty_string(input_dict):
"""
Temporary fix to convert None values to empty strings in the input data.
This is needed because the current version of the Toolbox service does not
support optional fields.

TODO: Remove this once optional fields are supported by Toolbox.

Args:
input_dict: The input data dictionary.

Returns:
A new dictionary with None values replaced by empty strings.
"""
new_dict = {}
for key, value in input_dict.items():
if value is None:
new_dict[key] = ""
else:
new_dict[key] = value
return new_dict


def _find_auth_params(
params: list[ParameterSchema],
) -> tuple[list[ParameterSchema], list[ParameterSchema]]:
Expand Down
16 changes: 5 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from toolbox_langchain.utils import (
ParameterSchema,
_convert_none_to_empty_string,
_get_auth_headers,
_invoke_tool,
_load_manifest,
Expand Down Expand Up @@ -154,9 +153,9 @@ def test_schema_to_model(self):
model = _schema_to_model("TestModel", schema)
assert issubclass(model, BaseModel)

assert model.model_fields["param1"].annotation == Union[str, None]
assert model.model_fields["param1"].annotation == str
assert model.model_fields["param1"].description == "Parameter 1"
assert model.model_fields["param2"].annotation == Union[int, None]
assert model.model_fields["param2"].annotation == int
assert model.model_fields["param2"].description == "Parameter 2"

def test_schema_to_model_empty(self):
Expand Down Expand Up @@ -225,7 +224,7 @@ async def test_invoke_tool(self, mock_post):

mock_post.assert_called_once_with(
"http://localhost:8000/api/tool/tool_name/invoke",
json=_convert_none_to_empty_string({"input": "data"}),
json={"input": "data"},
headers={},
)
assert result == {"key": "value"}
Expand All @@ -252,7 +251,7 @@ async def test_invoke_tool_unsecure_with_auth(self, mock_post):

mock_post.assert_called_once_with(
"http://localhost:8000/api/tool/tool_name/invoke",
json=_convert_none_to_empty_string({"input": "data"}),
json={"input": "data"},
headers={"my_test_auth_token": "fake_id_token"},
)
assert result == {"key": "value"}
Expand All @@ -278,16 +277,11 @@ async def test_invoke_tool_secure_with_auth(self, mock_post):

mock_post.assert_called_once_with(
"https://localhost:8000/api/tool/tool_name/invoke",
json=_convert_none_to_empty_string({"input": "data"}),
json={"input": "data"},
headers={"my_test_auth_token": "fake_id_token"},
)
assert result == {"key": "value"}

def test_convert_none_to_empty_string(self):
input_dict = {"a": None, "b": 123}
expected_output = {"a": "", "b": 123}
assert _convert_none_to_empty_string(input_dict) == expected_output

def test_get_auth_headers_deprecation_warning(self):
"""Test _get_auth_headers deprecation warning."""
with pytest.warns(
Expand Down