Skip to content

Commit c2b52aa

Browse files
authored
fix: Make all fields required in Tool schema (#68)
* fix: Make all fields required in Tool schema. Earlier we made all fields as optional since we wanted to keep some fields optional for the LLM. Since Toolbox did not support optional fields, there was no way to know which fields were optional, so as a worst-case, we did a temporary workaround of keeping all fields as optional in the schema generated by Toolbox SDK. Now, there has been some evidence that the LLMs do not work very well with optional parameters, and so we have decided not to support optional fields for now, neither in Toolbox service nor in the SDK. This PR removes that temporary fix of making all the fields optional. This PR also removes an augmentation to the request body where `None` values were converted to empty strings (`''`). This is because now that LLM knows no fields are optional, we can be sure that we would not be getting any `None` values as inputs to the tools. So the function `_convert_none_to_empty_string` is not required anymore. * chore: Update unit tests.
1 parent 57671e8 commit c2b52aa

File tree

2 files changed

+8
-39
lines changed

2 files changed

+8
-39
lines changed

src/toolbox_langchain/utils.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16-
from typing import Any, Callable, Optional, Type, Union, cast
16+
from typing import Any, Callable, Optional, Type, cast
1717
from warnings import warn
1818

1919
from aiohttp import ClientSession
@@ -99,9 +99,7 @@ def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[Bas
9999
field_definitions[field.name] = cast(
100100
Any,
101101
(
102-
# TODO: Remove the hardcoded optional types once optional fields
103-
# are supported by Toolbox.
104-
Optional[_parse_type(field)],
102+
_parse_type(field),
105103
Field(description=field.description),
106104
),
107105
)
@@ -202,37 +200,14 @@ async def _invoke_tool(
202200

203201
async with session.post(
204202
url,
205-
json=_convert_none_to_empty_string(data),
203+
json=data,
206204
headers=auth_tokens,
207205
) as response:
208206
# TODO: Remove as it masks error messages.
209207
response.raise_for_status()
210208
return await response.json()
211209

212210

213-
def _convert_none_to_empty_string(input_dict):
214-
"""
215-
Temporary fix to convert None values to empty strings in the input data.
216-
This is needed because the current version of the Toolbox service does not
217-
support optional fields.
218-
219-
TODO: Remove this once optional fields are supported by Toolbox.
220-
221-
Args:
222-
input_dict: The input data dictionary.
223-
224-
Returns:
225-
A new dictionary with None values replaced by empty strings.
226-
"""
227-
new_dict = {}
228-
for key, value in input_dict.items():
229-
if value is None:
230-
new_dict[key] = ""
231-
else:
232-
new_dict[key] = value
233-
return new_dict
234-
235-
236211
def _find_auth_params(
237212
params: list[ParameterSchema],
238213
) -> tuple[list[ParameterSchema], list[ParameterSchema]]:

tests/test_utils.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from toolbox_langchain.utils import (
2727
ParameterSchema,
28-
_convert_none_to_empty_string,
2928
_get_auth_headers,
3029
_invoke_tool,
3130
_load_manifest,
@@ -154,9 +153,9 @@ def test_schema_to_model(self):
154153
model = _schema_to_model("TestModel", schema)
155154
assert issubclass(model, BaseModel)
156155

157-
assert model.model_fields["param1"].annotation == Union[str, None]
156+
assert model.model_fields["param1"].annotation == str
158157
assert model.model_fields["param1"].description == "Parameter 1"
159-
assert model.model_fields["param2"].annotation == Union[int, None]
158+
assert model.model_fields["param2"].annotation == int
160159
assert model.model_fields["param2"].description == "Parameter 2"
161160

162161
def test_schema_to_model_empty(self):
@@ -225,7 +224,7 @@ async def test_invoke_tool(self, mock_post):
225224

226225
mock_post.assert_called_once_with(
227226
"http://localhost:8000/api/tool/tool_name/invoke",
228-
json=_convert_none_to_empty_string({"input": "data"}),
227+
json={"input": "data"},
229228
headers={},
230229
)
231230
assert result == {"key": "value"}
@@ -252,7 +251,7 @@ async def test_invoke_tool_unsecure_with_auth(self, mock_post):
252251

253252
mock_post.assert_called_once_with(
254253
"http://localhost:8000/api/tool/tool_name/invoke",
255-
json=_convert_none_to_empty_string({"input": "data"}),
254+
json={"input": "data"},
256255
headers={"my_test_auth_token": "fake_id_token"},
257256
)
258257
assert result == {"key": "value"}
@@ -278,16 +277,11 @@ async def test_invoke_tool_secure_with_auth(self, mock_post):
278277

279278
mock_post.assert_called_once_with(
280279
"https://localhost:8000/api/tool/tool_name/invoke",
281-
json=_convert_none_to_empty_string({"input": "data"}),
280+
json={"input": "data"},
282281
headers={"my_test_auth_token": "fake_id_token"},
283282
)
284283
assert result == {"key": "value"}
285284

286-
def test_convert_none_to_empty_string(self):
287-
input_dict = {"a": None, "b": 123}
288-
expected_output = {"a": "", "b": 123}
289-
assert _convert_none_to_empty_string(input_dict) == expected_output
290-
291285
def test_get_auth_headers_deprecation_warning(self):
292286
"""Test _get_auth_headers deprecation warning."""
293287
with pytest.warns(

0 commit comments

Comments
 (0)