Skip to content

Commit 5d62138

Browse files
twishabansalkurtisvgYuan325anubhav756
authored
feat: add type validation to params when tool is invoked (#129)
* feat: add authenticated parameters support * chore: add asyncio dep * chore: run itest * chore: add type hint * fix: call tool instead of client * chore: correct arg name * feat: add support for bound parameters * chore: add tests for bound parameters * docs: update syntax error on readme (#121) * ci: added release please config (#112) * ci: add release please config * chore: add initial version * chore: specify initial version as string * chore: Update .release-please-manifest.json * chore: add empty json * chore: small change * chore: try fixing config * chore: try fixing config again * chore: remove release-as * chore: add changelog sections * chore: better release notes * chore: better release notes * chore: change toolbox-langchain version * chore: separate PRs for packages * chore: change PR style * added basic e2e tests * change license year * add test deps * fix tests * fix tests * fix tests * add new test case * fix docstring * added todo * cleanup * add bind param test case * make bind params dynamic * try fix test errors * lint * remove redundant test * test fix * fix docstring * feat: add authenticated parameters support * chore: add asyncio dep * chore: run itest * chore: add type hint * fix: call tool instead of client * chore: correct arg name * chore: address feedback * chore: address more feedback * feat: add support for bound parameters * chore: add tests for bound parameters * chore: address feedback * revert package file changes * fix error message * revert package files * lint * fix error message * Update packages/toolbox-core/tests/test_e2e.py Co-authored-by: Anubhav Dhawan <anubhav756@gmail.com> * try changing docstring to include args and their descriptions * fix docstring * lint * fix test * lint * added return type annotation * Add docstrings * use and update schema # Conflicts: # packages/toolbox-core/src/toolbox_core/tool.py * lint * lint * try out pydantic validation # Conflicts: # packages/toolbox-core/src/toolbox_core/client.py # packages/toolbox-core/src/toolbox_core/tool.py * revert changes to e2e test file * add basic pydantic type checking * lint * fix pydantic validation error * fix error string * fix error string * lint * added to_pydantic_model as a method under the toolschema class * lint * revert changes to e2e test file # Conflicts: # packages/toolbox-core/tests/test_e2e.py * change string to str and in arg type * small change * fix imports * remove toolschema usage * lint * create pydantic model at init * lint * move create_docstring method outside the class. * lint * lint * Update packages/toolbox-core/src/toolbox_core/tool.py Co-authored-by: Anubhav Dhawan <anubhav756@gmail.com> * move to_pydantic_model outside the class * lint * rename function * added name to pydantic model * lint --------- Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Co-authored-by: Yuan <45984206+Yuan325@users.noreply.github.com> Co-authored-by: Anubhav Dhawan <anubhav756@gmail.com>
1 parent 6164b09 commit 5d62138

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
Mapping,
2424
Optional,
2525
Sequence,
26+
Type,
2627
Union,
28+
cast,
2729
)
2830

2931
from aiohttp import ClientSession
32+
from pydantic import BaseModel, Field, create_model
3033

3134
from toolbox_core.protocol import ParameterSchema
3235

@@ -78,6 +81,8 @@ def __init__(
7881
self.__url = f"{base_url}/api/tool/{name}/invoke"
7982
self.__description = description
8083
self.__params = params
84+
self.__pydantic_model = params_to_pydantic_model(name, self.__params)
85+
8186
inspect_type_params = [param.to_param() for param in self.__params]
8287

8388
# the following properties are set to help anyone that might inspect it determine usage
@@ -86,6 +91,7 @@ def __init__(
8691
self.__signature__ = Signature(
8792
parameters=inspect_type_params, return_annotation=str
8893
)
94+
8995
self.__annotations__ = {p.name: p.annotation for p in inspect_type_params}
9096
# TODO: self.__qualname__ ??
9197

@@ -170,6 +176,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
170176
all_args.apply_defaults() # Include default values if not provided
171177
payload = all_args.arguments
172178

179+
# Perform argument type validations using pydantic
180+
self.__pydantic_model.model_validate(payload)
181+
173182
# apply bounded parameters
174183
for param, value in self.__bound_parameters.items():
175184
if asyncio.iscoroutinefunction(value):
@@ -305,3 +314,19 @@ def identify_required_authn_params(
305314
if required:
306315
required_params[param] = services
307316
return required_params
317+
318+
319+
def params_to_pydantic_model(
320+
tool_name: str, params: Sequence[ParameterSchema]
321+
) -> Type[BaseModel]:
322+
"""Converts the given parameters to a Pydantic BaseModel class."""
323+
field_definitions = {}
324+
for field in params:
325+
field_definitions[field.name] = cast(
326+
Any,
327+
(
328+
field.to_param().annotation,
329+
Field(description=field.description),
330+
),
331+
)
332+
return create_model(tool_name, **field_definitions)

packages/toolbox-core/tests/test_e2e.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import pytest
1515
import pytest_asyncio
16+
from pydantic import ValidationError
1617

1718
from toolbox_core.client import ToolboxClient
1819
from toolbox_core.tool import ToolboxTool
@@ -77,8 +78,8 @@ async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool):
7778
async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool):
7879
"""Invoke a tool with wrong param type."""
7980
with pytest.raises(
80-
Exception,
81-
match='provided parameters were invalid: unable to parse value for "num_rows": .* not type "string"',
81+
ValidationError,
82+
match=r"num_rows\s+Input should be a valid string\s+\[type=string_type,\s+input_value=2,\s+input_type=int\]",
8283
):
8384
await get_n_rows_tool(num_rows=2)
8485

0 commit comments

Comments
 (0)