diff --git a/src/toolbox_langchain_sdk/async_client.py b/src/toolbox_langchain_sdk/async_client.py new file mode 100644 index 00000000..b65c8ccf --- /dev/null +++ b/src/toolbox_langchain_sdk/async_client.py @@ -0,0 +1,171 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, Union +from warnings import warn + +from aiohttp import ClientSession + +from .tools import AsyncToolboxTool +from .utils import ManifestSchema, _load_manifest + + +# This class is an internal implementation detail and is not exposed to the +# end-user. It should not be used directly by external code. Changes to this +# class will not be considered breaking changes to the public API. +class AsyncToolboxClient: + + def __init__( + self, + url: str, + session: ClientSession, + ): + """ + Initializes the AsyncToolboxClient for the Toolbox service at the given URL. + + Args: + url: The base URL of the Toolbox service. + session: An HTTP client session. + """ + self.__url = url + self.__session = session + + async def aload_tool( + self, + tool_name: str, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> AsyncToolboxTool: + """ + Loads the tool with the given tool name from the Toolbox service. + + Args: + tool_name: The name of the tool to load. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A tool loaded from the Toolbox. + """ + if auth_headers: + if auth_tokens: + warn( + "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + DeprecationWarning, + ) + auth_tokens = auth_headers + + url = f"{self.__url}/api/tool/{tool_name}" + manifest: ManifestSchema = await _load_manifest(url, self.__session) + + return AsyncToolboxTool( + tool_name, + manifest.tools[tool_name], + self.__url, + self.__session, + auth_tokens, + bound_params, + strict, + ) + + async def aload_toolset( + self, + toolset_name: Optional[str] = None, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> list[AsyncToolboxTool]: + """ + Loads tools from the Toolbox service, optionally filtered by toolset + name. + + Args: + toolset_name: The name of the toolset to load. If not provided, + all tools are loaded. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A list of all tools loaded from the Toolbox. + """ + if auth_headers: + if auth_tokens: + warn( + "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + DeprecationWarning, + ) + auth_tokens = auth_headers + + url = f"{self.__url}/api/toolset/{toolset_name or ''}" + manifest: ManifestSchema = await _load_manifest(url, self.__session) + tools: list[AsyncToolboxTool] = [] + + for tool_name, tool_schema in manifest.tools.items(): + tools.append( + AsyncToolboxTool( + tool_name, + tool_schema, + self.__url, + self.__session, + auth_tokens, + bound_params, + strict, + ) + ) + return tools + + def load_tool( + self, + tool_name: str, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> AsyncToolboxTool: + raise NotImplementedError("Synchronous methods not supported by async client.") + + def load_toolset( + self, + toolset_name: Optional[str] = None, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> list[AsyncToolboxTool]: + raise NotImplementedError("Synchronous methods not supported by async client.") diff --git a/src/toolbox_langchain_sdk/async_tools.py b/src/toolbox_langchain_sdk/async_tools.py new file mode 100644 index 00000000..f1e1364a --- /dev/null +++ b/src/toolbox_langchain_sdk/async_tools.py @@ -0,0 +1,397 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Any, Callable, TypeVar, Union +from warnings import warn + +from aiohttp import ClientSession +from langchain_core.tools import BaseTool + +from .utils import ( + ToolSchema, + _find_auth_params, + _find_bound_params, + _invoke_tool, + _schema_to_model, +) + +T = TypeVar("T") + + +# This class is an internal implementation detail and is not exposed to the +# end-user. It should not be used directly by external code. Changes to this +# class will not be considered breaking changes to the public API. +class AsyncToolboxTool(BaseTool): + """ + A subclass of LangChain's BaseTool that supports features specific to + Toolbox, like bound parameters and authenticated tools. + """ + + def __init__( + self, + name: str, + schema: ToolSchema, + url: str, + session: ClientSession, + auth_tokens: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> None: + """ + Initializes an AsyncToolboxTool instance. + + Args: + name: The name of the tool. + schema: The tool schema. + url: The base URL of the Toolbox service. + session: The HTTP client session. + auth_tokens: A mapping of authentication source names to functions + that retrieve ID tokens. + bound_params: A mapping of parameter names to their bound + values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + """ + + # If the schema is not already a ToolSchema instance, we create one from + # its attributes. This allows flexibility in how the schema is provided, + # accepting both a ToolSchema object and a dictionary of schema + # attributes. + if not isinstance(schema, ToolSchema): + schema = ToolSchema(**schema) + + auth_params, non_auth_params = _find_auth_params(schema.parameters) + non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( + non_auth_params, list(bound_params) + ) + + # Check if the user is trying to bind a param that is authenticated or + # is missing from the given schema. + auth_bound_params: list[str] = [] + missing_bound_params: list[str] = [] + for bound_param in bound_params: + if bound_param in [param.name for param in auth_params]: + auth_bound_params.append(bound_param) + elif bound_param not in [param.name for param in non_auth_params]: + missing_bound_params.append(bound_param) + + # Create error messages for any params that are found to be + # authenticated or missing. + messages: list[str] = [] + if auth_bound_params: + messages.append( + f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." + ) + if missing_bound_params: + messages.append( + f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound." + ) + + # Join any error messages and raise them as an error or warning, + # depending on the value of the strict flag. + if messages: + message = "\n\n".join(messages) + if strict: + raise ValueError(message) + warn(message) + + # Bind values for parameters present in the schema that don't require + # authentication. + bound_params = { + param_name: param_value + for param_name, param_value in bound_params.items() + if param_name in [param.name for param in non_auth_bound_params] + } + + # Update the tools schema to validate only the presence of parameters + # that neither require authentication nor are bound. + schema.parameters = non_auth_non_bound_params + + # Due to how pydantic works, we must initialize the underlying + # StructuredTool class before assigning values to member variables. + super().__init__( + name=name, + description=schema.description, + args_schema=_schema_to_model(model_name=name, schema=schema.parameters), + ) + + self.__name = name + self.__schema = schema + self.__url = url + self.__session = session + self.__auth_tokens = auth_tokens + self.__auth_params = auth_params + self.__bound_params = bound_params + + # Warn users about any missing authentication so they can add it before + # tool invocation. + self.__validate_auth(strict=False) + + def _run(self, **kwargs: Any) -> dict[str, Any]: + raise NotImplementedError("Synchronous methods not supported by async tools.") + + async def _arun(self, **kwargs: Any) -> dict[str, Any]: + """ + The coroutine that invokes the tool with the given arguments. + + Args: + **kwargs: The arguments to the tool. + + Returns: + A dictionary containing the parsed JSON response from the tool + invocation. + """ + + # If the tool had parameters that require authentication, then right + # before invoking that tool, we check whether all these required + # authentication sources have been registered or not. + self.__validate_auth() + + # Evaluate dynamic parameter values if any + evaluated_params = {} + for param_name, param_value in self.__bound_params.items(): + if callable(param_value): + evaluated_params[param_name] = param_value() + else: + evaluated_params[param_name] = param_value + + # Merge bound parameters with the provided arguments + kwargs.update(evaluated_params) + + return await _invoke_tool( + self.__url, self.__session, self.__name, kwargs, self.__auth_tokens + ) + + def __validate_auth(self, strict: bool = True) -> None: + """ + Checks if a tool meets the authentication requirements. + + A tool is considered authenticated if all of its parameters meet at + least one of the following conditions: + + * The parameter has at least one registered authentication source. + * The parameter requires no authentication. + + Args: + strict: If True, raises a PermissionError if any required + authentication sources are not registered. If False, only issues + a warning. + + Raises: + PermissionError: If strict is True and any required authentication + sources are not registered. + """ + params_missing_auth: list[str] = [] + + # Check each parameter for at least 1 required auth source + for param in self.__auth_params: + if not param.authSources: + raise ValueError("Auth sources cannot be None.") + has_auth = False + for src in param.authSources: + + # Find first auth source that is specified + if src in self.__auth_tokens: + has_auth = True + break + if not has_auth: + params_missing_auth.append(param.name) + + if params_missing_auth: + message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." + + if strict: + raise PermissionError(message) + warn(message) + + def __create_copy( + self, + *, + auth_tokens: dict[str, Callable[[], str]] = {}, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool, + ) -> "AsyncToolboxTool": + """ + Creates a copy of the current AsyncToolboxTool instance, allowing for + modification of auth tokens and bound params. + + This method enables the creation of new tool instances with inherited + properties from the current instance, while optionally updating the auth + tokens and bound params. This is useful for creating variations of the + tool with additional auth tokens or bound params without modifying the + original instance, ensuring immutability. + + Args: + auth_tokens: A dictionary of auth source names to functions that + retrieve ID tokens. These tokens will be merged with the + existing auth tokens. + bound_params: A dictionary of parameter names to their + bound values or functions to retrieve the values. These params + will be merged with the existing bound params. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. + + Returns: + A new AsyncToolboxTool instance that is a deep copy of the current + instance, with added auth tokens or bound params. + """ + new_schema = deepcopy(self.__schema) + + # Reconstruct the complete parameter schema by merging the auth + # parameters back with the non-auth parameters. This is necessary to + # accurately validate the new combination of auth tokens and bound + # params in the constructor of the new AsyncToolboxTool instance, ensuring + # that any overlaps or conflicts are correctly identified and reported + # as errors or warnings, depending on the given `strict` flag. + new_schema.parameters += self.__auth_params + return AsyncToolboxTool( + name=self.__name, + schema=new_schema, + url=self.__url, + session=self.__session, + auth_tokens={**self.__auth_tokens, **auth_tokens}, + bound_params={**self.__bound_params, **bound_params}, + strict=strict, + ) + + def add_auth_tokens( + self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + ) -> "AsyncToolboxTool": + """ + Registers functions to retrieve ID tokens for the corresponding + authentication sources. + + Args: + auth_tokens: A dictionary of authentication source names to the + functions that return corresponding ID token. + strict: If True, a ValueError is raised if any of the provided auth + tokens are already bound. If False, only a warning is issued. + + Returns: + A new AsyncToolboxTool instance that is a deep copy of the current + instance, with added auth tokens. + + Raises: + ValueError: If the provided auth tokens are already registered. + ValueError: If the provided auth tokens are already bound and strict + is True. + """ + + # Check if the authentication source is already registered. + dupe_tokens: list[str] = [] + for auth_token, _ in auth_tokens.items(): + if auth_token in self.__auth_tokens: + dupe_tokens.append(auth_token) + + if dupe_tokens: + raise ValueError( + f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." + ) + + return self.__create_copy(auth_tokens=auth_tokens, strict=strict) + + def add_auth_token( + self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + ) -> "AsyncToolboxTool": + """ + Registers a function to retrieve an ID token for a given authentication + source. + + Args: + auth_source: The name of the authentication source. + get_id_token: A function that returns the ID token. + strict: If True, a ValueError is raised if any of the provided auth + token is already bound. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added auth token. + + Raises: + ValueError: If the provided auth token is already registered. + ValueError: If the provided auth token is already bound and strict + is True. + """ + return self.add_auth_tokens({auth_source: get_id_token}, strict=strict) + + def bind_params( + self, + bound_params: dict[str, Union[Any, Callable[[], Any]]], + strict: bool = True, + ) -> "AsyncToolboxTool": + """ + Registers values or functions to retrieve the value for the + corresponding bound parameters. + + Args: + bound_params: A dictionary of the bound parameter name to the + value or function of the bound value. + strict: If True, a ValueError is raised if any of the provided bound + params are not defined in the tool's schema, or require + authentication. If False, only a warning is issued. + + Returns: + A new AsyncToolboxTool instance that is a deep copy of the current + instance, with added bound params. + + Raises: + ValueError: If the provided bound params are already bound. + ValueError: if the provided bound params are not defined in the tool's schema, or require + authentication, and strict is True. + """ + + # Check if the parameter is already bound. + dupe_params: list[str] = [] + for param_name, _ in bound_params.items(): + if param_name in self.__bound_params: + dupe_params.append(param_name) + + if dupe_params: + raise ValueError( + f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`." + ) + + return self.__create_copy(bound_params=bound_params, strict=strict) + + def bind_param( + self, + param_name: str, + param_value: Union[Any, Callable[[], Any]], + strict: bool = True, + ) -> "AsyncToolboxTool": + """ + Registers a value or a function to retrieve the value for a given bound + parameter. + + Args: + param_name: The name of the bound parameter. param_value: The value + of the bound parameter, or a callable that + returns the value. + strict: If True, a ValueError is raised if any of the provided bound + params is not defined in the tool's schema, or requires + authentication. If False, only a warning is issued. + + Returns: + A new ToolboxTool instance that is a deep copy of the current + instance, with added bound param. + + Raises: + ValueError: If the provided bound param is already bound. + ValueError: if the provided bound param is not defined in the tool's + schema, or requires authentication, and strict is True. + """ + return self.bind_params({param_name: param_value}, strict) diff --git a/src/toolbox_langchain_sdk/client.py b/src/toolbox_langchain_sdk/client.py index 3ae43911..f30d5766 100644 --- a/src/toolbox_langchain_sdk/client.py +++ b/src/toolbox_langchain_sdk/client.py @@ -13,88 +13,155 @@ # limitations under the License. import asyncio -from typing import Any, Callable, Optional, Union -from warnings import warn +from threading import Thread +from typing import Any, Awaitable, Callable, Optional, TypeVar, Union from aiohttp import ClientSession +from .async_client import AsyncToolboxClient from .tools import ToolboxTool -from .utils import ManifestSchema, _load_manifest + +T = TypeVar("T") class ToolboxClient: - def __init__(self, url: str, session: Optional[ClientSession] = None): + __session: Optional[ClientSession] = None + __loop: Optional[asyncio.AbstractEventLoop] = None + __thread: Optional[Thread] = None + + def __init__( + self, + url: str, + ) -> None: """ Initializes the ToolboxClient for the Toolbox service at the given URL. Args: url: The base URL of the Toolbox service. - session: An optional HTTP client session. If not provided, a new - session will be created. """ - self._url: str = url - self._should_close_session: bool = session is None - self._session: ClientSession = session or ClientSession() - async def close(self) -> None: - """ - Closes the HTTP client session if it was created by this client. - """ - # We check whether _should_close_session is set or not since we do not - # want to close the session in case the user had passed their own - # ClientSession object, since then we expect the user to be owning its - # lifecycle. - if self._session and self._should_close_session: - await self._session.close() - - def __del__(self): - """ - Ensures the HTTP client session is closed when the client is garbage - collected. - """ - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.create_task(self.close()) - else: - loop.run_until_complete(self.close()) - except Exception: - # We "pass" assuming that the exception is thrown because the event - # loop is no longer running, but at that point the Session should - # have been closed already anyway. - pass - - async def _load_tool_manifest(self, tool_name: str) -> ManifestSchema: + # Running a loop in a background thread allows us to support async + # methods from non-async environments. + if ToolboxClient.__loop is None: + loop = asyncio.new_event_loop() + thread = Thread(target=loop.run_forever, daemon=True) + thread.start() + ToolboxClient.__thread = thread + ToolboxClient.__loop = loop + + async def __start_session() -> None: + + # Use a default session if none is provided. This leverages connection + # pooling for better performance by reusing a single session throughout + # the application's lifetime. + if ToolboxClient.__session is None: + ToolboxClient.__session = ClientSession() + + coro = __start_session() + + asyncio.run_coroutine_threadsafe(coro, ToolboxClient.__loop).result() + + if not ToolboxClient.__session: + raise ValueError("Session cannot be None.") + self.__async_client = AsyncToolboxClient(url, ToolboxClient.__session) + + def __run_as_sync(self, coro: Awaitable[T]) -> T: + """Run an async coroutine synchronously""" + if not self.__loop: + raise Exception( + "Cannot call synchronous methods before the background loop is initialized." + ) + return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() + + async def __run_as_async(self, coro: Awaitable[T]) -> T: + """Run an async coroutine asynchronously""" + + # If a loop has not been provided, attempt to run in current thread. + if not self.__loop: + return await coro + + # Otherwise, run in the background thread. + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self.__loop) + ) + + async def aload_tool( + self, + tool_name: str, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> ToolboxTool: """ - Fetches and parses the manifest schema for the given tool from the - Toolbox service. + Loads the tool with the given tool name from the Toolbox service. Args: tool_name: The name of the tool to load. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. Returns: - The parsed Toolbox manifest. + A tool loaded from the Toolbox. """ - url = f"{self._url}/api/tool/{tool_name}" - return await _load_manifest(url, self._session) + async_tool = await self.__run_as_async( + self.__async_client.aload_tool( + tool_name, auth_tokens, auth_headers, bound_params, strict + ) + ) - async def _load_toolset_manifest( - self, toolset_name: Optional[str] = None - ) -> ManifestSchema: + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") + return ToolboxTool(async_tool, self.__loop, self.__thread) + + async def aload_toolset( + self, + toolset_name: Optional[str] = None, + auth_tokens: dict[str, Callable[[], str]] = {}, + auth_headers: Optional[dict[str, Callable[[], str]]] = None, + bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, + strict: bool = True, + ) -> list[ToolboxTool]: """ - Fetches and parses the manifest schema from the Toolbox service. + Loads tools from the Toolbox service, optionally filtered by toolset + name. Args: toolset_name: The name of the toolset to load. If not provided, - the manifest for all available tools is loaded. + all tools are loaded. + auth_tokens: An optional mapping of authentication source names to + functions that retrieve ID tokens. + auth_headers: Deprecated. Use `auth_tokens` instead. + bound_params: An optional mapping of parameter names to their + bound values. + strict: If True, raises a ValueError if any of the given bound + parameters are missing from the schema or require + authentication. If False, only issues a warning. Returns: - The parsed Toolbox manifest. + A list of all tools loaded from the Toolbox. """ - url = f"{self._url}/api/toolset/{toolset_name or ''}" - return await _load_manifest(url, self._session) + async_tools = await self.__run_as_async( + self.__async_client.aload_toolset( + toolset_name, auth_tokens, auth_headers, bound_params, strict + ) + ) - async def load_tool( + tools: list[ToolboxTool] = [] + + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") + for async_tool in async_tools: + tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + return tools + + def load_tool( self, tool_name: str, auth_tokens: dict[str, Callable[[], str]] = {}, @@ -119,31 +186,17 @@ async def load_tool( Returns: A tool loaded from the Toolbox. """ - if auth_headers: - if auth_tokens: - warn( - "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", - DeprecationWarning, - ) - else: - warn( - "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", - DeprecationWarning, - ) - auth_tokens = auth_headers - - manifest: ManifestSchema = await self._load_tool_manifest(tool_name) - return ToolboxTool( - tool_name, - manifest.tools[tool_name], - self._url, - self._session, - auth_tokens, - bound_params, - strict, + async_tool = self.__run_as_sync( + self.__async_client.aload_tool( + tool_name, auth_tokens, auth_headers, bound_params, strict + ) ) - async def load_toolset( + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") + return ToolboxTool(async_tool, self.__loop, self.__thread) + + def load_toolset( self, toolset_name: Optional[str] = None, auth_tokens: dict[str, Callable[[], str]] = {}, @@ -170,32 +223,15 @@ async def load_toolset( Returns: A list of all tools loaded from the Toolbox. """ - if auth_headers: - if auth_tokens: - warn( - "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", - DeprecationWarning, - ) - else: - warn( - "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", - DeprecationWarning, - ) - auth_tokens = auth_headers + async_tools = self.__run_as_sync( + self.__async_client.aload_toolset( + toolset_name, auth_tokens, auth_headers, bound_params, strict + ) + ) + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") tools: list[ToolboxTool] = [] - manifest: ManifestSchema = await self._load_toolset_manifest(toolset_name) - - for tool_name, tool_schema in manifest.tools.items(): - tools.append( - ToolboxTool( - tool_name, - tool_schema, - self._url, - self._session, - auth_tokens, - bound_params, - strict, - ) - ) + for async_tool in async_tools: + tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) return tools diff --git a/src/toolbox_langchain_sdk/tools.py b/src/toolbox_langchain_sdk/tools.py index 0c2a7396..c62ab1a9 100644 --- a/src/toolbox_langchain_sdk/tools.py +++ b/src/toolbox_langchain_sdk/tools.py @@ -12,260 +12,80 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from typing import Any, Callable, Union -from warnings import warn +import asyncio +from asyncio import AbstractEventLoop +from threading import Thread +from typing import Any, Awaitable, Callable, TypeVar, Union -from aiohttp import ClientSession from langchain_core.tools import BaseTool -from typing_extensions import Self -from .utils import ( - ParameterSchema, - ToolSchema, - _find_auth_params, - _find_bound_params, - _invoke_tool, - _schema_to_model, -) +from .async_tools import AsyncToolboxTool + +T = TypeVar("T") class ToolboxTool(BaseTool): """ - A subclass of LangChain's StructuredTool that supports features specific to + A subclass of LangChain's BaseTool that supports features specific to Toolbox, like bound parameters and authenticated tools. """ def __init__( self, - name: str, - schema: ToolSchema, - url: str, - session: ClientSession, - auth_tokens: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + async_tool: AsyncToolboxTool, + loop: AbstractEventLoop, + thread: Thread, ) -> None: """ Initializes a ToolboxTool instance. Args: - name: The name of the tool. - schema: The tool schema. - url: The base URL of the Toolbox service. - session: The HTTP client session. - auth_tokens: A mapping of authentication source names to functions - that retrieve ID tokens. - bound_params: A mapping of parameter names to their bound - values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + async_tool: The underlying AsyncToolboxTool instance. + loop: The event loop used to run asynchronous tasks. + thread: The thread to run blocking operations in. """ - # If the schema is not already a ToolSchema instance, we create one from - # its attributes. This allows flexibility in how the schema is provided, - # accepting both a ToolSchema object and a dictionary of schema - # attributes. - if not isinstance(schema, ToolSchema): - schema = ToolSchema(**schema) - - auth_params, non_auth_params = _find_auth_params(schema.parameters) - non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( - non_auth_params, list(bound_params) - ) - - # Check if the user is trying to bind a param that is authenticated or - # is missing from the given schema. - auth_bound_params: list[str] = [] - missing_bound_params: list[str] = [] - for bound_param in bound_params: - if bound_param in [param.name for param in auth_params]: - auth_bound_params.append(bound_param) - elif bound_param not in [param.name for param in non_auth_params]: - missing_bound_params.append(bound_param) - - # Create error messages for any params that are found to be - # authenticated or missing. - messages: list[str] = [] - if auth_bound_params: - messages.append( - f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." - ) - if missing_bound_params: - messages.append( - f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound." - ) - - # Join any error messages and raise them as an error or warning, - # depending on the value of the strict flag. - if messages: - message = "\n\n".join(messages) - if strict: - raise ValueError(message) - warn(message) - - # Bind values for parameters present in the schema that don't require - # authentication. - bound_params = { - param_name: param_value - for param_name, param_value in bound_params.items() - if param_name in [param.name for param in non_auth_bound_params] - } - - # Update the tools schema to validate only the presence of parameters - # that neither require authentication nor are bound. - schema.parameters = non_auth_non_bound_params - # Due to how pydantic works, we must initialize the underlying # StructuredTool class before assigning values to member variables. super().__init__( - name=name, - description=schema.description, - args_schema=_schema_to_model(model_name=name, schema=schema.parameters), + name=async_tool.name, + description=async_tool.description, + args_schema=async_tool.args_schema, ) - self._name: str = name - self._schema: ToolSchema = schema - self._url: str = url - self._session: ClientSession = session - self._auth_tokens: dict[str, Callable[[], str]] = auth_tokens - self._auth_params: list[ParameterSchema] = auth_params - self._bound_params: dict[str, Union[Any, Callable[[], Any]]] = bound_params - - # Warn users about any missing authentication so they can add it before - # tool invocation. - self.__validate_auth(strict=False) - - async def _arun(self, **kwargs: Any) -> dict[str, Any]: - """ - The coroutine that invokes the tool with the given arguments. - - Args: - **kwargs: The arguments to the tool. - - Returns: - A dictionary containing the parsed JSON response from the tool - invocation. - """ + self.__async_tool = async_tool + self.__loop = loop + self.__thread = thread - # If the tool had parameters that require authentication, then right - # before invoking that tool, we check whether all these required - # authentication sources have been registered or not. - self.__validate_auth() + def __run_as_sync(self, coro: Awaitable[T]) -> T: + """Run an async coroutine synchronously""" + if not self.__loop: + raise Exception( + "Cannot call synchronous methods before the background loop is initialized." + ) + return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() - # Evaluate dynamic parameter values if any - evaluated_params = {} - for param_name, param_value in self._bound_params.items(): - if callable(param_value): - evaluated_params[param_name] = param_value() - else: - evaluated_params[param_name] = param_value + async def __run_as_async(self, coro: Awaitable[T]) -> T: + """Run an async coroutine asynchronously""" - # Merge bound parameters with the provided arguments - kwargs.update(evaluated_params) + # If a loop has not been provided, attempt to run in current thread. + if not self.__loop: + return await coro - return await _invoke_tool( - self._url, self._session, self._name, kwargs, self._auth_tokens + # Otherwise, run in the background thread. + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self.__loop) ) def _run(self, **kwargs: Any) -> dict[str, Any]: - raise NotImplementedError("Sync tool calls not supported yet.") - - def __validate_auth(self, strict: bool = True) -> None: - """ - Checks if a tool meets the authentication requirements. - - A tool is considered authenticated if all of its parameters meet at - least one of the following conditions: - - * The parameter has at least one registered authentication source. - * The parameter requires no authentication. - - Args: - strict: If True, raises a PermissionError if any required - authentication sources are not registered. If False, only issues - a warning. - - Raises: - PermissionError: If strict is True and any required authentication - sources are not registered. - """ - params_missing_auth: list[str] = [] - - # Check each parameter for at least 1 required auth source - for param in self._auth_params: - assert param.authSources is not None - has_auth = False - for src in param.authSources: - # Find first auth source that is specified - if src in self._auth_tokens: - has_auth = True - break - if not has_auth: - params_missing_auth.append(param.name) - - if params_missing_auth: - message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self._name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." - - if strict: - raise PermissionError(message) - warn(message) - - def __create_copy( - self, - *, - auth_tokens: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool, - ) -> Self: - """ - Creates a deep copy of the current ToolboxTool instance, allowing for - modification of auth tokens and bound params. + return self.__run_as_sync(self.__async_tool._arun(**kwargs)) - This method enables the creation of new tool instances with inherited - properties from the current instance, while optionally updating the auth - tokens and bound params. This is useful for creating variations of the - tool with additional auth tokens or bound params without modifying the - original instance, ensuring immutability. - - Args: - auth_tokens: A dictionary of auth source names to functions that - retrieve ID tokens. These tokens will be merged with the - existing auth tokens. - bound_params: A dictionary of parameter names to their - bound values or functions to retrieve the values. These params - will be merged with the existing bound params. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. - - Returns: - A new ToolboxTool instance that is a deep copy of the current - instance, with added auth tokens or bound params. - """ - new_schema = deepcopy(self._schema) - - # Reconstruct the complete parameter schema by merging the auth - # parameters back with the non-auth parameters. This is necessary to - # accurately validate the new combination of auth tokens and bound - # params in the constructor of the new ToolboxTool instance, ensuring - # that any overlaps or conflicts are correctly identified and reported - # as errors or warnings, depending on the given `strict` flag. - new_schema.parameters += self._auth_params - return type(self)( - name=self._name, - schema=new_schema, - url=self._url, - session=self._session, - auth_tokens={**self._auth_tokens, **auth_tokens}, - bound_params={**self._bound_params, **bound_params}, - strict=strict, - ) + async def _arun(self, **kwargs: Any) -> dict[str, Any]: + return await self.__run_as_async(self.__async_tool._arun(**kwargs)) def add_auth_tokens( self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True - ) -> Self: + ) -> "ToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding authentication sources. @@ -274,30 +94,26 @@ def add_auth_tokens( auth_tokens: A dictionary of authentication source names to the functions that return corresponding ID token. strict: If True, a ValueError is raised if any of the provided auth - tokens are already registered, or are already bound. If False, - only a warning is issued. + tokens are already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current instance, with added auth tokens. - """ - - # Check if the authentication source is already registered. - dupe_tokens: list[str] = [] - for auth_token, _ in auth_tokens.items(): - if auth_token in self._auth_tokens: - dupe_tokens.append(auth_token) - - if dupe_tokens: - raise ValueError( - f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self._name}`." - ) - return self.__create_copy(auth_tokens=auth_tokens, strict=strict) + Raises: + ValueError: If the provided auth tokens are already registered. + ValueError: If the provided auth tokens are already bound and strict + is True. + """ + return ToolboxTool( + self.__async_tool.add_auth_tokens(auth_tokens, strict), + self.__loop, + self.__thread, + ) def add_auth_token( self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True - ) -> Self: + ) -> "ToolboxTool": """ Registers a function to retrieve an ID token for a given authentication source. @@ -306,20 +122,28 @@ def add_auth_token( auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. strict: If True, a ValueError is raised if any of the provided auth - tokens are already registered, or are already bound. If False, - only a warning is issued. + token is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. - """ - return self.add_auth_tokens({auth_source: get_id_token}, strict=strict) + instance, with added auth token. + + Raises: + ValueError: If the provided auth token is already registered. + ValueError: If the provided auth token is already bound and strict + is True. + """ + return ToolboxTool( + self.__async_tool.add_auth_token(auth_source, get_id_token, strict), + self.__loop, + self.__thread, + ) def bind_params( self, bound_params: dict[str, Union[Any, Callable[[], Any]]], strict: bool = True, - ) -> Self: + ) -> "ToolboxTool": """ Registers values or functions to retrieve the value for the corresponding bound parameters. @@ -328,47 +152,53 @@ def bind_params( bound_params: A dictionary of the bound parameter name to the value or function of the bound value. strict: If True, a ValueError is raised if any of the provided bound - params are already bound, not defined in the tool's schema, or - require authentication. If False, only a warning is issued. + params are not defined in the tool's schema, or require + authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current instance, with added bound params. - """ - - # Check if the parameter is already bound. - dupe_params: list[str] = [] - for param_name, _ in bound_params.items(): - if param_name in self._bound_params: - dupe_params.append(param_name) - - if dupe_params: - raise ValueError( - f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self._name}`." - ) - return self.__create_copy(bound_params=bound_params, strict=strict) + Raises: + ValueError: If the provided bound params are already bound. + ValueError: if the provided bound params are not defined in the tool's schema, or require + authentication, and strict is True. + """ + return ToolboxTool( + self.__async_tool.bind_params(bound_params, strict), + self.__loop, + self.__thread, + ) def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], strict: bool = True, - ) -> Self: + ) -> "ToolboxTool": """ - Registers a value or a function to retrieve the value for a given - bound parameter. + Registers a value or a function to retrieve the value for a given bound + parameter. Args: - param_name: The name of the bound parameter. - param_value: The value of the bound parameter, or a callable - that returns the value. + param_name: The name of the bound parameter. param_value: The value + of the bound parameter, or a callable that + returns the value. strict: If True, a ValueError is raised if any of the provided bound - params are already bound, not defined in the tool's schema, or - require authentication. If False, only a warning is issued. + params is not defined in the tool's schema, or requires + authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added bound params. - """ - return self.bind_params({param_name: param_value}, strict) + instance, with added bound param. + + Raises: + ValueError: If the provided bound param is already bound. + ValueError: if the provided bound param is not defined in the tool's + schema, or requires authentication, and strict is True. + """ + return ToolboxTool( + self.__async_tool.bind_param(param_name, param_value, strict), + self.__loop, + self.__thread, + ) diff --git a/src/toolbox_langchain_sdk/utils.py b/src/toolbox_langchain_sdk/utils.py index 036624d6..ab3eb349 100644 --- a/src/toolbox_langchain_sdk/utils.py +++ b/src/toolbox_langchain_sdk/utils.py @@ -67,8 +67,10 @@ async def _load_manifest(url: str, session: ClientSession) -> ManifestSchema: ValueError: If the response is not a valid manifest. """ async with session.get(url) as response: + # TODO: Remove as it masks error messages. response.raise_for_status() try: + # TODO: Simply use response.json() parsed_json = json.loads(await response.text()) except json.JSONDecodeError as e: raise json.JSONDecodeError( @@ -198,6 +200,7 @@ async def _invoke_tool( json=_convert_none_to_empty_string(data), headers=auth_tokens, ) as response: + # TODO: Remove as it masks error messages. response.raise_for_status() return await response.json() @@ -228,6 +231,17 @@ def _convert_none_to_empty_string(input_dict): def _find_auth_params( params: list[ParameterSchema], ) -> tuple[list[ParameterSchema], list[ParameterSchema]]: + """ + Separates parameters into those that are authenticated and those that are not. + + Args: + params: A list of ParameterSchema objects. + + Returns: + A tuple containing two lists: + - auth_params: A list of ParameterSchema objects that require authentication. + - non_auth_params: A list of ParameterSchema objects that do not require authentication. + """ _auth_params: list[ParameterSchema] = [] _non_auth_params: list[ParameterSchema] = [] @@ -243,6 +257,19 @@ def _find_auth_params( def _find_bound_params( params: list[ParameterSchema], bound_params: list[str] ) -> tuple[list[ParameterSchema], list[ParameterSchema]]: + """ + Separates parameters into those that are bound and those that are not. + + Args: + params: A list of ParameterSchema objects. + bound_params: A list of parameter names that are bound. + + Returns: + A tuple containing two lists: + - bound_params: A list of ParameterSchema objects whose names are in the bound_params list. + - non_bound_params: A list of ParameterSchema objects whose names are not in the bound_params list. + """ + _bound_params: list[ParameterSchema] = [] _non_bound_params: list[ParameterSchema] = [] diff --git a/tests/test_async_client.py b/tests/test_async_client.py new file mode 100644 index 00000000..18957cc5 --- /dev/null +++ b/tests/test_async_client.py @@ -0,0 +1,194 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock, patch +from warnings import catch_warnings, simplefilter + +import pytest +from aiohttp import ClientSession + +from toolbox_langchain_sdk.async_client import AsyncToolboxClient +from toolbox_langchain_sdk.async_tools import AsyncToolboxTool +from toolbox_langchain_sdk.utils import ManifestSchema + +URL = "http://test_url" +MANIFEST_JSON = { + "serverVersion": "1.0.0", + "tools": { + "test_tool_1": { + "description": "Test Tool 1 Description", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + } + ], + }, + "test_tool_2": { + "description": "Test Tool 2 Description", + "parameters": [ + { + "name": "param2", + "type": "integer", + "description": "Param 2", + } + ], + }, + }, +} + + +@pytest.mark.asyncio +class TestAsyncToolboxClient: + @pytest.fixture() + def manifest_schema(self): + return ManifestSchema(**MANIFEST_JSON) + + @pytest.fixture() + def mock_session(self): + return AsyncMock(spec=ClientSession) + + @pytest.fixture() + def mock_client(self, mock_session): + return AsyncToolboxClient(URL, session=mock_session) + + async def test_create_with_existing_session(self, mock_client, mock_session): + assert mock_client._AsyncToolboxClient__session == mock_session + + @patch("toolbox_langchain_sdk.async_client._load_manifest") + async def test_aload_tool( + self, mock_load_manifest, mock_client, mock_session, manifest_schema + ): + tool_name = "test_tool_1" + mock_load_manifest.return_value = manifest_schema + + tool = await mock_client.aload_tool(tool_name) + + mock_load_manifest.assert_called_once_with( + f"{URL}/api/tool/{tool_name}", mock_session + ) + assert isinstance(tool, AsyncToolboxTool) + assert tool.name == tool_name + + @patch("toolbox_langchain_sdk.async_client._load_manifest") + async def test_aload_tool_auth_headers_deprecated( + self, mock_load_manifest, mock_client, manifest_schema + ): + tool_name = "test_tool_1" + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, auth_headers={"Authorization": lambda: "Bearer token"} + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + @patch("toolbox_langchain_sdk.async_client._load_manifest") + async def test_aload_tool_auth_headers_and_tokens( + self, mock_load_manifest, mock_client, manifest_schema + ): + tool_name = "test_tool_1" + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, + auth_headers={"Authorization": lambda: "Bearer token"}, + auth_tokens={"test": lambda: "token"}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + @patch("toolbox_langchain_sdk.async_client._load_manifest") + async def test_aload_toolset( + self, mock_load_manifest, mock_client, mock_session, manifest_schema + ): + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + tools = await mock_client.aload_toolset() + + mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) + assert len(tools) == 2 + for tool in tools: + assert isinstance(tool, AsyncToolboxTool) + assert tool.name in ["test_tool_1", "test_tool_2"] + + @patch("toolbox_langchain_sdk.async_client._load_manifest") + async def test_aload_toolset_with_toolset_name( + self, mock_load_manifest, mock_client, mock_session, manifest_schema + ): + toolset_name = "test_toolset_1" + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + tools = await mock_client.aload_toolset(toolset_name=toolset_name) + + mock_load_manifest.assert_called_once_with( + f"{URL}/api/toolset/{toolset_name}", mock_session + ) + assert len(tools) == 2 + for tool in tools: + assert isinstance(tool, AsyncToolboxTool) + assert tool.name in ["test_tool_1", "test_tool_2"] + + @patch("toolbox_langchain_sdk.async_client._load_manifest") + async def test_aload_toolset_auth_headers_deprecated( + self, mock_load_manifest, mock_client, manifest_schema + ): + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_headers={"Authorization": lambda: "Bearer token"} + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + @patch("toolbox_langchain_sdk.async_client._load_manifest") + async def test_aload_toolset_auth_headers_and_tokens( + self, mock_load_manifest, mock_client, manifest_schema + ): + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_headers={"Authorization": lambda: "Bearer token"}, + auth_tokens={"test": lambda: "token"}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + async def test_load_tool_not_implemented(self, mock_client): + with pytest.raises(NotImplementedError) as excinfo: + mock_client.load_tool("test_tool") + assert "Synchronous methods not supported by async client." in str( + excinfo.value + ) + + async def test_load_toolset_not_implemented(self, mock_client): + with pytest.raises(NotImplementedError) as excinfo: + mock_client.load_toolset() + assert "Synchronous methods not supported by async client." in str( + excinfo.value + ) diff --git a/tests/test_async_tools.py b/tests/test_async_tools.py new file mode 100644 index 00000000..7e98bcb7 --- /dev/null +++ b/tests/test_async_tools.py @@ -0,0 +1,270 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import pytest_asyncio +from pydantic import ValidationError + +from toolbox_langchain_sdk.async_tools import AsyncToolboxTool + + +@pytest.mark.asyncio +class TestAsyncToolboxTool: + @pytest.fixture + def tool_schema(self): + return { + "description": "Test Tool Description", + "parameters": [ + {"name": "param1", "type": "string", "description": "Param 1"}, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture + def auth_tool_schema(self): + return { + "description": "Test Tool Description", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + "authSources": ["test-auth-source"], + }, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest_asyncio.fixture + @patch("aiohttp.ClientSession") + async def toolbox_tool(self, MockClientSession, tool_schema): + mock_session = MockClientSession.return_value + mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + tool = AsyncToolboxTool( + name="test_tool", + schema=tool_schema, + url="http://test_url", + session=mock_session, + ) + return tool + + @pytest_asyncio.fixture + @patch("aiohttp.ClientSession") + async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema): + mock_session = MockClientSession.return_value + mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + with pytest.warns( + UserWarning, + match=r"Parameter\(s\) `param1` of tool test_tool require authentication", + ): + tool = AsyncToolboxTool( + name="test_tool", + schema=auth_tool_schema, + url="https://test-url", + session=mock_session, + ) + return tool + + @patch("aiohttp.ClientSession") + async def test_toolbox_tool_init(self, MockClientSession, tool_schema): + mock_session = MockClientSession.return_value + tool = AsyncToolboxTool( + name="test_tool", + schema=tool_schema, + url="https://test-url", + session=mock_session, + ) + assert tool.name == "test_tool" + assert tool.description == "Test Tool Description" + + @pytest.mark.parametrize( + "params, expected_bound_params", + [ + ({"param1": "bound-value"}, {"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), + ( + {"param1": "bound-value", "param2": 123}, + {"param1": "bound-value", "param2": 123}, + ), + ], + ) + async def test_toolbox_tool_bind_params( + self, toolbox_tool, params, expected_bound_params + ): + tool = toolbox_tool.bind_params(params) + for key, value in expected_bound_params.items(): + if callable(value): + assert value() == tool._AsyncToolboxTool__bound_params[key]() + else: + assert value == tool._AsyncToolboxTool__bound_params[key] + + @pytest.mark.parametrize("strict", [True, False]) + async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool, strict): + if strict: + with pytest.raises(ValueError) as e: + tool = toolbox_tool.bind_params( + {"param3": "bound-value"}, strict=strict + ) + assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) + else: + with pytest.warns(UserWarning) as record: + tool = toolbox_tool.bind_params( + {"param3": "bound-value"}, strict=strict + ) + assert len(record) == 1 + assert "Parameter(s) param3 missing and cannot be bound." in str( + record[0].message + ) + + async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): + tool = toolbox_tool.bind_params({"param1": "bound-value"}) + with pytest.raises(ValueError) as e: + tool = tool.bind_params({"param1": "bound-value"}) + assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( + e.value + ) + + async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): + with pytest.raises(ValueError) as e: + auth_toolbox_tool.bind_params({"param1": "bound-value"}) + assert "Parameter(s) param1 already authenticated and cannot be bound." in str( + e.value + ) + + @pytest.mark.parametrize( + "auth_tokens, expected_auth_tokens", + [ + ( + {"test-auth-source": lambda: "test-token"}, + {"test-auth-source": lambda: "test-token"}, + ), + ( + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + ), + ], + ) + async def test_toolbox_tool_add_auth_tokens( + self, auth_toolbox_tool, auth_tokens, expected_auth_tokens + ): + tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) + for source, getter in expected_auth_tokens.items(): + assert tool._AsyncToolboxTool__auth_tokens[source]() == getter() + + async def test_toolbox_tool_add_auth_tokens_duplicate(self, auth_toolbox_tool): + tool = auth_toolbox_tool.add_auth_tokens( + {"test-auth-source": lambda: "test-token"} + ) + with pytest.raises(ValueError) as e: + tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) + assert ( + "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." + in str(e.value) + ) + + async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): + with pytest.raises(PermissionError) as e: + auth_toolbox_tool._AsyncToolboxTool__validate_auth(strict=True) + assert "Parameter(s) `param1` of tool test_tool require authentication" in str( + e.value + ) + + async def test_toolbox_tool_call(self, toolbox_tool): + result = await toolbox_tool.ainvoke({"param1": "test-value", "param2": 123}) + assert result == {"result": "test-result"} + toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "http://test_url/api/tool/test_tool/invoke", + json={"param1": "test-value", "param2": 123}, + headers={}, + ) + + @pytest.mark.parametrize( + "bound_param, expected_value", + [ + ({"param1": "bound-value"}, "bound-value"), + ({"param1": lambda: "dynamic-value"}, "dynamic-value"), + ], + ) + async def test_toolbox_tool_call_with_bound_params( + self, toolbox_tool, bound_param, expected_value + ): + tool = toolbox_tool.bind_params(bound_param) + result = await tool.ainvoke({"param2": 123}) + assert result == {"result": "test-result"} + toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "http://test_url/api/tool/test_tool/invoke", + json={"param1": expected_value, "param2": 123}, + headers={}, + ) + + async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): + tool = auth_toolbox_tool.add_auth_tokens( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool.ainvoke({"param2": 123}) + assert result == {"result": "test-result"} + auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "https://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) + + async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool): + with pytest.warns( + UserWarning, + match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", + ): + auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" + tool = auth_toolbox_tool.add_auth_tokens( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool.ainvoke({"param2": 123}) + assert result == {"result": "test-result"} + auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "http://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) + + async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): + with pytest.raises(ValidationError) as e: + await toolbox_tool.ainvoke({"param1": 123, "param2": "invalid"}) + assert "2 validation errors for test_tool" in str(e.value) + assert "param1\n Input should be a valid string" in str(e.value) + assert "param2\n Input should be a valid integer" in str(e.value) + + async def test_toolbox_tool_call_with_empty_input(self, toolbox_tool): + with pytest.raises(ValidationError) as e: + await toolbox_tool.ainvoke({}) + assert "2 validation errors for test_tool" in str(e.value) + assert "param1\n Field required" in str(e.value) + assert "param2\n Field required" in str(e.value) + + async def test_toolbox_tool_run_not_implemented(self, toolbox_tool): + with pytest.raises(NotImplementedError): + toolbox_tool._run() diff --git a/tests/test_client.py b/tests/test_client.py index bb3d9e17..184afc43 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,428 +12,212 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock, patch import pytest -from aiohttp import ClientSession +from pydantic import BaseModel from toolbox_langchain_sdk.client import ToolboxClient -from toolbox_langchain_sdk.utils import ManifestSchema - - -@pytest.fixture -def manifest_schema(): - return ManifestSchema( - **{ - "serverVersion": "1.0.0", - "tools": { - "test_tool_1": { - "description": "Test Tool 1 Description", - "parameters": [ - {"name": "param1", "type": "string", "description": "Param 1"} - ], - }, - "test_tool_2": { - "description": "Test Tool 2 Description", - "parameters": [ - {"name": "param2", "type": "integer", "description": "Param 2"} - ], - }, - }, - } - ) - - -@pytest.fixture -def mock_auth_tokens(): - return {"test-auth-source": lambda: "test-token"} - - -@pytest.fixture -def mock_bound_params(): - return {"param1": "bound-value"} - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ClientSession") -async def test_toolbox_client_init(mock_client): - client = ToolboxClient(url="https://test-url", session=mock_client) - assert client._url == "https://test-url" - assert client._session == mock_client - - -@pytest.fixture(params=[True, False]) -@patch("toolbox_langchain_sdk.client.ClientSession") -def toolbox_client(MockClientSession, request): - """ - Fixture to provide a ToolboxClient with and without a provided session. - """ - if request.param: - # Client with a provided session - session = MockClientSession.return_value - client = ToolboxClient(url="https://test-url", session=session) - yield client - else: - # Client that creates its own session - client = ToolboxClient(url="https://test-url") - yield client - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ClientSession") -async def test_toolbox_client_close(MockClientSession, toolbox_client): - MockClientSession.return_value.close = AsyncMock() - for client in toolbox_client: - assert not client._session.close.called - await client.close() - if client._should_close_session: - # Assert session is closed only if it was created by the client - assert client._session.closed - else: - # Assert session is NOT closed if it was provided - assert not client._session.close.called - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ClientSession") -async def test_toolbox_client_del(MockClientSession, toolbox_client): - MockClientSession.return_value.close = AsyncMock() - for client in toolbox_client: - client_session = client._session - assert not client_session.close.called - client.__del__() - assert not client_session.close.called - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_tool_manifest(mock_load_manifest): - mock_load_manifest.return_value = AsyncMock( - return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} - ) - async with ClientSession() as session: - client = ToolboxClient(url="https://test-url", session=session) - manifest = await client._load_tool_manifest("test_tool") - assert manifest == ( # Call the mock object to get its return value - mock_load_manifest.return_value # This will return the dictionary +from toolbox_langchain_sdk.tools import ToolboxTool + +URL = "http://test_url" + + +class TestToolboxClient: + @pytest.fixture() + def toolbox_client(self): + client = ToolboxClient(URL) + assert isinstance(client, ToolboxClient) + assert client._ToolboxClient__async_client is not None + + # Check that the background loop was created and started + assert client._ToolboxClient__loop is not None + assert client._ToolboxClient__loop.is_running() + + return client + + @patch("toolbox_langchain_sdk.client.AsyncToolboxClient.aload_tool") + def test_load_tool(self, mock_aload_tool, toolbox_client): + mock_tool = Mock(spec=ToolboxTool) + mock_tool.name = "mock-tool" + mock_tool.description = "mock description" + mock_tool.args_schema = BaseModel + mock_aload_tool.return_value = mock_tool + + tool = toolbox_client.load_tool("test_tool") + + assert tool.name == mock_tool.name + assert tool.description == mock_tool.description + assert tool.args_schema == mock_tool.args_schema + mock_aload_tool.assert_called_once_with("test_tool", {}, None, {}, True) + + @patch("toolbox_langchain_sdk.client.AsyncToolboxClient.aload_toolset") + def test_load_toolset(self, mock_aload_toolset, toolbox_client): + mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] + mock_tools[0].name = "mock-tool-0" + mock_tools[0].description = "mock description 0" + mock_tools[0].args_schema = BaseModel + mock_tools[1].name = "mock-tool-1" + mock_tools[1].description = "mock description 1" + mock_tools[1].args_schema = BaseModel + mock_aload_toolset.return_value = mock_tools + + tools = toolbox_client.load_toolset() + assert len(tools) == len(mock_tools) + assert all( + a.name == b.name + and a.description == b.description + and a.args_schema == b.args_schema + for a, b in zip(tools, mock_tools) ) - mock_load_manifest.assert_called_once_with( - "https://test-url/api/tool/test_tool", session + mock_aload_toolset.assert_called_once_with(None, {}, None, {}, True) + + @pytest.mark.asyncio + @patch("toolbox_langchain_sdk.client.AsyncToolboxClient.aload_tool") + async def test_aload_tool(self, mock_aload_tool, toolbox_client): + mock_tool = Mock(spec=ToolboxTool) + mock_tool.name = "mock-tool" + mock_tool.description = "mock description" + mock_tool.args_schema = BaseModel + mock_aload_tool.return_value = mock_tool + + tool = await toolbox_client.aload_tool("test_tool") + assert tool.name == mock_tool.name + assert tool.description == mock_tool.description + assert tool.args_schema == mock_tool.args_schema + mock_aload_tool.assert_called_once_with("test_tool", {}, None, {}, True) + + @pytest.mark.asyncio + @patch("toolbox_langchain_sdk.client.AsyncToolboxClient.aload_toolset") + async def test_aload_toolset(self, mock_aload_toolset, toolbox_client): + mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] + mock_tools[0].name = "mock-tool-0" + mock_tools[0].description = "mock description 0" + mock_tools[0].args_schema = BaseModel + mock_tools[1].name = "mock-tool-1" + mock_tools[1].description = "mock description 1" + mock_tools[1].args_schema = BaseModel + mock_aload_toolset.return_value = mock_tools + + tools = await toolbox_client.aload_toolset() + assert len(tools) == len(mock_tools) + assert all( + a.name == b.name + and a.description == b.description + and a.args_schema == b.args_schema + for a, b in zip(tools, mock_tools) ) - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_toolset_manifest(mock_load_manifest): - mock_load_manifest.return_value = AsyncMock( - return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} - ) - async with ClientSession() as session: - client = ToolboxClient(url="https://test-url", session=session) - manifest = await client._load_toolset_manifest("test_toolset") - assert manifest == ( # Call the mock object to get its return value - mock_load_manifest.return_value # This will return the dictionary + mock_aload_toolset.assert_called_once_with(None, {}, None, {}, True) + + @patch("toolbox_langchain_sdk.client.AsyncToolboxClient.aload_tool") + def test_load_tool_with_args(self, mock_aload_tool, toolbox_client): + mock_tool = Mock(spec=ToolboxTool) + mock_tool.name = "mock-tool" + mock_tool.description = "mock description" + mock_tool.args_schema = BaseModel + mock_aload_tool.return_value = mock_tool + auth_tokens = {"token1": lambda: "value1"} + auth_headers = {"header1": lambda: "value2"} + bound_params = {"param1": "value3"} + + tool = toolbox_client.load_tool( + "test_tool_name", + auth_tokens=auth_tokens, + auth_headers=auth_headers, + bound_params=bound_params, + strict=False, ) - mock_load_manifest.assert_called_once_with( - "https://test-url/api/toolset/test_toolset", session + + assert tool.name == mock_tool.name + assert tool.description == mock_tool.description + assert tool.args_schema == mock_tool.args_schema + mock_aload_tool.assert_called_once_with( + "test_tool_name", auth_tokens, auth_headers, bound_params, False ) + @patch("toolbox_langchain_sdk.client.AsyncToolboxClient.aload_toolset") + def test_load_toolset_with_args(self, mock_aload_toolset, toolbox_client): + mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] + mock_tools[0].name = "mock-tool-0" + mock_tools[0].description = "mock description 0" + mock_tools[0].args_schema = BaseModel + mock_tools[1].name = "mock-tool-1" + mock_tools[1].description = "mock description 1" + mock_tools[1].args_schema = BaseModel + mock_aload_toolset.return_value = mock_tools + + auth_tokens = {"token1": lambda: "value1"} + auth_headers = {"header1": lambda: "value2"} + bound_params = {"param1": "value3"} + + tools = toolbox_client.load_toolset( + toolset_name="my_toolset", + auth_tokens=auth_tokens, + auth_headers=auth_headers, + bound_params=bound_params, + strict=False, + ) -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_toolset_manifest_no_toolset(mock_load_manifest): - mock_load_manifest.return_value = AsyncMock( - return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} - ) - async with ClientSession() as session: - client = ToolboxClient(url="https://test-url", session=session) - manifest = await client._load_toolset_manifest() - assert manifest == ( # Call the mock object to get its return value - mock_load_manifest.return_value # This will return the dictionary + assert len(tools) == len(mock_tools) + assert all( + a.name == b.name + and a.description == b.description + and a.args_schema == b.args_schema + for a, b in zip(tools, mock_tools) ) - mock_load_manifest.assert_called_once_with( - "https://test-url/api/toolset/", session + mock_aload_toolset.assert_called_once_with( + "my_toolset", auth_tokens, auth_headers, bound_params, False ) - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ToolboxTool") -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_tool(mock_load_manifest, MockToolboxTool): - mock_load_manifest.return_value = AsyncMock( - return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} - ) - async with ClientSession() as session: - client = ToolboxClient(url="https://test-url", session=session) - tool = await client.load_tool("test_tool") - assert tool == MockToolboxTool.return_value - MockToolboxTool.assert_called_once_with( - "test_tool", - mock_load_manifest.return_value.tools.__getitem__( - "test_tool" - ), # Correctly access the tool schema - "https://test-url", - session, - {}, - {}, - True, + @pytest.mark.asyncio + @patch("toolbox_langchain_sdk.client.AsyncToolboxClient.aload_tool") + async def test_aload_tool_with_args(self, mock_aload_tool, toolbox_client): + mock_tool = Mock(spec=ToolboxTool) + mock_tool.name = "mock-tool" + mock_tool.description = "mock description" + mock_tool.args_schema = BaseModel + mock_aload_tool.return_value = mock_tool + + auth_tokens = {"token1": lambda: "value1"} + auth_headers = {"header1": lambda: "value2"} + bound_params = {"param1": "value3"} + + tool = await toolbox_client.aload_tool( + "test_tool", auth_tokens, auth_headers, bound_params, False ) - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ToolboxTool") -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_tool_with_auth( - mock_load_manifest, MockToolboxTool, mock_auth_tokens -): - mock_load_manifest.return_value = AsyncMock( - return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} - ) - async with ClientSession() as session: - client = ToolboxClient(url="https://test-url", session=session) - tool = await client.load_tool("test_tool", auth_tokens=mock_auth_tokens) - assert tool == MockToolboxTool.return_value - MockToolboxTool.assert_called_once_with( - "test_tool", - mock_load_manifest.return_value.tools.__getitem__("test_tool"), - "https://test-url", - session, - mock_auth_tokens, - {}, - True, + assert tool.name == mock_tool.name + assert tool.description == mock_tool.description + assert tool.args_schema == mock_tool.args_schema + mock_aload_tool.assert_called_once_with( + "test_tool", auth_tokens, auth_headers, bound_params, False ) - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ToolboxTool") -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_tool_with_auth_headers( - mock_load_manifest, MockToolboxTool, mock_auth_tokens -): - mock_load_manifest.return_value = AsyncMock( - return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} - ) - async with ClientSession() as session: - client = ToolboxClient(url="https://test-url", session=session) - with pytest.warns( - DeprecationWarning, - match="Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", - ): - tool = await client.load_tool("test_tool", auth_headers=mock_auth_tokens) - assert tool == MockToolboxTool.return_value - MockToolboxTool.assert_called_once_with( - "test_tool", - mock_load_manifest.return_value.tools.__getitem__("test_tool"), - "https://test-url", - session, - mock_auth_tokens, - {}, - True, + @pytest.mark.asyncio + @patch("toolbox_langchain_sdk.client.AsyncToolboxClient.aload_toolset") + async def test_aload_toolset_with_args(self, mock_aload_toolset, toolbox_client): + mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] + mock_tools[0].name = "mock-tool-0" + mock_tools[0].description = "mock description 0" + mock_tools[0].args_schema = BaseModel + mock_tools[1].name = "mock-tool-1" + mock_tools[1].description = "mock description 1" + mock_tools[1].args_schema = BaseModel + mock_aload_toolset.return_value = mock_tools + + auth_tokens = {"token1": lambda: "value1"} + auth_headers = {"header1": lambda: "value2"} + bound_params = {"param1": "value3"} + + tools = await toolbox_client.aload_toolset( + "my_toolset", auth_tokens, auth_headers, bound_params, False ) - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ToolboxTool") -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_tool_with_auth_and_headers( - mock_load_manifest, MockToolboxTool, mock_auth_tokens -): - mock_load_manifest.return_value = AsyncMock( - return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} - ) - async with ClientSession() as session: - client = ToolboxClient(url="https://test-url", session=session) - with pytest.warns( - DeprecationWarning, - match="Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", - ): - tool = await client.load_tool( - "test_tool", auth_tokens=mock_auth_tokens, auth_headers=mock_auth_tokens - ) - assert tool == MockToolboxTool.return_value - MockToolboxTool.assert_called_once_with( - "test_tool", - mock_load_manifest.return_value.tools.__getitem__("test_tool"), - "https://test-url", - session, - mock_auth_tokens, - {}, - True, + assert len(tools) == len(mock_tools) + assert all( + a.name == b.name + and a.description == b.description + and a.args_schema == b.args_schema + for a, b in zip(tools, mock_tools) ) - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ToolboxTool") -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_tool_with_bound_params( - mock_load_manifest, MockToolboxTool, mock_bound_params -): - mock_load_manifest.return_value = AsyncMock( - return_value={"tools": {"test_tool": {"description": "Test Tool Description"}}} - ) - async with ClientSession() as session: - client = ToolboxClient(url="https://test-url", session=session) - tool = await client.load_tool("test_tool", bound_params=mock_bound_params) - assert tool == MockToolboxTool.return_value - MockToolboxTool.assert_called_once_with( - "test_tool", - mock_load_manifest.return_value.tools.__getitem__("test_tool"), - "https://test-url", - session, - {}, - mock_bound_params, - True, + mock_aload_toolset.assert_called_once_with( + "my_toolset", auth_tokens, auth_headers, bound_params, False ) - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_toolset( - mock_load_manifest, toolbox_client, manifest_schema -): - mock_load_manifest.return_value = manifest_schema - for client in toolbox_client: - tools = await client.load_toolset() - assert [tool._schema for tool in tools] == list(manifest_schema.tools.values()) - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ToolboxTool") -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_toolset_with_auth( - mock_load_manifest, - mock_toolbox_tool, - toolbox_client, - manifest_schema, - mock_auth_tokens, -): - mock_load_manifest.return_value = manifest_schema - for client in toolbox_client: - tools = await client.load_toolset(auth_tokens=mock_auth_tokens) - - for i, (tool_name, tool_schema) in enumerate(manifest_schema.tools.items()): - call_args, _ = mock_toolbox_tool.call_args_list[i] - assert call_args[0] == tool_name - assert call_args[1] == tool_schema - assert call_args[2] == client._url - assert call_args[3] == client._session - assert call_args[4] == mock_auth_tokens - assert call_args[5] == {} - - assert len(tools) == len(manifest_schema.tools) - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ToolboxTool") -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_toolset_with_auth_headers( - mock_load_manifest, - mock_toolbox_tool, - toolbox_client, - manifest_schema, - mock_auth_tokens, -): - mock_load_manifest.return_value = manifest_schema - for client in toolbox_client: - with pytest.warns( - DeprecationWarning, - match="Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", - ): - tools = await client.load_toolset(auth_headers=mock_auth_tokens) - - for i, (tool_name, tool_schema) in enumerate(manifest_schema.tools.items()): - call_args, _ = mock_toolbox_tool.call_args_list[i] - assert call_args[0] == tool_name - assert call_args[1] == tool_schema - assert call_args[2] == client._url - assert call_args[3] == client._session - assert call_args[4] == mock_auth_tokens - assert call_args[5] == {} - - assert len(tools) == len(manifest_schema.tools) - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ToolboxTool") -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_toolset_with_auth_and_headers( - mock_load_manifest, - mock_toolbox_tool, - toolbox_client, - manifest_schema, - mock_auth_tokens, -): - mock_load_manifest.return_value = manifest_schema - for client in toolbox_client: - with pytest.warns( - DeprecationWarning, - match="Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", - ): - tools = await client.load_toolset( - auth_tokens=mock_auth_tokens, auth_headers=mock_auth_tokens - ) - - for i, (tool_name, tool_schema) in enumerate(manifest_schema.tools.items()): - call_args, _ = mock_toolbox_tool.call_args_list[i] - assert call_args[0] == tool_name - assert call_args[1] == tool_schema - assert call_args[2] == client._url - assert call_args[3] == client._session - assert call_args[4] == mock_auth_tokens - assert call_args[5] == {} - - assert len(tools) == len(manifest_schema.tools) - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ToolboxTool") -@patch("toolbox_langchain_sdk.client._load_manifest") -async def test_toolbox_client_load_toolset_with_bound_params( - mock_load_manifest, - mock_toolbox_tool, - toolbox_client, - manifest_schema, - mock_bound_params, -): - mock_load_manifest.return_value = manifest_schema - for client in toolbox_client: - tools = await client.load_toolset(bound_params=mock_bound_params) - - for i, (tool_name, tool_schema) in enumerate(manifest_schema.tools.items()): - call_args, _ = mock_toolbox_tool.call_args_list[i] - assert call_args[0] == tool_name - assert call_args[1] == tool_schema - assert call_args[2] == client._url - assert call_args[3] == client._session - assert call_args[4] == {} - assert call_args[5] == mock_bound_params - - assert len(tools) == len(manifest_schema.tools) - - -@pytest.mark.asyncio -async def test_toolbox_client_del_loop_not_running(): - """Test __del__ when the loop is not running.""" - mock_loop = Mock() - mock_loop.is_running.return_value = False - mock_close = Mock(spec=ToolboxClient.close) - - with patch("asyncio.get_event_loop", return_value=mock_loop): - client = ToolboxClient(url="https://test-url") - client.close = mock_close - client.__del__() - - -@pytest.mark.asyncio -async def test_toolbox_client_del_exception(): - """Test __del__ when an exception occurs.""" - mock_loop = Mock() - mock_loop.is_running.return_value = True - mock_loop.create_task.side_effect = Exception("Test Exception") - - with patch("asyncio.get_event_loop", return_value=mock_loop): - client = ToolboxClient(url="https://test-url") - client.__del__() - - # Assert that create_task was called (despite the exception) - mock_loop.create_task.assert_called_once() diff --git a/tests/test_e2e.py b/tests/test_e2e.py index e5e63785..946b12d2 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -44,40 +44,38 @@ @pytest.mark.asyncio @pytest.mark.usefixtures("toolbox_server") -class TestE2EClient: - @pytest_asyncio.fixture(scope="function") - async def toolbox(self): +class TestE2EClientAsync: + @pytest.fixture(scope="function") + def toolbox(self): """Provides a ToolboxClient instance for each test.""" toolbox = ToolboxClient("http://localhost:5000") - yield toolbox - await toolbox.close() - - #### Basic e2e tests - @pytest.mark.asyncio - async def test_load_tool(self, toolbox): - tool = await toolbox.load_tool("get-n-rows") - response = await tool.ainvoke({"num_rows": "2"}) - result = response["result"] + return toolbox - assert "row1" in result - assert "row2" in result - assert "row3" not in result - - @pytest.mark.asyncio - async def test_load_toolset_specific(self, toolbox): - toolset = await toolbox.load_toolset("my-toolset") - assert len(toolset) == 1 - assert toolset[0].name == "get-row-by-id" + @pytest_asyncio.fixture(scope="function") + async def get_n_rows_tool(self, toolbox): + tool = await toolbox.aload_tool("get-n-rows") + assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + return tool - toolset = await toolbox.load_toolset("my-toolset-2") - assert len(toolset) == 2 - tool_names = ["get-n-rows", "get-row-by-id"] - assert toolset[0].name in tool_names - assert toolset[1].name in tool_names + #### Basic e2e tests + @pytest.mark.parametrize( + "toolset_name, expected_length, expected_tools", + [ + ("my-toolset", 1, ["get-row-by-id"]), + ("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]), + ], + ) + async def test_aload_toolset_specific( + self, toolbox, toolset_name, expected_length, expected_tools + ): + toolset = await toolbox.aload_toolset(toolset_name) + assert len(toolset) == expected_length + for tool in toolset: + name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + assert name in expected_tools - @pytest.mark.asyncio - async def test_load_toolset_all(self, toolbox): - toolset = await toolbox.load_toolset() + async def test_aload_toolset_all(self, toolbox): + toolset = await toolbox.aload_toolset() assert len(toolset) == 5 tool_names = [ "get-n-rows", @@ -87,44 +85,54 @@ async def test_load_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - assert tool.name in tool_names + name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + assert name in tool_names - @pytest.mark.asyncio - async def test_run_tool_missing_params(self, toolbox): - tool = await toolbox.load_tool("get-n-rows") + async def test_run_tool_async(self, get_n_rows_tool): + response = await get_n_rows_tool.ainvoke({"num_rows": "2"}) + result = response["result"] + + assert "row1" in result + assert "row2" in result + assert "row3" not in result + + async def test_run_tool_sync(self, get_n_rows_tool): + response = get_n_rows_tool.invoke({"num_rows": "2"}) + result = response["result"] + + assert "row1" in result + assert "row2" in result + assert "row3" not in result + + async def test_run_tool_missing_params(self, get_n_rows_tool): with pytest.raises(ValidationError, match="Field required"): - await tool.ainvoke({}) + await get_n_rows_tool.ainvoke({}) - @pytest.mark.asyncio - async def test_run_tool_wrong_param_type(self, toolbox): - tool = await toolbox.load_tool("get-n-rows") + async def test_run_tool_wrong_param_type(self, get_n_rows_tool): with pytest.raises(ValidationError, match="Input should be a valid string"): - await tool.ainvoke({"num_rows": 2}) + await get_n_rows_tool.ainvoke({"num_rows": 2}) ##### Auth tests - @pytest.mark.asyncio @pytest.mark.skip(reason="b/389574566") async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" - tool = await toolbox.load_tool( + tool = await toolbox.aload_tool( "get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2} ) response = await tool.ainvoke({"id": "2"}) assert "row2" in response["result"] - @pytest.mark.asyncio async def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" - tool = await toolbox.load_tool( + tool = await toolbox.aload_tool( "get-row-by-id-auth", ) with pytest.raises(ClientResponseError, match="401, message='Unauthorized'"): await tool.ainvoke({"id": "2"}) - @pytest.mark.asyncio async def test_run_tool_wrong_auth(self, toolbox, auth_token2): """Tests running a tool with incorrect auth.""" - tool = await toolbox.load_tool( + tool = await toolbox.aload_tool( "get-row-by-id-auth", ) auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token2) @@ -132,30 +140,27 @@ async def test_run_tool_wrong_auth(self, toolbox, auth_token2): with pytest.raises(ClientResponseError, match="400, message='Bad Request'"): await auth_tool.ainvoke({"id": "2"}) - @pytest.mark.asyncio async def test_run_tool_auth(self, toolbox, auth_token1): """Tests running a tool with correct auth.""" - tool = await toolbox.load_tool( + tool = await toolbox.aload_tool( "get-row-by-id-auth", ) auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token1) response = await auth_tool.ainvoke({"id": "2"}) assert "row2" in response["result"] - @pytest.mark.asyncio async def test_run_tool_param_auth_no_auth(self, toolbox): """Tests running a tool with a param requiring auth, without auth.""" - tool = await toolbox.load_tool("get-row-by-email-auth") + tool = await toolbox.aload_tool("get-row-by-email-auth") with pytest.raises( PermissionError, match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", ): - await tool.ainvoke({}) + await tool.ainvoke({"email": ""}) - @pytest.mark.asyncio async def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" - tool = await toolbox.load_tool( + tool = await toolbox.aload_tool( "get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1} ) response = await tool.ainvoke({}) @@ -164,11 +169,146 @@ async def test_run_tool_param_auth(self, toolbox, auth_token1): assert "row5" in result assert "row6" in result - @pytest.mark.asyncio async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with insufficient auth.""" - tool = await toolbox.load_tool( + tool = await toolbox.aload_tool( "get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1} ) with pytest.raises(ClientResponseError, match="400, message='Bad Request'"): await tool.ainvoke({}) + + +@pytest.mark.usefixtures("toolbox_server") +class TestE2EClientSync: + @pytest.fixture(scope="session") + def toolbox(self): + """Provides a ToolboxClient instance for each test.""" + toolbox = ToolboxClient("http://localhost:5000") + return toolbox + + @pytest.fixture(scope="function") + def get_n_rows_tool(self, toolbox): + tool = toolbox.load_tool("get-n-rows") + assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + return tool + + #### Basic e2e tests + @pytest.mark.parametrize( + "toolset_name, expected_length, expected_tools", + [ + ("my-toolset", 1, ["get-row-by-id"]), + ("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]), + ], + ) + def test_load_toolset_specific( + self, toolbox, toolset_name, expected_length, expected_tools + ): + toolset = toolbox.load_toolset(toolset_name) + assert len(toolset) == expected_length + for tool in toolset: + name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + assert name in expected_tools + + def test_aload_toolset_all(self, toolbox): + toolset = toolbox.load_toolset() + assert len(toolset) == 5 + tool_names = [ + "get-n-rows", + "get-row-by-id", + "get-row-by-id-auth", + "get-row-by-email-auth", + "get-row-by-content-auth", + ] + for tool in toolset: + name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + assert name in tool_names + + @pytest.mark.asyncio + async def test_run_tool_async(self, get_n_rows_tool): + response = await get_n_rows_tool.ainvoke({"num_rows": "2"}) + result = response["result"] + + assert "row1" in result + assert "row2" in result + assert "row3" not in result + + def test_run_tool_sync(self, get_n_rows_tool): + response = get_n_rows_tool.invoke({"num_rows": "2"}) + result = response["result"] + + assert "row1" in result + assert "row2" in result + assert "row3" not in result + + def test_run_tool_missing_params(self, get_n_rows_tool): + with pytest.raises(ValidationError, match="Field required"): + get_n_rows_tool.invoke({}) + + def test_run_tool_wrong_param_type(self, get_n_rows_tool): + with pytest.raises(ValidationError, match="Input should be a valid string"): + get_n_rows_tool.invoke({"num_rows": 2}) + + #### Auth tests + @pytest.mark.skip(reason="b/389574566") + def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): + """Tests running a tool that doesn't require auth, with auth provided.""" + tool = toolbox.load_tool( + "get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2} + ) + response = tool.invoke({"id": "2"}) + assert "row2" in response["result"] + + def test_run_tool_no_auth(self, toolbox): + """Tests running a tool requiring auth without providing auth.""" + tool = toolbox.load_tool( + "get-row-by-id-auth", + ) + with pytest.raises(ClientResponseError, match="401, message='Unauthorized'"): + tool.invoke({"id": "2"}) + + def test_run_tool_wrong_auth(self, toolbox, auth_token2): + """Tests running a tool with incorrect auth.""" + tool = toolbox.load_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token2) + # TODO: Fix error message (b/389577313) + with pytest.raises(ClientResponseError, match="400, message='Bad Request'"): + auth_tool.invoke({"id": "2"}) + + def test_run_tool_auth(self, toolbox, auth_token1): + """Tests running a tool with correct auth.""" + tool = toolbox.load_tool( + "get-row-by-id-auth", + ) + auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token1) + response = auth_tool.invoke({"id": "2"}) + assert "row2" in response["result"] + + def test_run_tool_param_auth_no_auth(self, toolbox): + """Tests running a tool with a param requiring auth, without auth.""" + tool = toolbox.load_tool("get-row-by-email-auth") + with pytest.raises( + PermissionError, + match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + ): + tool.invoke({"email": ""}) + + def test_run_tool_param_auth(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with correct auth.""" + tool = toolbox.load_tool( + "get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + ) + response = tool.invoke({}) + result = response["result"] + assert "row4" in result + assert "row5" in result + assert "row6" in result + + def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): + """Tests running a tool with a param requiring auth, with insufficient auth.""" + tool = toolbox.load_tool( + "get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + ) + with pytest.raises(ClientResponseError, match="400, message='Bad Request'"): + tool.invoke({}) diff --git a/tests/test_tools.py b/tests/test_tools.py index 4709afe9..1ecd61e6 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,311 +15,218 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from pydantic import ValidationError +from pydantic import BaseModel +from toolbox_langchain_sdk.async_tools import AsyncToolboxTool from toolbox_langchain_sdk.tools import ToolboxTool -@pytest.fixture -def tool_schema(): - return { - "description": "Test Tool Description", - "parameters": [ - {"name": "param1", "type": "string", "description": "Param 1"}, - {"name": "param2", "type": "integer", "description": "Param 2"}, - ], - } - - -@pytest.fixture -def auth_tool_schema(): - return { - "description": "Test Tool Description", - "parameters": [ - { - "name": "param1", - "type": "string", - "description": "Param 1", - "authSources": ["test-auth-source"], - }, - {"name": "param2", "type": "integer", "description": "Param 2"}, - ], - } - - -@pytest.fixture -@patch("aiohttp.ClientSession") -async def toolbox_tool(MockClientSession, tool_schema): - mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} - ) - tool = ToolboxTool( - name="test_tool", - schema=tool_schema, - url="https://test-url", - session=mock_session, - ) - yield tool +class TestToolboxTool: + @pytest.fixture + def tool_schema(self): + return { + "description": "Test Tool Description", + "name": "test_tool", + "parameters": [ + {"name": "param1", "type": "string", "description": "Param 1"}, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture + def auth_tool_schema(self): + return { + "description": "Test Tool Description", + "name": "test_tool", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + "authSources": ["test-auth-source"], + }, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture(scope="function") + def mock_async_tool(self, tool_schema): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool.name = "test_tool" + mock_async_tool.description = "test description" + mock_async_tool.args_schema = BaseModel + mock_async_tool._AsyncToolboxTool__name = "test_tool" + mock_async_tool._AsyncToolboxTool__schema = tool_schema + mock_async_tool._AsyncToolboxTool__url = "http://test_url" + mock_async_tool._AsyncToolboxTool__session = Mock() + mock_async_tool._AsyncToolboxTool__auth_tokens = {} + mock_async_tool._AsyncToolboxTool__bound_params = {} + return mock_async_tool + + @pytest.fixture(scope="function") + def mock_async_auth_tool(self, auth_tool_schema): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool.name = "test_tool" + mock_async_tool.description = "test description" + mock_async_tool.args_schema = BaseModel + mock_async_tool._AsyncToolboxTool__name = "test_tool" + mock_async_tool._AsyncToolboxTool__schema = auth_tool_schema + mock_async_tool._AsyncToolboxTool__url = "http://test_url" + mock_async_tool._AsyncToolboxTool__session = Mock() + mock_async_tool._AsyncToolboxTool__auth_tokens = {} + mock_async_tool._AsyncToolboxTool__bound_params = {} + return mock_async_tool + + @pytest.fixture + def toolbox_tool(self, mock_async_tool): + return ToolboxTool( + async_tool=mock_async_tool, + loop=Mock(), + thread=Mock(), + ) + @pytest.fixture + def auth_toolbox_tool(self, mock_async_auth_tool): + return ToolboxTool( + async_tool=mock_async_auth_tool, + loop=Mock(), + thread=Mock(), + ) -@pytest.fixture -@patch("aiohttp.ClientSession") -async def auth_toolbox_tool(MockClientSession, auth_tool_schema): - mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} - ) - with pytest.warns( - UserWarning, - match="Parameter\(s\) \`param1\` of tool test_tool require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", - ): + def test_toolbox_tool_init(self, mock_async_tool): tool = ToolboxTool( - name="test_tool", - schema=auth_tool_schema, - url="https://test-url", - session=mock_session, + async_tool=mock_async_tool, + loop=Mock(), + thread=Mock(), ) - yield tool - - -@pytest.mark.asyncio -@patch("toolbox_langchain_sdk.client.ClientSession") -async def test_toolbox_tool_init(MockClientSession, tool_schema): - mock_session = MockClientSession.return_value - tool = ToolboxTool( - name="test_tool", - schema=tool_schema, - url="https://test-url", - session=mock_session, + async_tool = tool._ToolboxTool__async_tool + assert async_tool.name == mock_async_tool.name + assert async_tool.description == mock_async_tool.description + assert async_tool.args_schema == mock_async_tool.args_schema + + @pytest.mark.parametrize( + "params, expected_bound_params", + [ + ({"param1": "bound-value"}, {"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), + ( + {"param1": "bound-value", "param2": 123}, + {"param1": "bound-value", "param2": 123}, + ), + ], ) - assert tool.name == "test_tool" - assert tool.description == "Test Tool Description" + def test_toolbox_tool_bind_params( + self, + params, + expected_bound_params, + toolbox_tool, + mock_async_tool, + ): + mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_params + mock_async_tool.bind_params.return_value = mock_async_tool + tool = toolbox_tool.bind_params(params) + mock_async_tool.bind_params.assert_called_once_with(params, True) + assert isinstance(tool, ToolboxTool) -@pytest.mark.asyncio -@pytest.mark.parametrize( - "params, expected_bound_params", - [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), - ], -) -async def test_toolbox_tool_bind_params(toolbox_tool, params, expected_bound_params): - async for tool in toolbox_tool: - tool = tool.bind_params(params) for key, value in expected_bound_params.items(): + async_tool_bound_param_val = ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params[key] + ) if callable(value): - assert value() == tool._bound_params[key]() + assert value() == async_tool_bound_param_val() else: - assert value == tool._bound_params[key] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("strict", [True, False]) -async def test_toolbox_tool_bind_params_invalid(toolbox_tool, strict): - async for tool in toolbox_tool: - if strict: - with pytest.raises(ValueError) as e: - tool = tool.bind_params({"param3": "bound-value"}, strict=strict) - assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) - else: - with pytest.warns(UserWarning) as record: - tool = tool.bind_params({"param3": "bound-value"}, strict=strict) - assert len(record) == 1 - assert "Parameter(s) param3 missing and cannot be bound." in str( - record[0].message - ) + assert value == async_tool_bound_param_val + def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): + expected_bound_param = {"param1": "bound-value"} + mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_param + mock_async_tool.bind_param.return_value = mock_async_tool -@pytest.mark.asyncio -async def test_toolbox_tool_bind_params_duplicate(toolbox_tool): - async for tool in toolbox_tool: - tool = tool.bind_params({"param1": "bound-value"}) - with pytest.raises(ValueError) as e: - tool = tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( - e.value + tool = toolbox_tool.bind_param("param1", "bound-value") + mock_async_tool.bind_param.assert_called_once_with( + "param1", "bound-value", True ) - -@pytest.mark.asyncio -async def test_toolbox_tool_bind_params_invalid_params(auth_toolbox_tool): - async for tool in auth_toolbox_tool: - with pytest.raises(ValueError) as e: - tool = tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) param1 already authenticated and cannot be bound." in str( - e.value + assert ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params + == expected_bound_param + ) + assert isinstance(tool, ToolboxTool) + + @pytest.mark.parametrize( + "auth_tokens, expected_auth_tokens", + [ + ( + {"test-auth-source": lambda: "test-token"}, + {"test-auth-source": lambda: "test-token"}, + ), + ( + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + ), + ], + ) + def test_toolbox_tool_add_auth_tokens( + self, + auth_tokens, + expected_auth_tokens, + mock_async_auth_tool, + auth_toolbox_tool, + ): + auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens = ( + expected_auth_tokens + ) + auth_toolbox_tool._ToolboxTool__async_tool.add_auth_tokens.return_value = ( + mock_async_auth_tool ) - -@pytest.mark.asyncio -async def test_toolbox_tool_bind_param(toolbox_tool): - async for tool in toolbox_tool: - tool = tool.bind_param("param1", "bound-value") - assert tool._bound_params == {"param1": "bound-value"} - - -@pytest.mark.asyncio -@pytest.mark.parametrize("strict", [True, False]) -async def test_toolbox_tool_bind_param_invalid(toolbox_tool, strict): - async for tool in toolbox_tool: - if strict: - with pytest.raises(ValueError) as e: - tool = tool.bind_param("param3", "bound-value", strict=strict) - assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) - else: - with pytest.warns(UserWarning) as record: - tool = tool.bind_param("param3", "bound-value", strict=strict) - assert len(record) == 1 - assert "Parameter(s) param3 missing and cannot be bound." in str( - record[0].message + tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) + mock_async_auth_tool.add_auth_tokens.assert_called_once_with(auth_tokens, True) + for source, getter in expected_auth_tokens.items(): + assert ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[source]() + == getter() ) + assert isinstance(tool, ToolboxTool) - -@pytest.mark.asyncio -async def test_toolbox_tool_bind_param_duplicate(toolbox_tool): - async for tool in toolbox_tool: - tool = tool.bind_param("param1", "bound-value") - with pytest.raises(ValueError) as e: - tool = tool.bind_param("param1", "bound-value") - assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( - e.value + def test_toolbox_tool_add_auth_token(self, mock_async_auth_tool, auth_toolbox_tool): + get_id_token = lambda: "test-token" + expected_auth_tokens = {"test-auth-source": get_id_token} + auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens = ( + expected_auth_tokens + ) + auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token.return_value = ( + mock_async_auth_tool ) + tool = auth_toolbox_tool.add_auth_token("test-auth-source", get_id_token) + mock_async_auth_tool.add_auth_token.assert_called_once_with( + "test-auth-source", get_id_token, True + ) -@pytest.mark.asyncio -@pytest.mark.parametrize( - "auth_tokens, expected_auth_tokens", - [ - ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), - ( - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - ), - ], -) -async def test_toolbox_tool_add_auth_tokens( - auth_toolbox_tool, auth_tokens, expected_auth_tokens -): - async for tool in auth_toolbox_tool: - tool = tool.add_auth_tokens(auth_tokens) - for source, getter in expected_auth_tokens.items(): - assert tool._auth_tokens[source]() == getter() - - -@pytest.mark.asyncio -async def test_toolbox_tool_add_auth_tokens_duplicate(auth_toolbox_tool): - async for tool in auth_toolbox_tool: - tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) - with pytest.raises(ValueError) as e: - tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) assert ( - "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." - in str(e.value) + tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[ + "test-auth-source" + ]() + == "test-token" ) + assert isinstance(tool, ToolboxTool) - -@pytest.mark.asyncio -async def test_toolbox_tool_add_auth_token(auth_toolbox_tool): - async for tool in auth_toolbox_tool: - tool = tool.add_auth_token("test-auth-source", lambda: "test-token") - assert tool._auth_tokens["test-auth-source"]() == "test-token" - - -@pytest.mark.asyncio -async def test_toolbox_tool_validate_auth_strict(auth_toolbox_tool): - async for tool in auth_toolbox_tool: + def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): + auth_toolbox_tool._ToolboxTool__async_tool._arun = Mock( + side_effect=PermissionError( + "Parameter(s) `param1` of tool test_tool require authentication" + ) + ) with pytest.raises(PermissionError) as e: - tool._ToolboxTool__validate_auth(strict=True) - assert ( - "Parameter(s) `param1` of tool test_tool require authentication, but no valid authentication sources are registered. Please register the required sources before use." - in str(e.value) + auth_toolbox_tool._run() + assert "Parameter(s) `param1` of tool test_tool require authentication" in str( + e.value ) - - -@pytest.mark.asyncio -async def test_toolbox_tool_call_with_callable_bound_params(toolbox_tool): - async for tool in toolbox_tool: - tool = tool.bind_param("param1", lambda: "bound-value") - result = await tool.ainvoke({"param2": 123}) - assert result == {"result": "test-result"} - - -@pytest.mark.asyncio -async def test_toolbox_tool_call(toolbox_tool): - async for tool in toolbox_tool: - result = await tool.ainvoke({"param1": "test-value", "param2": 123}) - assert result == {"result": "test-result"} - - -@pytest.mark.asyncio -async def test_toolbox_sync_tool_call_(toolbox_tool): - async for tool in toolbox_tool: - with pytest.raises(NotImplementedError) as e: - result = tool.invoke({"param1": "test-value", "param2": 123}) - assert "Sync tool calls not supported yet." in str(e.value) - - -@pytest.mark.asyncio -async def test_toolbox_tool_call_with_bound_params(toolbox_tool): - async for tool in toolbox_tool: - tool = tool.bind_params({"param1": "bound-value"}) - result = await tool.ainvoke({"param2": 123}) - assert result == {"result": "test-result"} - - -@pytest.mark.asyncio -async def test_toolbox_tool_call_with_auth_tokens(auth_toolbox_tool): - async for tool in auth_toolbox_tool: - tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) - result = await tool.ainvoke({"param2": 123}) - assert result == {"result": "test-result"} - - -@pytest.mark.asyncio -async def test_toolbox_tool_call_with_auth_tokens_insecure(auth_toolbox_tool): - async for tool in auth_toolbox_tool: - with pytest.warns( - UserWarning, - match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", - ): - tool._url = "http://test-url" - tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) - result = await tool.ainvoke({"param2": 123}) - assert result == {"result": "test-result"} - - -@pytest.mark.asyncio -async def test_toolbox_tool_call_with_invalid_input(toolbox_tool): - async for tool in toolbox_tool: - with pytest.raises(ValidationError) as e: - await tool.ainvoke({"param1": 123, "param2": "invalid"}) - assert "2 validation errors for test_tool" in str(e.value) - assert "param1\n Input should be a valid string" in str(e.value) - assert "param2\n Input should be a valid integer" in str(e.value) - - -@pytest.mark.asyncio -async def test_toolbox_tool_call_with_empty_input(toolbox_tool): - async for tool in toolbox_tool: - with pytest.raises(ValidationError) as e: - await tool.ainvoke({}) - assert "2 validation errors for test_tool" in str(e.value) - assert "param1\n Field required" in str(e.value) - assert "param2\n Field required" in str(e.value)