diff --git a/src/toolbox_langchain/async_tools.py b/src/toolbox_langchain/async_tools.py index 0593213c..514b6f17 100644 --- a/src/toolbox_langchain/async_tools.py +++ b/src/toolbox_langchain/async_tools.py @@ -13,31 +13,24 @@ # limitations under the License. from copy import deepcopy -from typing import Any, Callable, TypeVar, Union +from typing import Any, Callable, TypeVar, Union, Dict, List, Tuple, Type from warnings import warn +import inspect +import asyncio from aiohttp import ClientSession -from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field, create_model -from .utils import ( - ToolSchema, - _find_auth_params, - _find_bound_params, - _invoke_tool, - _schema_to_model, -) -T = TypeVar("T") - - -# This class is an internal implementation detail and is not exposed to the -# end-user. It should not be used directly by external code. Changes to this -# class will not be considered breaking changes to the public API. -class AsyncToolboxTool(BaseTool): - """ - A subclass of LangChain's BaseTool that supports features specific to - Toolbox, like bound parameters and authenticated tools. - """ +class AsyncToolboxTool(): + __name: str + __schema: ToolSchema + __model: Type[BaseModel] + __url: str + __session: ClientSession + __auth_tokens: Dict[str, Callable[[], str]] + __auth_params: List[ParameterSchema] + __bound_params: Dict[str, Union[Any, Callable[[], Any]]] def __init__( self, @@ -45,34 +38,10 @@ def __init__( schema: ToolSchema, url: str, session: ClientSession, - auth_tokens: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + auth_tokens: Dict[str, Callable[[], str]] = {}, + bound_params: Dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, ) -> None: - """ - Initializes an AsyncToolboxTool instance. - - Args: - name: The name of the tool. - schema: The tool schema. - url: The base URL of the Toolbox service. - session: The HTTP client session. - auth_tokens: A mapping of authentication source names to functions - that retrieve ID tokens. - bound_params: A mapping of parameter names to their bound - values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. - """ - - # If the schema is not already a ToolSchema instance, we create one from - # its attributes. This allows flexibility in how the schema is provided, - # accepting both a ToolSchema object and a dictionary of schema - # attributes. - if not isinstance(schema, ToolSchema): - schema = ToolSchema(**schema) - auth_params, non_auth_params = _find_auth_params(schema.parameters) non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( non_auth_params, list(bound_params) @@ -80,8 +49,8 @@ def __init__( # Check if the user is trying to bind a param that is authenticated or # is missing from the given schema. - auth_bound_params: list[str] = [] - missing_bound_params: list[str] = [] + auth_bound_params: List[str] = [] + missing_bound_params: List[str] = [] for bound_param in bound_params: if bound_param in [param.name for param in auth_params]: auth_bound_params.append(bound_param) @@ -90,7 +59,7 @@ def __init__( # Create error messages for any params that are found to be # authenticated or missing. - messages: list[str] = [] + messages: List[str] = [] if auth_bound_params: messages.append( f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." @@ -110,7 +79,7 @@ def __init__( # Bind values for parameters present in the schema that don't require # authentication. - bound_params = { + _bound_params = { param_name: param_value for param_name, param_value in bound_params.items() if param_name in [param.name for param in non_auth_bound_params] @@ -118,92 +87,58 @@ def __init__( # Update the tools schema to validate only the presence of parameters # that neither require authentication nor are bound. - schema.parameters = non_auth_non_bound_params - - # Due to how pydantic works, we must initialize the underlying - # BaseTool class before assigning values to member variables. - super().__init__( - name=name, - description=schema.description, - args_schema=_schema_to_model(model_name=name, schema=schema.parameters), - ) + _updated_schema = deepcopy(schema) + _updated_schema.parameters = non_auth_non_bound_params self.__name = name - self.__schema = schema + self.__schema = _updated_schema + self.__model = _schema_to_model(self.__name, self.__schema.parameters) self.__url = url self.__session = session self.__auth_tokens = auth_tokens self.__auth_params = auth_params - self.__bound_params = bound_params + self.__bound_params = _bound_params # Warn users about any missing authentication so they can add it before # tool invocation. self.__validate_auth(strict=False) - def _run(self, **kwargs: Any) -> dict[str, Any]: - raise NotImplementedError("Synchronous methods not supported by async tools.") - - async def _arun(self, **kwargs: Any) -> dict[str, Any]: - """ - The coroutine that invokes the tool with the given arguments. - - Args: - **kwargs: The arguments to the tool. - - Returns: - A dictionary containing the parsed JSON response from the tool - invocation. - """ - - # If the tool had parameters that require authentication, then right - # before invoking that tool, we check whether all these required - # authentication sources have been registered or not. - self.__validate_auth() - - # Evaluate dynamic parameter values if any - evaluated_params = {} - for param_name, param_value in self.__bound_params.items(): - if callable(param_value): - evaluated_params[param_name] = param_value() - else: - evaluated_params[param_name] = param_value + # Make the tool instance directly callable. + docstring = schema_to_docstring(self.__schema, self.__bound_params) + + # Create a list to store parameter definitions for the function signature + sig_params = [] + for param in self.__schema.parameters: + # TODO: Change to _parse_type(param) post latest SDK release. + param_type = _parse_type(param.type) + sig_params.append( + inspect.Parameter( + param.name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=param_type + ) + ) - # Merge bound parameters with the provided arguments - kwargs.update(evaluated_params) + # Create the function signature + sig = inspect.Signature(parameters=sig_params, return_annotation=Dict[str, Any]) + self.__signature__ = sig + self.__doc__ = docstring + self.__name__ = self.__name - return await _invoke_tool( - self.__url, self.__session, self.__name, kwargs, self.__auth_tokens - ) def __validate_auth(self, strict: bool = True) -> None: - """ - Checks if a tool meets the authentication requirements. - - A tool is considered authenticated if all of its parameters meet at - least one of the following conditions: - - * The parameter has at least one registered authentication source. - * The parameter requires no authentication. - - Args: - strict: If True, raises a PermissionError if any required - authentication sources are not registered. If False, only issues - a warning. - - Raises: - PermissionError: If strict is True and any required authentication - sources are not registered. - """ - is_authenticated: bool = not self.__schema.authRequired - params_missing_auth: list[str] = [] - - # Check tool for at least 1 required auth source - for src in self.__schema.authRequired: - if src in self.__auth_tokens: - is_authenticated = True - break + # TODO: Add this once we release the latest SDK code. + # is_authenticated: bool = not self.__schema.authRequired + # # Check tool for at least 1 required auth source + # for src in self.__schema.authRequired: + # if src in self.__auth_tokens: + # is_authenticated = True + # break + # if not is_authenticated: + # messages.append( + # f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use." + # ) # Check each parameter for at least 1 required auth source + params_missing_auth: List[str] = [] for param in self.__auth_params: if not param.authSources: raise ValueError("Auth sources cannot be None.") @@ -217,12 +152,7 @@ def __validate_auth(self, strict: bool = True) -> None: if not has_auth: params_missing_auth.append(param.name) - messages: list[str] = [] - - if not is_authenticated: - messages.append( - f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use." - ) + messages: List[str] = [] if params_missing_auth: messages.append( @@ -238,35 +168,11 @@ def __validate_auth(self, strict: bool = True) -> None: def __create_copy( self, *, - auth_tokens: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + auth_tokens: Dict[str, Callable[[], str]] = {}, + bound_params: Dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool, ) -> "AsyncToolboxTool": - """ - Creates a copy of the current AsyncToolboxTool instance, allowing for - modification of auth tokens and bound params. - - This method enables the creation of new tool instances with inherited - properties from the current instance, while optionally updating the auth - tokens and bound params. This is useful for creating variations of the - tool with additional auth tokens or bound params without modifying the - original instance, ensuring immutability. - - Args: - auth_tokens: A dictionary of auth source names to functions that - retrieve ID tokens. These tokens will be merged with the - existing auth tokens. - bound_params: A dictionary of parameter names to their - bound values or functions to retrieve the values. These params - will be merged with the existing bound params. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. - - Returns: - A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens or bound params. - """ + new_schema = deepcopy(self.__schema) # Reconstruct the complete parameter schema by merging the auth @@ -287,30 +193,11 @@ def __create_copy( ) def add_auth_tokens( - self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + self, auth_tokens: Dict[str, Callable[[], str]], strict: bool = True ) -> "AsyncToolboxTool": - """ - Registers functions to retrieve ID tokens for the corresponding - authentication sources. - - Args: - auth_tokens: A dictionary of authentication source names to the - functions that return corresponding ID token. - strict: If True, a ValueError is raised if any of the provided auth - tokens are already bound. If False, only a warning is issued. - - Returns: - A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. - - Raises: - ValueError: If the provided auth tokens are already registered. - ValueError: If the provided auth tokens are already bound and strict - is True. - """ # Check if the authentication source is already registered. - dupe_tokens: list[str] = [] + dupe_tokens: List[str] = [] for auth_token, _ in auth_tokens.items(): if auth_token in self.__auth_tokens: dupe_tokens.append(auth_token) @@ -325,55 +212,16 @@ def add_auth_tokens( def add_auth_token( self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True ) -> "AsyncToolboxTool": - """ - Registers a function to retrieve an ID token for a given authentication - source. - - Args: - auth_source: The name of the authentication source. - get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if any of the provided auth - token is already bound. If False, only a warning is issued. - - Returns: - A new ToolboxTool instance that is a deep copy of the current - instance, with added auth token. - - Raises: - ValueError: If the provided auth token is already registered. - ValueError: If the provided auth token is already bound and strict - is True. - """ return self.add_auth_tokens({auth_source: get_id_token}, strict=strict) def bind_params( self, - bound_params: dict[str, Union[Any, Callable[[], Any]]], + bound_params: Dict[str, Union[Any, Callable[[], Any]]], strict: bool = True, ) -> "AsyncToolboxTool": - """ - Registers values or functions to retrieve the value for the - corresponding bound parameters. - - Args: - bound_params: A dictionary of the bound parameter name to the - value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params are not defined in the tool's schema, or require - authentication. If False, only a warning is issued. - - Returns: - A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added bound params. - - Raises: - ValueError: If the provided bound params are already bound. - ValueError: if the provided bound params are not defined in the tool's schema, or require - authentication, and strict is True. - """ # Check if the parameter is already bound. - dupe_params: list[str] = [] + dupe_params: List[str] = [] for param_name, _ in bound_params.items(): if param_name in self.__bound_params: dupe_params.append(param_name) @@ -391,25 +239,28 @@ def bind_param( param_value: Union[Any, Callable[[], Any]], strict: bool = True, ) -> "AsyncToolboxTool": - """ - Registers a value or a function to retrieve the value for a given bound - parameter. - - Args: - param_name: The name of the bound parameter. - param_value: The value of the bound parameter, or a callable that - returns the value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. - - Returns: - A new ToolboxTool instance that is a deep copy of the current - instance, with added bound param. - - Raises: - ValueError: If the provided bound param is already bound. - ValueError: if the provided bound param is not defined in the tool's - schema, or requires authentication, and strict is True. - """ - return self.bind_params({param_name: param_value}, strict) + return self.bind_params({param_name: param_value}, strict=strict) + + async def __call__(self, *args, **kwargs) -> Any: + call_args = self.__signature__.bind(*args, **kwargs).arguments + self.__model.model_validate(call_args) + + # If the tool had parameters that require authentication, then right + # before invoking that tool, we check whether all these required + # authentication sources have been registered or not. + self.__validate_auth() + + # Evaluate dynamic parameter values if any + evaluated_params = {} + for param_name, param_value in self.__bound_params.items(): + if callable(param_value): + evaluated_params[param_name] = param_value() + else: + evaluated_params[param_name] = param_value + + # Merge bound parameters with the provided arguments + call_args.update(evaluated_params) + + return await _invoke_tool( + self.__url, self.__session, self.__name, call_args, self.__auth_tokens + )