From 2f5eae125b58ed6dce1720a65dcdced2bdc33474 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 27 Feb 2025 18:57:54 +0530 Subject: [PATCH 1/2] fix: Make all fields required in Tool schema. Earlier we made all fields as optional since we wanted to keep some fields optional for the LLM. Since Toolbox did not support optional fields, there was no way to know which fields were optional, so as a worst-case, we did a temporary workaround of keeping all fields as optional in the schema generated by Toolbox SDK. Now, there has been some evidence that the LLMs do not work very well with optional parameters, and so we have decided not to support optional fields for now, neither in Toolbox service nor in the SDK. This PR removes that temporary fix of making all the fields optional. This PR also removes an augmentation to the request body where `None` values were converted to empty strings (`''`). This is because now that LLM knows no fields are optional, we can be sure that we would not be getting any `None` values as inputs to the tools. So the function `_convert_none_to_empty_string` is not required anymore. --- src/toolbox_langchain/utils.py | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/src/toolbox_langchain/utils.py b/src/toolbox_langchain/utils.py index 53ab2edf..c63332f7 100644 --- a/src/toolbox_langchain/utils.py +++ b/src/toolbox_langchain/utils.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from typing import Any, Callable, Optional, Type, Union, cast +from typing import Any, Callable, Optional, Type, cast from warnings import warn from aiohttp import ClientSession @@ -99,9 +99,7 @@ def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[Bas field_definitions[field.name] = cast( Any, ( - # TODO: Remove the hardcoded optional types once optional fields - # are supported by Toolbox. - Optional[_parse_type(field)], + _parse_type(field), Field(description=field.description), ), ) @@ -202,7 +200,7 @@ async def _invoke_tool( async with session.post( url, - json=_convert_none_to_empty_string(data), + json=data, headers=auth_tokens, ) as response: # TODO: Remove as it masks error messages. @@ -210,29 +208,6 @@ async def _invoke_tool( return await response.json() -def _convert_none_to_empty_string(input_dict): - """ - Temporary fix to convert None values to empty strings in the input data. - This is needed because the current version of the Toolbox service does not - support optional fields. - - TODO: Remove this once optional fields are supported by Toolbox. - - Args: - input_dict: The input data dictionary. - - Returns: - A new dictionary with None values replaced by empty strings. - """ - new_dict = {} - for key, value in input_dict.items(): - if value is None: - new_dict[key] = "" - else: - new_dict[key] = value - return new_dict - - def _find_auth_params( params: list[ParameterSchema], ) -> tuple[list[ParameterSchema], list[ParameterSchema]]: From 3ec2c8eda28a8b1ac3b1c6b794380862d6ecabd2 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 27 Feb 2025 19:03:57 +0530 Subject: [PATCH 2/2] chore: Update unit tests. --- tests/test_utils.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index cb70d055..8e5139ed 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -25,7 +25,6 @@ from toolbox_langchain.utils import ( ParameterSchema, - _convert_none_to_empty_string, _get_auth_headers, _invoke_tool, _load_manifest, @@ -154,9 +153,9 @@ def test_schema_to_model(self): model = _schema_to_model("TestModel", schema) assert issubclass(model, BaseModel) - assert model.model_fields["param1"].annotation == Union[str, None] + assert model.model_fields["param1"].annotation == str assert model.model_fields["param1"].description == "Parameter 1" - assert model.model_fields["param2"].annotation == Union[int, None] + assert model.model_fields["param2"].annotation == int assert model.model_fields["param2"].description == "Parameter 2" def test_schema_to_model_empty(self): @@ -225,7 +224,7 @@ async def test_invoke_tool(self, mock_post): mock_post.assert_called_once_with( "http://localhost:8000/api/tool/tool_name/invoke", - json=_convert_none_to_empty_string({"input": "data"}), + json={"input": "data"}, headers={}, ) assert result == {"key": "value"} @@ -252,7 +251,7 @@ async def test_invoke_tool_unsecure_with_auth(self, mock_post): mock_post.assert_called_once_with( "http://localhost:8000/api/tool/tool_name/invoke", - json=_convert_none_to_empty_string({"input": "data"}), + json={"input": "data"}, headers={"my_test_auth_token": "fake_id_token"}, ) assert result == {"key": "value"} @@ -278,16 +277,11 @@ async def test_invoke_tool_secure_with_auth(self, mock_post): mock_post.assert_called_once_with( "https://localhost:8000/api/tool/tool_name/invoke", - json=_convert_none_to_empty_string({"input": "data"}), + json={"input": "data"}, headers={"my_test_auth_token": "fake_id_token"}, ) assert result == {"key": "value"} - def test_convert_none_to_empty_string(self): - input_dict = {"a": None, "b": 123} - expected_output = {"a": "", "b": 123} - assert _convert_none_to_empty_string(input_dict) == expected_output - def test_get_auth_headers_deprecation_warning(self): """Test _get_auth_headers deprecation warning.""" with pytest.warns(