diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 48e4626c..01a5ff76 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,11 +13,14 @@ # limitations under the License. +import warnings from inspect import Parameter, Signature -from typing import Any +from typing import Any, Callable, TypeVar, Union from aiohttp import ClientSession +T = TypeVar("T", bound="ToolboxTool") + class ToolboxTool: """ @@ -25,7 +28,10 @@ class ToolboxTool: Instances of this class behave like asynchronous functions. When called, they send a request to the corresponding tool's endpoint on the Toolbox server with - the provided arguments. + the provided arguments, including any bound parameters. + + Methods like `bind_param` return *new* instances + with the added state, ensuring immutability of the original tool object. It utilizes Python's introspection features (`__name__`, `__doc__`, `__signature__`, `__annotations__`) so that standard tools like `help()` @@ -43,6 +49,7 @@ def __init__( name: str, desc: str, params: list[Parameter], + bound_params: Union[dict[str, Union[Any, Callable[[], Any]]], None] = None, ): """ Initializes a callable that will trigger the tool invocation through the Toolbox server. @@ -54,43 +61,252 @@ def __init__( desc: The description of the remote tool (used as its docstring). params: A list of `inspect.Parameter` objects defining the tool's arguments and their types/defaults. + bound_params: Pre-existing bound parameters. """ + self.__base_url = base_url # used to invoke the toolbox API self.__session = session self.__url = f"{base_url}/api/tool/{name}/invoke" + self.__original_params = params + + # Store bound params + self.__bound_params = bound_params or {} + + # Filter out bound parameters from the signature exposed to the user + visible_params = [p for p in params if p.name not in self.__bound_params] # the following properties are set to help anyone that might inspect it determine self.__name__ = name self.__doc__ = desc - self.__signature__ = Signature(parameters=params, return_annotation=str) - self.__annotations__ = {p.name: p.annotation for p in params} + # The signature only shows non-bound parameters + self.__signature__ = Signature(parameters=visible_params, return_annotation=str) + self.__annotations__ = {p.name: p.annotation for p in visible_params} # TODO: self.__qualname__ ?? + def _evaluate_param_vals( + self, params: dict[str, Union[Any, Callable[[], Any]]] + ) -> dict[str, Any]: + """ + Evaluate any callable parameter values. + + Iterates through the input dictionary, calling any callable values + to get their actual result. Non-callable values are kept as is. + + Args: + params: A dictionary where keys are parameter names + and values are either static values or callables returning a value. + + Returns: + A dictionary containing all parameter names with their resolved, static values. + + Raises: + RuntimeError: If evaluating a callable parameter value fails. + """ + resolved_parameters: dict[str, Any] = {} + for param_name, param_value in params.items(): + try: + resolved_parameters[param_name] = ( + param_value() if callable(param_value) else param_value + ) + except Exception as e: + raise RuntimeError( + f"Error evaluating parameter '{param_name}' for tool '{self.__name__}': {e}" + ) from e + return resolved_parameters + + def _prepare_arguments(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + """ + Evaluates parameters, merges with call arguments, and binds them + to the tool's full signature. + + Args: + *args: Positional arguments provided at call time. + **kwargs: Keyword arguments provided at call time. + + Returns: + A dictionary of all arguments ready to be sent to the API. + + Raises: + TypeError: If call-time arguments conflict with bound parameters, + or if arguments don't match the tool's signature. + RuntimeError: If evaluating a bound parameter fails. + """ + # Check for conflicts between resolved bound params and keyword arguments + conflicts = self.__bound_params.keys() & kwargs.keys() + if conflicts: + raise TypeError( + f"Tool '{self.__name__}': Cannot provide value during call for already bound argument(s): {', '.join(conflicts)}" + ) + + evaluated_bound_params = self._evaluate_param_vals(self.__bound_params) + # Merge params with provided keyword arguments + merged_kwargs = {**evaluated_bound_params, **kwargs} + + # Bind *args and merged_kwargs using the *original* full signature + full_signature = Signature( + parameters=self.__original_params, return_annotation=str + ) + try: + bound_args = full_signature.bind(*args, **merged_kwargs) + except TypeError as e: + raise TypeError( + f"Argument binding error for tool '{self.__name__}' (check arguments against signature {full_signature} and bound params {list(self.__bound_params.keys())}): {e}" + ) from e + + # Apply default values for any missing arguments + bound_args.apply_defaults() + return bound_args.arguments + async def __call__(self, *args: Any, **kwargs: Any) -> str: """ - Asynchronously calls the remote tool with the provided arguments. + Asynchronously calls the remote tool with the provided arguments and bound parameters. - Validates arguments against the tool's signature, then sends them - as a JSON payload in a POST request to the tool's invoke URL. + Validates arguments against the tool's signature (excluding bound parameters), + then sends bound parameters and call arguments as a JSON payload in a POST request to the tool's invoke URL. Args: - *args: Positional arguments for the tool. - **kwargs: Keyword arguments for the tool. + *args: Positional arguments for the tool (for non-bound parameters). + **kwargs: Keyword arguments for the tool (for non-bound parameters). Returns: The string result returned by the remote tool execution. + + Raises: + TypeError: If a bound parameter conflicts with a parameter provided at call time. + Exception: If the remote tool call results in an error. """ - all_args = self.__signature__.bind(*args, **kwargs) - all_args.apply_defaults() # Include default values if not provided - payload = all_args.arguments + arguments_payload = self._prepare_arguments(*args, **kwargs) + # Make the API call async with self.__session.post( self.__url, - json=payload, + json=arguments_payload, ) as resp: - ret = await resp.json() - if "error" in ret: - # TODO: better error - raise Exception(ret["error"]) - return ret.get("result", ret) + try: + ret = await resp.json() + except Exception as e: + raise Exception( + f"Failed to decode JSON response from tool '{self.__name__}': {e}. Status: {resp.status}, Body: {await resp.text()}" + ) from e + + if resp.status >= 400 or "error" in ret: + error_detail = ret.get("error", ret) if isinstance(ret, dict) else ret + raise Exception( + f"Tool '{self.__name__}' invocation failed with status {resp.status}: {error_detail}" + ) + + # Handle cases where 'result' might be missing but no explicit error given + return ret.get( + "result", str(ret) + ) # Return string representation if 'result' key missing + + # --- Methods for adding state (return new instances) --- + def _copy_with_updates( + self: T, + *, + add_bound_params: Union[dict[str, Union[Any, Callable[[], Any]]], None] = None, + ) -> T: + """Creates a new instance with updated bound params.""" + new_bound_params = self.__bound_params.copy() + if add_bound_params: + new_bound_params.update(add_bound_params) + + return self.__class__( + session=self.__session, + base_url=self.__base_url, + name=self.__name__, + desc=self.__doc__ or "", + params=self.__original_params, + bound_params=new_bound_params, + ) + + def bind_params( + self: T, + params_to_bind: dict[str, Union[Any, Callable[[], Any]]], + strict: bool = True, + ) -> T: + """ + Returns a *new* tool instance with the provided parameters bound. + + Bound parameters are pre-filled values or callables that resolve to values + when the tool is called. They are not part of the signature of the + returned tool instance. + + Args: + params_to_bind: A dictionary mapping parameter names to their + values or callables that return the value. + strict: If True (default), raises ValueError if attempting to bind + a parameter that doesn't exist in the original tool signature + or is already bound in this instance. If False, issues a warning. + + Returns: + A new ToolboxTool instance with the specified parameters bound. + + Raises: + ValueError: If strict is True and a parameter name is invalid or + already bound. + """ + invalid_params: list[str] = [] + duplicate_params: list[str] = [] + original_param_names = {p.name for p in self.__original_params} + + for name in params_to_bind: + if name not in original_param_names: + invalid_params.append(name) + elif name in self.__bound_params: + duplicate_params.append(name) + + messages: list[str] = [] + if invalid_params: + messages.append( + f"Parameter(s) {', '.join(invalid_params)} do not exist in the signature for tool '{self.__name__}'." + ) + if duplicate_params: + messages.append( + f"Parameter(s) {', '.join(duplicate_params)} are already bound in this instance of tool '{self.__name__}'." + ) + + if messages: + message = "\n".join(messages) + if strict: + raise ValueError(message) + else: + warnings.warn(message) + # Filter out problematic params if not strict + params_to_bind = { + k: v + for k, v in params_to_bind.items() + if k not in invalid_params and k not in duplicate_params + } + + if not params_to_bind: + return self + + return self._copy_with_updates(add_bound_params=params_to_bind) + + def bind_param( + self: T, + param_name: str, + param_value: Union[Any, Callable[[], Any]], + strict: bool = True, + ) -> T: + """ + Returns a *new* tool instance with the provided parameter bound. + + Convenience method for binding a single parameter. + + Args: + param_name: The name of the parameter to bind. + param_value: The value or callable for the parameter. + strict: If True (default), raises ValueError if the parameter name + is invalid or already bound. If False, issues a warning. + + Returns: + A new ToolboxTool instance with the specified parameter bound. + + Raises: + ValueError: If strict is True and the parameter name is invalid or + already bound. + """ + return self.bind_params({param_name: param_value}, strict=strict) diff --git a/packages/toolbox-core/tests/test_tool.py b/packages/toolbox-core/tests/test_tool.py index 593f50fe..e2eccf0e 100644 --- a/packages/toolbox-core/tests/test_tool.py +++ b/packages/toolbox-core/tests/test_tool.py @@ -29,10 +29,8 @@ def mock_session(self) -> MagicMock: # Added self return session @pytest.fixture - def tool_details(self) -> dict: - base_url = "http://fake-toolbox.com" - tool_name = "test_tool" - params = [ + def tool_params(self) -> list[Parameter]: + return [ Parameter("arg1", Parameter.POSITIONAL_OR_KEYWORD, annotation=str), Parameter( "opt_arg", @@ -40,15 +38,31 @@ def tool_details(self) -> dict: default=123, annotation=Optional[int], ), + Parameter( + "req_kwarg", Parameter.KEYWORD_ONLY, annotation=bool + ), # Added back ] + + @pytest.fixture + def tool_details(self, tool_params: list[Parameter]) -> dict[str, Any]: + """Provides common details for constructing the test tool.""" + base_url = "http://fake-toolbox.com" + tool_name = "test_tool" + params = tool_params + full_signature = Signature(parameters=params, return_annotation=str) + public_signature = Signature(parameters=params, return_annotation=str) + full_annotations = {"arg1": str, "opt_arg": Optional[int], "req_kwarg": bool} + public_annotations = full_annotations.copy() + return { "base_url": base_url, "name": tool_name, "desc": "A tool for testing.", "params": params, - "signature": Signature(parameters=params, return_annotation=str), + "full_signature": full_signature, "expected_url": f"{base_url}/api/tool/{tool_name}/invoke", - "annotations": {"arg1": str, "opt_arg": Optional[int]}, + "public_signature": public_signature, + "public_annotations": public_annotations, } @pytest.fixture @@ -59,6 +73,7 @@ def tool(self, mock_session: MagicMock, tool_details: dict) -> ToolboxTool: name=tool_details["name"], desc=tool_details["desc"], params=tool_details["params"], + bound_params=None, ) @pytest.fixture @@ -81,9 +96,9 @@ async def test_initialization_and_introspection( assert tool.__name__ == tool_details["name"] assert tool.__doc__ == tool_details["desc"] assert tool._ToolboxTool__url == tool_details["expected_url"] - assert tool._ToolboxTool__session is tool._ToolboxTool__session - assert tool.__signature__ == tool_details["signature"] - assert tool.__annotations__ == tool_details["annotations"] + assert tool.__signature__ == tool_details["public_signature"] + assert tool.__annotations__ == tool_details["public_annotations"] + assert tool._ToolboxTool__bound_params == {} # assert hasattr(tool, "__qualname__") @pytest.mark.asyncio @@ -99,91 +114,119 @@ async def test_call_success( arg1_val = "test_value" opt_arg_val = 456 - result = await tool(arg1_val, opt_arg=opt_arg_val) + req_kwarg_val = True + result = await tool(arg1_val, opt_arg=opt_arg_val, req_kwarg=req_kwarg_val) assert result == expected_result mock_session.post.assert_called_once_with( tool_details["expected_url"], - json={"arg1": arg1_val, "opt_arg": opt_arg_val}, + json={"arg1": arg1_val, "opt_arg": opt_arg_val, "req_kwarg": req_kwarg_val}, ) mock_session.post.return_value.__aenter__.return_value.json.assert_awaited_once() @pytest.mark.asyncio - async def test_call_success_with_defaults( - self, - tool: ToolboxTool, - mock_session: MagicMock, - tool_details: dict, - configure_mock_response: Callable, + async def test_call_invalid_arguments_type_error( + self, tool: ToolboxTool, mock_session: MagicMock ): - expected_result = "Default success!" - configure_mock_response({"result": expected_result}) + with pytest.raises(TypeError): + await tool("val1", 2, 3) - arg1_val = "another_test" - default_opt_val = tool_details["params"][1].default - result = await tool(arg1_val) + with pytest.raises(TypeError): + await tool("val1", non_existent_arg="bad") - assert result == expected_result - mock_session.post.assert_called_once_with( - tool_details["expected_url"], - json={"arg1": arg1_val, "opt_arg": default_opt_val}, + with pytest.raises(TypeError): + await tool(opt_arg=500) + + mock_session.post.assert_not_called() + + # Bound Params tests + @pytest.fixture + def bound_arg1_value(self) -> str: + return "statically_bound_arg1" + + @pytest.fixture + def tool_with_bound_arg1( + self, + mock_session: MagicMock, + tool_details: dict[str, Any], + bound_arg1_value: str, + ) -> ToolboxTool: + bound_params = {"arg1": bound_arg1_value} + return ToolboxTool( + session=mock_session, + base_url=tool_details["base_url"], + name=tool_details["name"], + desc=tool_details["desc"], + params=tool_details["params"], # Use corrected params + bound_params=bound_params, ) - mock_session.post.return_value.__aenter__.return_value.json.assert_awaited_once() @pytest.mark.asyncio - async def test_call_api_error( + async def test_bound_parameter_static_value_call( self, - tool: ToolboxTool, + tool_with_bound_arg1: ToolboxTool, mock_session: MagicMock, - tool_details: dict, + tool_details: dict[str, Any], configure_mock_response: Callable, + bound_arg1_value: str, ): - error_message = "Tool execution failed on server" - configure_mock_response({"error": error_message}) - default_opt_val = tool_details["params"][1].default + """Test calling a tool with a statically bound parameter.""" + expected_result = "Bound call success!" + configure_mock_response(json_data={"result": expected_result}) - with pytest.raises(Exception) as exc_info: - await tool("some_arg") + opt_arg_val = 789 + req_kwarg_val = True # The only remaining required arg - assert str(exc_info.value) == error_message + # Call *without* providing arg1, but provide the others + result = await tool_with_bound_arg1( + opt_arg=opt_arg_val, req_kwarg=req_kwarg_val + ) + + assert result == expected_result mock_session.post.assert_called_once_with( tool_details["expected_url"], - json={"arg1": "some_arg", "opt_arg": default_opt_val}, + # Payload should include the bound value for arg1 + json={ + "arg1": bound_arg1_value, + "opt_arg": opt_arg_val, + "req_kwarg": req_kwarg_val, + }, ) mock_session.post.return_value.__aenter__.return_value.json.assert_awaited_once() + @pytest.fixture + def tool_with_bound_arg2(self, tool: ToolboxTool) -> ToolboxTool: + new_tool = tool.bind_params( + {"opt_arg": lambda: 88} + ) # Tried passing a string callable and it failed + return new_tool + @pytest.mark.asyncio - async def test_call_missing_result_key( + async def test_bound_parameter_static_value_call2( self, - tool: ToolboxTool, + tool_with_bound_arg2: ToolboxTool, mock_session: MagicMock, - tool_details: dict, + tool_details: dict[str, Any], configure_mock_response: Callable, + bound_arg1_value: str, ): - fallback_response = {"status": "completed", "details": "some info"} - configure_mock_response(fallback_response) - default_opt_val = tool_details["params"][1].default + """Test calling a tool with a statically bound parameter.""" + expected_result = "Bound call success!" + configure_mock_response(json_data={"result": expected_result}) + + req_kwarg_val = True # The only remaining required arg - result = await tool("value_for_arg1") + # Call *without* providing arg1, but provide the others + result = await tool_with_bound_arg2(arg1="random_val", req_kwarg=req_kwarg_val) - assert result == fallback_response + assert result == expected_result mock_session.post.assert_called_once_with( tool_details["expected_url"], - json={"arg1": "value_for_arg1", "opt_arg": default_opt_val}, + # Payload should include the bound value for arg1 + json={ + "arg1": "random_val", + "opt_arg": 88, + "req_kwarg": req_kwarg_val, + }, ) mock_session.post.return_value.__aenter__.return_value.json.assert_awaited_once() - - @pytest.mark.asyncio - async def test_call_invalid_arguments_type_error( - self, tool: ToolboxTool, mock_session: MagicMock - ): - with pytest.raises(TypeError): - await tool("val1", 2, 3) - - with pytest.raises(TypeError): - await tool("val1", non_existent_arg="bad") - - with pytest.raises(TypeError): - await tool(opt_arg=500) - - mock_session.post.assert_not_called()