diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index a534e706..cc4a6256 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -13,13 +13,12 @@ # limitations under the License. -import types -from typing import Any, Callable, Mapping, Optional, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Union from aiohttp import ClientSession -from .protocol import ManifestSchema, ToolSchema -from .tool import ToolboxTool, identify_required_authn_params +from .protocol import ManifestSchema, ParameterSchema, ToolSchema +from .tool import ToolboxTool class ToolboxClient: @@ -59,24 +58,26 @@ def __parse_tool( self, name: str, schema: ToolSchema, - auth_token_getters: dict[str, Callable[[], str]], + auth_token_getters: Mapping[str, Callable[[], str]], all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], + strict: bool, ) -> ToolboxTool: - """Internal helper to create a callable tool from its schema.""" - # sort into reg, authn, and bound params - params = [] - authn_params: dict[str, list[str]] = {} - bound_params: dict[str, Callable[[], str]] = {} - for p in schema.parameters: - if p.authSources: # authn parameter - authn_params[p.name] = p.authSources - elif p.name in all_bound_params: # bound parameter - bound_params[p.name] = all_bound_params[p.name] - else: # regular parameter - params.append(p) - - authn_params = identify_required_authn_params( - authn_params, auth_token_getters.keys() + """ + Internal helper to create a callable ToolboxTool from its schema. + + Args: + name: The name of the tool. + schema: The ToolSchema defining the tool. + auth_token_getters: Mapping of auth service names to token getters. + all_bound_params: Mapping of all initially bound parameter names to values/callables. + strict: The strictness setting for the created ToolboxTool instance. + + Returns: + An initialized ToolboxTool instance. + """ + + params: Sequence[ParameterSchema] = ( + schema.parameters if schema.parameters is not None else [] ) tool = ToolboxTool( @@ -85,10 +86,9 @@ def __parse_tool( name=name, description=schema.description, params=params, - # create a read-only values for the maps to prevent mutation - required_authn_params=types.MappingProxyType(authn_params), - auth_service_token_getters=types.MappingProxyType(auth_token_getters), - bound_params=types.MappingProxyType(bound_params), + auth_service_token_getters=auth_token_getters, + bound_params=all_bound_params, + strict=strict, ) return tool @@ -127,8 +127,9 @@ async def close(self): async def load_tool( self, name: str, - auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_token_getters: Mapping[str, Callable[[], str]] = {}, bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + strict: bool = True, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -143,8 +144,8 @@ async def load_tool( callables that return the corresponding authentication token. bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - - + strict: If True (default), the loaded tool instance will operate in + strict validation mode. If False, it will be non-strict. Returns: ToolboxTool: A callable object representing the loaded tool, ready @@ -161,10 +162,11 @@ async def load_tool( # parse the provided definition to a tool if name not in manifest.tools: - # TODO: Better exception - raise Exception(f"Tool '{name}' not found!") + raise Exception( + f"Tool '{name}' not found in the manifest received from {url}" + ) tool = self.__parse_tool( - name, manifest.tools[name], auth_token_getters, bound_params + name, manifest.tools[name], auth_token_getters, bound_params, strict ) return tool @@ -172,24 +174,26 @@ async def load_tool( async def load_toolset( self, name: Optional[str] = None, - auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_token_getters: Mapping[str, Callable[[], str]] = {}, bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + strict: bool = True, ) -> list[ToolboxTool]: """ Asynchronously fetches a toolset and loads all tools defined within it. Args: - name: Name of the toolset to load tools. + name: Optional name of the toolset to load. If None, attempts to load + the default toolset. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - - + strict: If True (default), all loaded tool instances will operate in + strict validation mode. If False, they will be non-strict. Returns: - list[ToolboxTool]: A list of callables, one for each tool defined - in the toolset. + list[ToolboxTool]: A list of callables, one for each tool defined in + the toolset. """ # Request the definition of the tool from the server url = f"{self.__base_url}/api/toolset/{name or ''}" @@ -199,7 +203,7 @@ async def load_toolset( # parse each tools name and schema into a list of ToolboxTools tools = [ - self.__parse_tool(n, s, auth_token_getters, bound_params) + self.__parse_tool(n, s, auth_token_getters, bound_params, strict) for n, s in manifest.tools.items() ] return tools diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index b9f5b8df..999962eb 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -14,17 +14,13 @@ import types +from copy import deepcopy from inspect import Signature -from typing import ( - Any, - Callable, - Mapping, - Optional, - Sequence, - Union, -) +from typing import Any, Callable, Mapping, Optional, Sequence, Union +from warnings import warn from aiohttp import ClientSession +from pydantic import ValidationError from toolbox_core.protocol import ParameterSchema @@ -56,9 +52,11 @@ def __init__( name: str, description: str, params: Sequence[ParameterSchema], - required_authn_params: Mapping[str, list[str]], auth_service_token_getters: Mapping[str, Callable[[], str]], bound_params: Mapping[str, Union[Callable[[], Any], Any]], + strict: bool = True, + __original_params: Optional[Sequence[ParameterSchema]] = None, + __original_required_authn_params: Optional[Mapping[str, list[str]]] = None, ): """ Initializes a callable that will trigger the tool invocation through the @@ -69,40 +67,115 @@ def __init__( base_url: The base URL of the Toolbox server API. name: The name of the remote tool. description: The description of the remote tool. - params: The args of the tool. - required_authn_params: A dict of required authenticated parameters to a list - of services that provide values for them. + params: The *complete* original parameter list for the tool. auth_service_token_getters: A dict of authService -> token (or callables that produce a token) bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. + strict: If True (default), raises ValueError during initialization or + binding if parameters are missing, already bound, or require + authentication. If False, issues a warning for auth conflicts + instead (missing/duplicate bindings still raise errors). """ - # used to invoke the toolbox API self.__session: ClientSession = session self.__base_url: str = base_url - self.__url = f"{base_url}/api/tool/{name}/invoke" + self.__name__ = name self.__description = description - self.__params = params - self.__pydantic_model = params_to_pydantic_model(name, self.__params) + self.__strict = strict + + self.__original_params = deepcopy( + __original_params if __original_params is not None else params + ) + self.__original_required_authn_params = ( + __original_required_authn_params + if __original_required_authn_params is not None + else identify_required_authn_params(self.__original_params) + ) - inspect_type_params = [param.to_param() for param in self.__params] + # Validate initial bound_params against original schema before setting state + self._validate_binding(bound_params, check_already_bound=False) - # the following properties are set to help anyone that might inspect it determine usage - self.__name__ = name - self.__doc__ = create_func_docstring(self.__description, self.__params) + # Initialize internal state based on current bindings + self.__auth_service_token_getters = types.MappingProxyType( + dict(auth_service_token_getters) + ) + self.__bound_parameters = types.MappingProxyType( + dict(bound_params) + ) + + # Filter original params to get current (unbound) params + self.__params = tuple( + p for p in self.__original_params if p.name not in self.__bound_parameters + ) + + # Setup for invocation and introspection based on *current* params + self.__url = f"{self.__base_url}/api/tool/{self.__name__}/invoke" + self.__pydantic_model = params_to_pydantic_model(self.__name__, self.__params) + + inspect_type_params = [ + param.to_param() for param in self.__params + ] + + self.__doc__ = create_func_docstring( + self.__description, self.__params + ) self.__signature__ = Signature( parameters=inspect_type_params, return_annotation=str ) - - self.__annotations__ = {p.name: p.annotation for p in inspect_type_params} + self.__annotations__ = { + p.name: p.annotation for p in inspect_type_params + } self.__qualname__ = f"{self.__class__.__qualname__}.{self.__name__}" - # map of parameter name to auth service required by it - self.__required_authn_params = required_authn_params - # map of authService -> token_getter - self.__auth_service_token_getters = auth_service_token_getters - # map of parameter name to value (or callable that produces that value) - self.__bound_parameters = bound_params + + def _validate_binding( + self, + params_to_validate: Mapping[str, Union[Callable[[], Any], Any]], + check_already_bound: bool = True, + ) -> None: + """ + Validates parameters intended for binding against the original schema + and authentication requirements, respecting the instance's strict mode. + """ + auth_bound_params: list[str] = [] + missing_bound_params: list[str] = [] + already_bound_params: list[str] = [] + + original_param_names = {p.name for p in self.__original_params} + + for param_name in params_to_validate: + # Check if already bound (if requested) + if check_already_bound and param_name in self.__bound_parameters: + already_bound_params.append(param_name) + continue + + # Check if missing from original schema + if param_name not in original_param_names: + missing_bound_params.append(param_name) + continue + + # Check if requires authentication + if param_name in self.__original_required_authn_params: + auth_bound_params.append(param_name) + + if already_bound_params: + raise ValueError( + f"Parameter(s) `{', '.join(already_bound_params)}` already bound in tool `{self.__name__}`." + ) + + messages: list[str] = [] + if missing_bound_params: + messages.append( + f"Parameter(s) `{', '.join(missing_bound_params)}` not found in tool schema and cannot be bound." + ) + raise ValueError("\n".join(messages)) + + # Check auth conflicts separately + if auth_bound_params: + auth_message = f"Parameter(s) `{', '.join(auth_bound_params)}` require authentication and cannot be bound." + if self.__strict: + raise ValueError(auth_message) + warn(auth_message) def __copy( self, @@ -111,9 +184,11 @@ def __copy( name: Optional[str] = None, description: Optional[str] = None, params: Optional[Sequence[ParameterSchema]] = None, - required_authn_params: Optional[Mapping[str, list[str]]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, + strict: Optional[bool] = None, + original_params: Optional[Sequence[ParameterSchema]] = None, + original_required_authn_params: Optional[Mapping[str, list[str]]] = None, ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. @@ -124,101 +199,199 @@ def __copy( name: The name of the remote tool. description: The description of the remote tool. params: The args of the tool. - required_authn_params: A dict of required authenticated parameters that need - an auth_service_token_getter set for them yet. auth_service_token_getters: A dict of authService -> token (or callables that produce a token) bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. - + strict: The strictness setting of the tool. + original_params: + original_required_authn_params: """ check = lambda val, default: val if val is not None else default + + # Ensure original state and strictness are passed correctly using current values as default + new_strict = check(strict, self.__strict) + new_original_params = check( + original_params, self.__original_params + ) + new_original_required_authn_params = check( + original_required_authn_params, self.__original_required_authn_params + ) + + # The 'params' arg here should be the *new* set of *current* (unbound) parameters + # determined by the calling method (e.g., bind_parameters derives this) + current_params = check( + params, self.__params + ) # This holds the filtered list for the new instance + + # Use current values as defaults for other potentially changed state + new_auth_getters = check( + auth_service_token_getters, self.__auth_service_token_getters + ) + new_bound_params = check(bound_params, self.__bound_parameters) + + # Re-call constructor. Note: This will re-run validation in __init__ if applicable, + # but _validate_binding should handle it correctly based on the passed bound_params. + # Pass the original state explicitly using the internal args. return ToolboxTool( session=check(session, self.__session), base_url=check(base_url, self.__base_url), name=check(name, self.__name__), description=check(description, self.__description), - params=check(params, self.__params), - required_authn_params=check( - required_authn_params, self.__required_authn_params - ), - auth_service_token_getters=check( - auth_service_token_getters, self.__auth_service_token_getters - ), - bound_params=check(bound_params, self.__bound_parameters), + params=current_params, + auth_service_token_getters=new_auth_getters, + bound_params=new_bound_params, + strict=new_strict, + __original_params=new_original_params, + __original_required_authn_params=new_original_required_authn_params, ) - async def __call__(self, *args: Any, **kwargs: Any) -> str: + def _check_invocation_auth(self) -> None: """ - Asynchronously calls the remote tool with the provided arguments. - - Validates arguments against the tool's signature, then sends them - as a JSON payload in a POST request to the tool's invoke URL. - - Args: - *args: Positional arguments for the tool. - **kwargs: Keyword arguments for the tool. - - Returns: - The string result returned by the remote tool execution. + Verifies that all parameters requiring authentication have a registered + token getter before invocation. Internal helper for __call__. """ + missing_auth_params: dict[str, list[str]] = ( + {} + ) - # check if any auth services need to be specified yet - if len(self.__required_authn_params) > 0: - # Gather all the required auth services into a set - req_auth_services = set() - for s in self.__required_authn_params.values(): - req_auth_services.update(s) - raise Exception( - f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}" + # Check against original requirements for parameters *not currently bound* + for ( + param_name, + required_sources, + ) in self.__original_required_authn_params.items(): + if param_name not in self.__bound_parameters: + has_auth = False + if required_sources: + for source in required_sources: + if source in self.__auth_service_token_getters: + has_auth = True + break + if not has_auth: + missing_auth_params[param_name] = required_sources + + if missing_auth_params: + param_details = [ + f"'{name}' (requires one of: {', '.join(srcs)})" + for name, srcs in missing_auth_params.items() + ] + available_sources = list(self.__auth_service_token_getters.keys()) + raise PermissionError( + f"Tool '{self.__name__}' requires authentication for parameter(s) " + f"{', '.join(param_details)} which is not configured. " + f"Available authentication sources: {available_sources or 'None'}." ) + # TODO: Add check for tool-level auth here (ie. authRequired). - # validate inputs to this call using the signature - all_args = self.__signature__.bind(*args, **kwargs) - all_args.apply_defaults() # Include default values if not provided - payload = all_args.arguments - - # Perform argument type validations using pydantic - self.__pydantic_model.model_validate(payload) - - # apply bounded parameters + async def __call__(self, *args: Any, **kwargs: Any) -> str: + """ + Asynchronously calls the remote tool with the provided arguments. + """ + # 1. Check if all required authentications are satisfied for unbound parameters + self._check_invocation_auth() + + # 2. Bind provided arguments to signature (for current/unbound params) + try: + bound_call_args = self.__signature__.bind(*args, **kwargs) + bound_call_args.apply_defaults() + payload = bound_call_args.arguments + except TypeError as e: + raise TypeError(f"Argument mismatch for tool '{self.__name__}': {e}") from e + + # 3. Validate argument types using pydantic model (for current/unbound params) + try: + # Pydantic model validation ensures correct types for unbound args + validated_payload = self.__pydantic_model.model_validate(payload) + # Use validated data (handles defaults, conversions, etc.) + payload_for_api = validated_payload.model_dump() + except ValidationError as e: + raise ValidationError( + f"Invalid argument types for tool '{self.__name__}':\n{e}" + ) from e + + # 4. Apply statically bound parameters (resolve callables) + resolved_bound_params: dict[str, Any] = {} for param, value in self.__bound_parameters.items(): - payload[param] = await resolve_value(value) + resolved_bound_params[param] = await resolve_value(value) - # create headers for auth services - headers = {} - for auth_service, token_getter in self.__auth_service_token_getters.items(): - headers[f"{auth_service}_token"] = await resolve_value(token_getter) + # 5. Merge provided validated args with resolved bound args + # Bound parameters take precedence if somehow passed in kwargs as well (shouldn't happen via signature) + final_payload = {**payload_for_api, **resolved_bound_params} + # 6. Create headers for auth services + headers: dict[str, str] = {} + for auth_service, token_getter in self.__auth_service_token_getters.items(): + # Include all registered tokens. Server side might ignore unused ones. + try: + token = await resolve_value(token_getter) + if not isinstance(token, str): + warn( + f"Token getter for auth service '{auth_service}' did not return a string.", + UserWarning, + ) + token = str(token) # Attempt conversion + headers[f"{auth_service}_token"] = token + except Exception as e: + # Fail invocation if a token getter fails + raise RuntimeError( + f"Failed to retrieve token for auth service '{auth_service}': {e}" + ) from e + + # 7. Make the API call async with self.__session.post( self.__url, - json=payload, + json=final_payload, headers=headers, + # Consider adding timeout? + # timeout=aiohttp.ClientTimeout(total=...) ) as resp: - body = await resp.json() - if resp.status < 200 or resp.status >= 300: - err = body.get("error", f"unexpected status from server: {resp.status}") - raise Exception(err) - return body.get("result", body) + try: + # Check content type before assuming JSON + content_type = resp.headers.get("Content-Type", "") + if "application/json" in content_type: + body = await resp.json() + else: + # Handle non-JSON response as text + text_body = await resp.text() + # Log or handle text body appropriately + # We still need to check status code below + body = { + "error": f"Non-JSON response received (Content-Type: {content_type}). Body: {text_body[:500]}..." + } + + except Exception as json_error: # Includes JSONDecodeError + # Handle cases where response is not valid JSON even if header suggests it + body = await resp.text() + body = { + "error": f"Failed to decode JSON response (status {resp.status}): {json_error}. Body: {body[:500]}..." + } + + # Check status code *after* trying to read body + if not (200 <= resp.status < 300): + err_msg = f"Error calling tool '{self.__name__}' (status {resp.status})" + if isinstance(body, dict) and "error" in body: + # Use error from JSON payload if available + err_msg += f": {body['error']}" + # No need to add body again if it was already included in body['error'] above + raise Exception(err_msg) # Or a more specific HTTPError subclass + + # 8. Return result (assuming successful 2xx response) + if isinstance(body, dict): + # Prefer 'result' field if present, otherwise return stringified dict + return str(body.get("result", body)) + else: + # Should not happen if status check passed and JSON was decoded, but as fallback + return str(body) def add_auth_token_getters( self, auth_token_getters: Mapping[str, Callable[[], str]], ) -> "ToolboxTool": """ - Registers an auth token getter function that is used for AuthServices when tools - are invoked. - - Args: - auth_token_getters: A mapping of authentication service names to - callables that return the corresponding authentication token. - - Returns: - A new ToolboxTool instance with the specified authentication token - getters registered. + Registers auth token getter functions for specified authentication services. + Creates and returns a *new* tool instance. """ - - # throw an error if the authentication source is already registered + # Check for duplicates against current getters existing_services = self.__auth_service_token_getters.keys() incoming_services = auth_token_getters.keys() duplicates = existing_services & incoming_services @@ -227,48 +400,45 @@ def add_auth_token_getters( f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`." ) - # create a read-only updated value for new_getters + # Create updated map of getters new_getters = types.MappingProxyType( - dict(self.__auth_service_token_getters, **auth_token_getters) - ) - # create a read-only updated for params that are still required - new_req_authn_params = types.MappingProxyType( - identify_required_authn_params( - self.__required_authn_params, auth_token_getters.keys() - ) + {**self.__auth_service_token_getters, **auth_token_getters} ) + # Return a new instance using __copy, passing the new getters return self.__copy( auth_service_token_getters=new_getters, - required_authn_params=new_req_authn_params, + # Other state (params, bound_params, originals, strict) remains the same ) def bind_parameters( self, bound_params: Mapping[str, Union[Callable[[], Any], Any]] ) -> "ToolboxTool": """ - Binds parameters to values or callables that produce values. + Binds parameters to specific values or callables. + Creates and returns a *new* tool instance. Validation uses the + instance's `strict` mode. + """ + if not bound_params: + return self # Return self if no parameters are being bound - Args: - bound_params: A mapping of parameter names to values or callables that - produce values. + # 1. Validate the new bindings against original schema & current state + self._validate_binding(bound_params, check_already_bound=True) - Returns: - A new ToolboxTool instance with the specified parameters bound. - """ - param_names = set(p.name for p in self.__params) - for name in bound_params.keys(): - if name not in param_names: - raise Exception(f"unable to bind parameters: no parameter named {name}") - - new_params = [] - for p in self.__params: - if p.name not in bound_params: - new_params.append(p) - all_bound_params = dict(self.__bound_parameters) - all_bound_params.update(bound_params) + # 2. Create the new state + # New combined dictionary of all bound parameters + all_bound_params = types.MappingProxyType( + {**self.__bound_parameters, **bound_params} + ) + + # New list of *current* (unbound) parameters for the copied instance + new_current_params = tuple( + p for p in self.__original_params if p.name not in all_bound_params + ) + # 3. Create the new instance via __copy return self.__copy( - params=new_params, - bound_params=types.MappingProxyType(all_bound_params), + params=new_current_params, # Pass the filtered list as the new 'current' params + bound_params=all_bound_params, + # Other state (auth_getters, originals, strict) remains the same ) diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index 4c4ec5a7..e88d62ce 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -18,13 +18,13 @@ Any, Awaitable, Callable, - Iterable, Mapping, Sequence, Type, Union, cast, ) +from types import MappingProxyType from pydantic import BaseModel, Field, create_model @@ -45,30 +45,14 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) - def identify_required_authn_params( - req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] -) -> dict[str, list[str]]: - """ - Identifies authentication parameters that are still required; because they - are not covered by the provided `auth_service_names`. - - Args: - req_authn_params: A mapping of parameter names to sets of required - authentication services. - auth_service_names: An iterable of authentication service names for which - token getters are available. - - Returns: - A new dictionary representing the subset of required authentication parameters - that are not covered by the provided `auth_services`. - """ - required_params = {} # params that are still required with provided auth_services - for param, services in req_authn_params.items(): - # if we don't have a token_getter for any of the services required by the param, - # the param is still required - required = not any(s in services for s in auth_service_names) - if required: - required_params[param] = services - return required_params + params: Sequence[ParameterSchema], +) -> Mapping[str, list[str]]: + """Helper to extract auth requirements from parameter schemas.""" + req_auth: dict[str, list[str]] = {} + for p in params: + if hasattr(p, "authSources") and p.authSources: + req_auth[p.name] = p.authSources + return MappingProxyType(req_auth) def params_to_pydantic_model( diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py index b65c8ccf..704ec422 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py @@ -16,6 +16,7 @@ from warnings import warn from aiohttp import ClientSession +from toolbox_core import ToolboxClient as ToolboxCoreClient from .tools import AsyncToolboxTool from .utils import ManifestSchema, _load_manifest @@ -38,8 +39,7 @@ def __init__( url: The base URL of the Toolbox service. session: An HTTP client session. """ - self.__url = url - self.__session = session + self.__core_client = ToolboxCoreClient(url, session) async def aload_tool( self, @@ -79,18 +79,10 @@ async def aload_tool( ) auth_tokens = auth_headers - url = f"{self.__url}/api/tool/{tool_name}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - - return AsyncToolboxTool( - tool_name, - manifest.tools[tool_name], - self.__url, - self.__session, - auth_tokens, - bound_params, - strict, + core_tool = await self.__core_client.load_tool( + tool_name, auth_tokens, bound_params ) + return AsyncToolboxTool(core_tool) async def aload_toolset( self, @@ -132,23 +124,10 @@ async def aload_toolset( ) auth_tokens = auth_headers - url = f"{self.__url}/api/toolset/{toolset_name or ''}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - tools: list[AsyncToolboxTool] = [] - - for tool_name, tool_schema in manifest.tools.items(): - tools.append( - AsyncToolboxTool( - tool_name, - tool_schema, - self.__url, - self.__session, - auth_tokens, - bound_params, - strict, - ) - ) - return tools + core_tools = await self.__core_client.load_toolset( + toolset_name, auth_tokens, bound_params + ) + return [AsyncToolboxTool(core_tool) for core_tool in core_tools] def load_tool( self, diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index c7aafc12..ab5d7f3a 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -12,22 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from typing import Any, Callable, TypeVar, Union -from warnings import warn +from typing import Any, Callable, Mapping, Union -from aiohttp import ClientSession from langchain_core.tools import BaseTool - -from .utils import ( - ToolSchema, - _find_auth_params, - _find_bound_params, - _invoke_tool, - _schema_to_model, -) - -T = TypeVar("T") +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool # This class is an internal implementation detail and is not exposed to the @@ -37,113 +25,39 @@ class AsyncToolboxTool(BaseTool): """ A subclass of LangChain's BaseTool that supports features specific to Toolbox, like bound parameters and authenticated tools. + + It proxies core functionalities like invocation, adding authentication, and + binding parameters to the underlying toolbox_core.ToolboxTool, adapting it + for use within the LangChain ecosystem. """ def __init__( self, - name: str, - schema: ToolSchema, - url: str, - session: ClientSession, - auth_tokens: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + core_tool: ToolboxCoreTool, ) -> None: """ - Initializes an AsyncToolboxTool instance. + Initializes an AsyncToolboxTool instance wrapping the provided core tool. Args: - name: The name of the tool. - schema: The tool schema. - url: The base URL of the Toolbox service. - session: The HTTP client session. - auth_tokens: A mapping of authentication source names to functions - that retrieve ID tokens. - bound_params: A mapping of parameter names to their bound - values. - strict: If True, raises a ValueError if any of the given bound - parameters is missing from the schema or requires - authentication. If False, only issues a warning. + core_tool: An instance of toolbox_core.ToolboxTool. """ - - # If the schema is not already a ToolSchema instance, we create one from - # its attributes. This allows flexibility in how the schema is provided, - # accepting both a ToolSchema object and a dictionary of schema - # attributes. - if not isinstance(schema, ToolSchema): - schema = ToolSchema(**schema) - - auth_params, non_auth_params = _find_auth_params(schema.parameters) - non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( - non_auth_params, list(bound_params) - ) - - # Check if the user is trying to bind a param that is authenticated or - # is missing from the given schema. - auth_bound_params: list[str] = [] - missing_bound_params: list[str] = [] - for bound_param in bound_params: - if bound_param in [param.name for param in auth_params]: - auth_bound_params.append(bound_param) - elif bound_param not in [param.name for param in non_auth_params]: - missing_bound_params.append(bound_param) - - # Create error messages for any params that are found to be - # authenticated or missing. - messages: list[str] = [] - if auth_bound_params: - messages.append( - f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." + if not isinstance(core_tool, ToolboxCoreTool): + raise TypeError( + f"Expected core_tool to be an instance of ToolboxCoreTool, got {type(core_tool)}" ) - if missing_bound_params: - messages.append( - f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound." - ) - - # Join any error messages and raise them as an error or warning, - # depending on the value of the strict flag. - if messages: - message = "\n\n".join(messages) - if strict: - raise ValueError(message) - warn(message) - # Bind values for parameters present in the schema that don't require - # authentication. - bound_params = { - param_name: param_value - for param_name, param_value in bound_params.items() - if param_name in [param.name for param in non_auth_bound_params] - } + self.__core_tool = core_tool - # Update the tools schema to validate only the presence of parameters - # that neither require authentication nor are bound. - schema.parameters = non_auth_non_bound_params - - # Due to how pydantic works, we must initialize the underlying - # BaseTool class before assigning values to member variables. super().__init__( - name=name, - description=schema.description, - args_schema=_schema_to_model(model_name=name, schema=schema.parameters), + name=self.__core_tool.__name__, + description=self.__core_tool._ToolboxTool__description, + args_schema=self.__core_tool._ToolboxTool__pydantic_model, ) - self.__name = name - self.__schema = schema - self.__url = url - self.__session = session - self.__auth_tokens = auth_tokens - self.__auth_params = auth_params - self.__bound_params = bound_params - - # Warn users about any missing authentication so they can add it before - # tool invocation. - self.__validate_auth(strict=False) - - def _run(self, **kwargs: Any) -> dict[str, Any]: + def _run(self, **kwargs: Any) -> str: raise NotImplementedError("Synchronous methods not supported by async tools.") - async def _arun(self, **kwargs: Any) -> dict[str, Any]: + async def _arun(self, **kwargs: Any) -> str: """ The coroutine that invokes the tool with the given arguments. @@ -151,153 +65,25 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: **kwargs: The arguments to the tool. Returns: - A dictionary containing the parsed JSON response from the tool - invocation. - """ - - # If the tool had parameters that require authentication, then right - # before invoking that tool, we check whether all these required - # authentication sources have been registered or not. - self.__validate_auth() - - # Evaluate dynamic parameter values if any - evaluated_params = {} - for param_name, param_value in self.__bound_params.items(): - if callable(param_value): - evaluated_params[param_name] = param_value() - else: - evaluated_params[param_name] = param_value - - # Merge bound parameters with the provided arguments - kwargs.update(evaluated_params) - - return await _invoke_tool( - self.__url, self.__session, self.__name, kwargs, self.__auth_tokens - ) - - def __validate_auth(self, strict: bool = True) -> None: - """ - Checks if a tool meets the authentication requirements. - - A tool is considered authenticated if all of its parameters meet at - least one of the following conditions: - - * The parameter has at least one registered authentication source. - * The parameter requires no authentication. - - Args: - strict: If True, raises a PermissionError if any required - authentication sources are not registered. If False, only issues - a warning. + The string result from the core tool invocation. Raises: - PermissionError: If strict is True and any required authentication - sources are not registered. - """ - is_authenticated: bool = not self.__schema.authRequired - params_missing_auth: list[str] = [] - - # Check tool for at least 1 required auth source - for src in self.__schema.authRequired: - if src in self.__auth_tokens: - is_authenticated = True - break - - # Check each parameter for at least 1 required auth source - for param in self.__auth_params: - if not param.authSources: - raise ValueError("Auth sources cannot be None.") - has_auth = False - for src in param.authSources: - - # Find first auth source that is specified - if src in self.__auth_tokens: - has_auth = True - break - if not has_auth: - params_missing_auth.append(param.name) - - messages: list[str] = [] - - if not is_authenticated: - messages.append( - f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use." - ) - - if params_missing_auth: - messages.append( - f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." - ) - - if messages: - message = "\n\n".join(messages) - if strict: - raise PermissionError(message) - warn(message) - - def __create_copy( - self, - *, - auth_tokens: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool, - ) -> "AsyncToolboxTool": - """ - Creates a copy of the current AsyncToolboxTool instance, allowing for - modification of auth tokens and bound params. - - This method enables the creation of new tool instances with inherited - properties from the current instance, while optionally updating the auth - tokens and bound params. This is useful for creating variations of the - tool with additional auth tokens or bound params without modifying the - original instance, ensuring immutability. - - Args: - auth_tokens: A dictionary of auth source names to functions that - retrieve ID tokens. These tokens will be merged with the - existing auth tokens. - bound_params: A dictionary of parameter names to their - bound values or functions to retrieve the values. These params - will be merged with the existing bound params. - strict: If True, raises a ValueError if any of the given bound - parameters is missing from the schema or requires - authentication. If False, only issues a warning. - - Returns: - A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens or bound params. + PermissionError: If required authentication is missing. + ValidationError: If provided arguments are invalid. + Exception: For API errors or other issues during invocation. """ - new_schema = deepcopy(self.__schema) - - # Reconstruct the complete parameter schema by merging the auth - # parameters back with the non-auth parameters. This is necessary to - # accurately validate the new combination of auth tokens and bound - # params in the constructor of the new AsyncToolboxTool instance, ensuring - # that any overlaps or conflicts are correctly identified and reported - # as errors or warnings, depending on the given `strict` flag. - new_schema.parameters += self.__auth_params - return AsyncToolboxTool( - name=self.__name, - schema=new_schema, - url=self.__url, - session=self.__session, - auth_tokens={**self.__auth_tokens, **auth_tokens}, - bound_params={**self.__bound_params, **bound_params}, - strict=strict, - ) + return await self.__core_tool(**kwargs) def add_auth_tokens( - self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + self, auth_tokens: Mapping[str, Callable[[], str]] ) -> "AsyncToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding authentication sources. Args: - auth_tokens: A dictionary of authentication source names to the - functions that return corresponding ID token. - strict: If True, a ValueError is raised if any of the provided auth - parameters is already bound. If False, only a warning is issued. + auth_tokens: A mapping of authentication source names to functions + that return the corresponding ID token. Returns: A new AsyncToolboxTool instance that is a deep copy of the current @@ -310,22 +96,11 @@ def add_auth_tokens( and strict is True. """ - - # Check if the authentication source is already registered. - dupe_tokens: list[str] = [] - for auth_token, _ in auth_tokens.items(): - if auth_token in self.__auth_tokens: - dupe_tokens.append(auth_token) - - if dupe_tokens: - raise ValueError( - f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." - ) - - return self.__create_copy(auth_tokens=auth_tokens, strict=strict) + new_core_tool = self.__core_tool.add_auth_token_getters(auth_tokens) + return self.__class__(new_core_tool) def add_auth_token( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + self, auth_source: str, get_id_token: Callable[[], str] ) -> "AsyncToolboxTool": """ Registers a function to retrieve an ID token for a given authentication @@ -334,11 +109,9 @@ def add_auth_token( Args: auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if the provided auth - parameter is already bound. If False, only a warning is issued. Returns: - A new ToolboxTool instance that is a deep copy of the current + A new AsyncToolboxTool instance that is a deep copy of the current instance, with added auth token. Raises: @@ -346,73 +119,53 @@ def add_auth_token( ValueError: If the provided auth parameter is already bound and strict is True. """ - return self.add_auth_tokens({auth_source: get_id_token}, strict=strict) + return self.add_auth_tokens({auth_source: get_id_token}) def bind_params( - self, - bound_params: dict[str, Union[Any, Callable[[], Any]]], - strict: bool = True, + self, bound_params: Mapping[str, Union[Any, Callable[[], Any]]] ) -> "AsyncToolboxTool": """ Registers values or functions to retrieve the value for the corresponding bound parameters. Args: - bound_params: A dictionary of the bound parameter name to the - value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. + bound_params: A mapping of parameter names to their bound + values or functions to retrieve the values dynamically. Returns: A new AsyncToolboxTool instance that is a deep copy of the current instance, with added bound params. Raises: - ValueError: If any of the provided bound params is already bound. - ValueError: if any of the provided bound params is not defined in - the tool's schema, or requires authentication, and strict is - True. + ValueError: If any provided parameter name is already bound. + ValueError: If `strict` is True and any parameter being bound requires + authentication or doesn't exist in the original schema. + Exception: If a parameter name doesn't exist. """ - - # Check if the parameter is already bound. - dupe_params: list[str] = [] - for param_name, _ in bound_params.items(): - if param_name in self.__bound_params: - dupe_params.append(param_name) - - if dupe_params: - raise ValueError( - f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`." - ) - - return self.__create_copy(bound_params=bound_params, strict=strict) + new_core_tool = self.__core_tool.bind_parameters(bound_params) + return self.__class__(new_core_tool) def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], - strict: bool = True, ) -> "AsyncToolboxTool": """ Registers a value or a function to retrieve the value for a given bound parameter. Args: - param_name: The name of the bound parameter. - param_value: The value of the bound parameter, or a callable that - returns the value. - strict: If True, a ValueError is raised if the provided bound param - is not defined in the tool's schema, or requires authentication. - If False, only a warning is issued. + param_name: The name of the parameter to bind. + param_value: The value or function for the bound parameter. Returns: - A new ToolboxTool instance that is a deep copy of the current + A new AsyncToolboxTool instance that is a deep copy of the current instance, with added bound param. Raises: ValueError: If the provided bound param is already bound. ValueError: if the provided bound param is not defined in the tool's schema, or requires authentication, and strict is True. + Exception: If the parameter name doesn't exist. """ - return self.bind_params({param_name: param_value}, strict) + return self.bind_params({param_name: param_value})