diff --git a/src/toolbox_langchain/async_tools.py b/src/toolbox_langchain/async_tools.py index 0593213c..faf011c1 100644 --- a/src/toolbox_langchain/async_tools.py +++ b/src/toolbox_langchain/async_tools.py @@ -12,31 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from copy import deepcopy -from typing import Any, Callable, TypeVar, Union +from typing import Any, Callable, Union from warnings import warn from aiohttp import ClientSession -from langchain_core.tools import BaseTool from .utils import ( ToolSchema, _find_auth_params, _find_bound_params, _invoke_tool, + _parse_type, + _schema_to_docstring, _schema_to_model, ) -T = TypeVar("T") - # This class is an internal implementation detail and is not exposed to the # end-user. It should not be used directly by external code. Changes to this # class will not be considered breaking changes to the public API. -class AsyncToolboxTool(BaseTool): +class AsyncToolboxTool: """ - A subclass of LangChain's BaseTool that supports features specific to - Toolbox, like bound parameters and authenticated tools. + A class that supports features specific to Toolbox, like bound parameters + and authenticated tools. """ def __init__( @@ -110,7 +110,7 @@ def __init__( # Bind values for parameters present in the schema that don't require # authentication. - bound_params = { + __bound_params = { param_name: param_value for param_name, param_value in bound_params.items() if param_name in [param.name for param in non_auth_bound_params] @@ -118,43 +118,61 @@ def __init__( # Update the tools schema to validate only the presence of parameters # that neither require authentication nor are bound. - schema.parameters = non_auth_non_bound_params - - # Due to how pydantic works, we must initialize the underlying - # BaseTool class before assigning values to member variables. - super().__init__( - name=name, - description=schema.description, - args_schema=_schema_to_model(model_name=name, schema=schema.parameters), - ) + __updated_schema = deepcopy(schema) + __updated_schema.parameters = non_auth_non_bound_params - self.__name = name - self.__schema = schema + self.__schema = __updated_schema + self.__model = _schema_to_model(name, self.__schema.parameters) self.__url = url self.__session = session self.__auth_tokens = auth_tokens self.__auth_params = auth_params - self.__bound_params = bound_params + self.__bound_params = __bound_params # Warn users about any missing authentication so they can add it before # tool invocation. self.__validate_auth(strict=False) - def _run(self, **kwargs: Any) -> dict[str, Any]: - raise NotImplementedError("Synchronous methods not supported by async tools.") + # Store parameter definitions for the function signature and annotations + sig_params = [] + annotations = {} + for param in self.__schema.parameters: + param_type = _parse_type(param) + annotations[param.name] = param_type + sig_params.append( + inspect.Parameter( + param.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=param_type, + ) + ) + + # Set function name, docstring, signature and annotations + self.__name__ = name + self.__qualname__ = name + self.__doc__ = _schema_to_docstring(self.__schema) + self.__signature__ = inspect.Signature( + parameters=sig_params, return_annotation=dict[str, Any] + ) + self.__annotations__ = annotations - async def _arun(self, **kwargs: Any) -> dict[str, Any]: + async def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """ The coroutine that invokes the tool with the given arguments. Args: - **kwargs: The arguments to the tool. + **args: The positional arguments to the tool. + **kwargs: The keyword arguments to the tool. Returns: A dictionary containing the parsed JSON response from the tool invocation. """ + # Validate arguments + validated_args = self.__signature__.bind(*args, **kwargs).arguments + self.__model.model_validate(validated_args) + # If the tool had parameters that require authentication, then right # before invoking that tool, we check whether all these required # authentication sources have been registered or not. @@ -169,10 +187,14 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: evaluated_params[param_name] = param_value # Merge bound parameters with the provided arguments - kwargs.update(evaluated_params) + validated_args.update(evaluated_params) return await _invoke_tool( - self.__url, self.__session, self.__name, kwargs, self.__auth_tokens + self.__url, + self.__session, + self.__name__, + validated_args, + self.__auth_tokens, ) def __validate_auth(self, strict: bool = True) -> None: @@ -221,12 +243,12 @@ def __validate_auth(self, strict: bool = True) -> None: if not is_authenticated: messages.append( - f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use." + f"Tool {self.__name__} requires authentication, but no valid authentication sources are registered. Please register the required sources before use." ) if params_missing_auth: messages.append( - f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." + f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name__} require authentication, but no valid authentication sources are registered. Please register the required sources before use." ) if messages: @@ -277,7 +299,7 @@ def __create_copy( # as errors or warnings, depending on the given `strict` flag. new_schema.parameters += self.__auth_params return AsyncToolboxTool( - name=self.__name, + name=self.__name__, schema=new_schema, url=self.__url, session=self.__session, @@ -317,7 +339,7 @@ def add_auth_tokens( if dupe_tokens: raise ValueError( - f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." + f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name__}`." ) return self.__create_copy(auth_tokens=auth_tokens, strict=strict) @@ -380,7 +402,7 @@ def bind_params( if dupe_params: raise ValueError( - f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`." + f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name__}`." ) return self.__create_copy(bound_params=bound_params, strict=strict) diff --git a/src/toolbox_langchain/tools.py b/src/toolbox_langchain/tools.py index f19b3d61..8705823e 100644 --- a/src/toolbox_langchain/tools.py +++ b/src/toolbox_langchain/tools.py @@ -17,17 +17,15 @@ from threading import Thread from typing import Any, Awaitable, Callable, TypeVar, Union -from langchain_core.tools import BaseTool - from .async_tools import AsyncToolboxTool T = TypeVar("T") -class ToolboxTool(BaseTool): +class ToolboxTool: """ - A subclass of LangChain's BaseTool that supports features specific to - Toolbox, like bound parameters and authenticated tools. + A class that supports features specific to Toolbox, like bound parameters + and authenticated tools. """ def __init__( @@ -45,14 +43,6 @@ def __init__( thread: The thread to run blocking operations in. """ - # Due to how pydantic works, we must initialize the underlying - # BaseTool class before assigning values to member variables. - super().__init__( - name=async_tool.name, - description=async_tool.description, - args_schema=async_tool.args_schema, - ) - self.__async_tool = async_tool self.__loop = loop self.__thread = thread @@ -77,11 +67,8 @@ async def __run_as_async(self, coro: Awaitable[T]) -> T: asyncio.run_coroutine_threadsafe(coro, self.__loop) ) - def _run(self, **kwargs: Any) -> dict[str, Any]: - return self.__run_as_sync(self.__async_tool._arun(**kwargs)) - - async def _arun(self, **kwargs: Any) -> dict[str, Any]: - return await self.__run_as_async(self.__async_tool._arun(**kwargs)) + def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + return self.__run_as_sync(self.__async_tool(*args, **kwargs)) def add_auth_tokens( self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True diff --git a/src/toolbox_langchain/utils.py b/src/toolbox_langchain/utils.py index 985c7bfe..299b630b 100644 --- a/src/toolbox_langchain/utils.py +++ b/src/toolbox_langchain/utils.py @@ -266,3 +266,29 @@ def _find_bound_params( _non_bound_params.append(param) return (_bound_params, _non_bound_params) + + +def _schema_to_docstring(tool_schema: ToolSchema) -> str: + """Generates a Google Style docstring from a ToolSchema object. + + If the schema has parameters, the docstring includes an 'Args:' section + detailing each parameter's name, type, and description. If no parameters are + present, only the tool's description is returned. + + Args: + tool_schema: The schema object defining the tool's interface, + including its description and parameters. + + Returns: + str: A Google Style formatted docstring. + """ + + if not tool_schema.parameters: + return tool_schema.description + + docstring = f"{tool_schema.description}\n\nArgs:" + for param in tool_schema.parameters: + docstring += ( + f"\n {param.name} ({_parse_type(param).__name__}): {param.description}" + ) + return docstring