diff --git a/packages/toolbox-core/src/toolbox_core/sync_client.py b/packages/toolbox-core/src/toolbox_core/sync_client.py index 96d4bba7..7ec94a44 100644 --- a/packages/toolbox-core/src/toolbox_core/sync_client.py +++ b/packages/toolbox-core/src/toolbox_core/sync_client.py @@ -56,7 +56,6 @@ def __init__( async def create_client(): return ToolboxClient(url, client_headers=client_headers) - # Ignoring type since we're already checking the existence of a loop above. self.__async_client = asyncio.run_coroutine_threadsafe( create_client(), self.__class__.__loop ).result() diff --git a/packages/toolbox-langchain/README.md b/packages/toolbox-langchain/README.md index 9f698694..fca7736b 100644 --- a/packages/toolbox-langchain/README.md +++ b/packages/toolbox-langchain/README.md @@ -227,7 +227,7 @@ tools = toolbox.load_toolset() auth_tool = tools[0].add_auth_token_getter("my_auth", get_auth_token) # Single token -multi_auth_tool = tools[0].add_auth_token_getters({"my_auth", get_auth_token}) # Multiple tokens +multi_auth_tool = tools[0].add_auth_token_getters({"auth_1": get_auth_1}, {"auth_2": get_auth_2}) # Multiple tokens # OR diff --git a/packages/toolbox-langchain/integration.cloudbuild.yaml b/packages/toolbox-langchain/integration.cloudbuild.yaml index 644794fb..51f0ce81 100644 --- a/packages/toolbox-langchain/integration.cloudbuild.yaml +++ b/packages/toolbox-langchain/integration.cloudbuild.yaml @@ -15,10 +15,11 @@ steps: - id: Install library requirements name: 'python:${_VERSION}' + dir: 'packages/toolbox-langchain' args: - install - '-r' - - 'packages/toolbox-langchain/requirements.txt' + - 'requirements.txt' - '--user' entrypoint: pip - id: Install test requirements diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index f4f5b7aa..9aaa254a 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,6 +9,8 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ + # TODO: Bump lowest supported version to 0.2.0 + "toolbox-core>=0.1.0,<1.0.0", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.7.0,<3.0.0", diff --git a/packages/toolbox-langchain/requirements.txt b/packages/toolbox-langchain/requirements.txt index 5fd65843..3ada831d 100644 --- a/packages/toolbox-langchain/requirements.txt +++ b/packages/toolbox-langchain/requirements.txt @@ -1,3 +1,4 @@ +-e ../toolbox-core langchain-core==0.3.56 PyYAML==6.0.2 pydantic==2.11.4 diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py index c58bbfdf..95e384c8 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py @@ -16,9 +16,9 @@ from warnings import warn from aiohttp import ClientSession +from toolbox_core.client import ToolboxClient as ToolboxCoreClient -from .tools import AsyncToolboxTool -from .utils import ManifestSchema, _load_manifest +from .async_tools import AsyncToolboxTool # This class is an internal implementation detail and is not exposed to the @@ -38,8 +38,7 @@ def __init__( url: The base URL of the Toolbox service. session: An HTTP client session. """ - self.__url = url - self.__session = session + self.__core_client = ToolboxCoreClient(url=url, session=session) async def aload_tool( self, @@ -48,7 +47,6 @@ async def aload_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> AsyncToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. @@ -61,9 +59,6 @@ async def aload_tool( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. @@ -94,18 +89,12 @@ async def aload_tool( ) auth_token_getters = auth_headers - url = f"{self.__url}/api/tool/{tool_name}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - - return AsyncToolboxTool( - tool_name, - manifest.tools[tool_name], - self.__url, - self.__session, - auth_token_getters, - bound_params, - strict, + core_tool = await self.__core_client.load_tool( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) + return AsyncToolboxTool(core_tool=core_tool) async def aload_toolset( self, @@ -114,7 +103,7 @@ async def aload_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[AsyncToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -129,9 +118,11 @@ async def aload_toolset( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. @@ -162,22 +153,16 @@ async def aload_toolset( ) auth_token_getters = auth_headers - url = f"{self.__url}/api/toolset/{toolset_name or ''}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - tools: list[AsyncToolboxTool] = [] - - for tool_name, tool_schema in manifest.tools.items(): - tools.append( - AsyncToolboxTool( - tool_name, - tool_schema, - self.__url, - self.__session, - auth_token_getters, - bound_params, - strict, - ) - ) + core_tools = await self.__core_client.load_toolset( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, + ) + + tools = [] + for core_tool in core_tools: + tools.append(AsyncToolboxTool(core_tool=core_tool)) return tools def load_tool( @@ -187,7 +172,6 @@ def load_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> AsyncToolboxTool: raise NotImplementedError("Synchronous methods not supported by async client.") @@ -198,6 +182,6 @@ def load_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[AsyncToolboxTool]: raise NotImplementedError("Synchronous methods not supported by async client.") diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index 40e21ee6..627b18e1 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -12,22 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from typing import Any, Callable, TypeVar, Union -from warnings import warn +from typing import Any, Callable, Union -from aiohttp import ClientSession from langchain_core.tools import BaseTool - -from .utils import ( - ToolSchema, - _find_auth_params, - _find_bound_params, - _invoke_tool, - _schema_to_model, -) - -T = TypeVar("T") +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool +from toolbox_core.utils import params_to_pydantic_model # This class is an internal implementation detail and is not exposed to the @@ -41,109 +30,28 @@ class AsyncToolboxTool(BaseTool): def __init__( self, - name: str, - schema: ToolSchema, - url: str, - session: ClientSession, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + core_tool: ToolboxCoreTool, ) -> None: """ Initializes an AsyncToolboxTool instance. Args: - name: The name of the tool. - schema: The tool schema. - url: The base URL of the Toolbox service. - session: The HTTP client session. - auth_token_getters: A mapping of authentication source names to - functions that retrieve ID tokens. - bound_params: A mapping of parameter names to their bound - values. - strict: If True, raises a ValueError if any of the given bound - parameters is missing from the schema or requires - authentication. If False, only issues a warning. + core_tool: The underlying core async ToolboxTool instance. """ - # If the schema is not already a ToolSchema instance, we create one from - # its attributes. This allows flexibility in how the schema is provided, - # accepting both a ToolSchema object and a dictionary of schema - # attributes. - if not isinstance(schema, ToolSchema): - schema = ToolSchema(**schema) - - auth_params, non_auth_params = _find_auth_params(schema.parameters) - non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( - non_auth_params, list(bound_params) - ) - - # Check if the user is trying to bind a param that is authenticated or - # is missing from the given schema. - auth_bound_params: list[str] = [] - missing_bound_params: list[str] = [] - for bound_param in bound_params: - if bound_param in [param.name for param in auth_params]: - auth_bound_params.append(bound_param) - elif bound_param not in [param.name for param in non_auth_params]: - missing_bound_params.append(bound_param) - - # Create error messages for any params that are found to be - # authenticated or missing. - messages: list[str] = [] - if auth_bound_params: - messages.append( - f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." - ) - if missing_bound_params: - messages.append( - f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound." - ) - - # Join any error messages and raise them as an error or warning, - # depending on the value of the strict flag. - if messages: - message = "\n\n".join(messages) - if strict: - raise ValueError(message) - warn(message) - - # Bind values for parameters present in the schema that don't require - # authentication. - bound_params = { - param_name: param_value - for param_name, param_value in bound_params.items() - if param_name in [param.name for param in non_auth_bound_params] - } - - # Update the tools schema to validate only the presence of parameters - # that neither require authentication nor are bound. - schema.parameters = non_auth_non_bound_params - # Due to how pydantic works, we must initialize the underlying # BaseTool class before assigning values to member variables. super().__init__( - name=name, - description=schema.description, - args_schema=_schema_to_model(model_name=name, schema=schema.parameters), + name=core_tool.__name__, + description=core_tool.__doc__, + args_schema=params_to_pydantic_model(core_tool._name, core_tool._params), ) + self.__core_tool = core_tool - self.__name = name - self.__schema = schema - self.__url = url - self.__session = session - self.__auth_token_getters = auth_token_getters - self.__auth_params = auth_params - self.__bound_params = bound_params - - # Warn users about any missing authentication so they can add it before - # tool invocation. - self.__validate_auth(strict=False) - - def _run(self, **kwargs: Any) -> dict[str, Any]: + def _run(self, **kwargs: Any) -> str: raise NotImplementedError("Synchronous methods not supported by async tools.") - async def _arun(self, **kwargs: Any) -> dict[str, Any]: + async def _arun(self, **kwargs: Any) -> str: """ The coroutine that invokes the tool with the given arguments. @@ -154,140 +62,10 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: A dictionary containing the parsed JSON response from the tool invocation. """ - - # If the tool had parameters that require authentication, then right - # before invoking that tool, we check whether all these required - # authentication sources have been registered or not. - self.__validate_auth() - - # Evaluate dynamic parameter values if any - evaluated_params = {} - for param_name, param_value in self.__bound_params.items(): - if callable(param_value): - evaluated_params[param_name] = param_value() - else: - evaluated_params[param_name] = param_value - - # Merge bound parameters with the provided arguments - kwargs.update(evaluated_params) - - return await _invoke_tool( - self.__url, self.__session, self.__name, kwargs, self.__auth_token_getters - ) - - def __validate_auth(self, strict: bool = True) -> None: - """ - Checks if a tool meets the authentication requirements. - - A tool is considered authenticated if all of its parameters meet at - least one of the following conditions: - - * The parameter has at least one registered authentication source. - * The parameter requires no authentication. - - Args: - strict: If True, raises a PermissionError if any required - authentication sources are not registered. If False, only issues - a warning. - - Raises: - PermissionError: If strict is True and any required authentication - sources are not registered. - """ - is_authenticated: bool = not self.__schema.authRequired - params_missing_auth: list[str] = [] - - # Check tool for at least 1 required auth source - for src in self.__schema.authRequired: - if src in self.__auth_token_getters: - is_authenticated = True - break - - # Check each parameter for at least 1 required auth source - for param in self.__auth_params: - if not param.authSources: - raise ValueError("Auth sources cannot be None.") - has_auth = False - for src in param.authSources: - - # Find first auth source that is specified - if src in self.__auth_token_getters: - has_auth = True - break - if not has_auth: - params_missing_auth.append(param.name) - - messages: list[str] = [] - - if not is_authenticated: - messages.append( - f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use." - ) - - if params_missing_auth: - messages.append( - f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." - ) - - if messages: - message = "\n\n".join(messages) - if strict: - raise PermissionError(message) - warn(message) - - def __create_copy( - self, - *, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool, - ) -> "AsyncToolboxTool": - """ - Creates a copy of the current AsyncToolboxTool instance, allowing for - modification of auth tokens and bound params. - - This method enables the creation of new tool instances with inherited - properties from the current instance, while optionally updating the auth - tokens and bound params. This is useful for creating variations of the - tool with additional auth tokens or bound params without modifying the - original instance, ensuring immutability. - - Args: - auth_token_getters: A dictionary of auth source names to functions - that retrieve ID tokens. These tokens will be merged with the - existing auth tokens. - bound_params: A dictionary of parameter names to their - bound values or functions to retrieve the values. These params - will be merged with the existing bound params. - strict: If True, raises a ValueError if any of the given bound - parameters is missing from the schema or requires - authentication. If False, only issues a warning. - - Returns: - A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens or bound params. - """ - new_schema = deepcopy(self.__schema) - - # Reconstruct the complete parameter schema by merging the auth - # parameters back with the non-auth parameters. This is necessary to - # accurately validate the new combination of auth tokens and bound - # params in the constructor of the new AsyncToolboxTool instance, ensuring - # that any overlaps or conflicts are correctly identified and reported - # as errors or warnings, depending on the given `strict` flag. - new_schema.parameters += self.__auth_params - return AsyncToolboxTool( - name=self.__name, - schema=new_schema, - url=self.__url, - session=self.__session, - auth_token_getters={**self.__auth_token_getters, **auth_token_getters}, - bound_params={**self.__bound_params, **bound_params}, - strict=strict, - ) + return await self.__core_tool(**kwargs) def add_auth_token_getters( - self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True + self, auth_token_getters: dict[str, Callable[[], str]] ) -> "AsyncToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding @@ -296,36 +74,21 @@ def add_auth_token_getters( Args: auth_token_getters: A dictionary of authentication source names to the functions that return corresponding ID token getters. - strict: If True, a ValueError is raised if any of the provided auth - parameters is already bound. If False, only a warning is issued. Returns: A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. + instance, with added auth token getters. Raises: ValueError: If any of the provided auth parameters is already registered. - ValueError: If any of the provided auth parameters is already bound - and strict is True. """ - - # Check if the authentication source is already registered. - dupe_tokens: list[str] = [] - for auth_token, _ in auth_token_getters.items(): - if auth_token in self.__auth_token_getters: - dupe_tokens.append(auth_token) - - if dupe_tokens: - raise ValueError( - f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." - ) - - return self.__create_copy(auth_token_getters=auth_token_getters, strict=strict) + new_core_tool = self.__core_tool.add_auth_token_getters(auth_token_getters) + return AsyncToolboxTool(core_tool=new_core_tool) def add_auth_token_getter( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + self, auth_source: str, get_id_token: Callable[[], str] ) -> "AsyncToolboxTool": """ Registers a function to retrieve an ID token for a given authentication @@ -334,24 +97,20 @@ def add_auth_token_getter( Args: auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if the provided auth - parameter is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth token. + instance, with added auth token getter. Raises: ValueError: If the provided auth parameter is already registered. - ValueError: If the provided auth parameter is already bound and - strict is True. + """ - return self.add_auth_token_getters({auth_source: get_id_token}, strict=strict) + return self.add_auth_token_getters({auth_source: get_id_token}) def bind_params( self, bound_params: dict[str, Union[Any, Callable[[], Any]]], - strict: bool = True, ) -> "AsyncToolboxTool": """ Registers values or functions to retrieve the value for the @@ -360,9 +119,6 @@ def bind_params( Args: bound_params: A dictionary of the bound parameter name to the value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new AsyncToolboxTool instance that is a deep copy of the current @@ -370,29 +126,14 @@ def bind_params( Raises: ValueError: If any of the provided bound params is already bound. - ValueError: if any of the provided bound params is not defined in - the tool's schema, or requires authentication, and strict is - True. """ - - # Check if the parameter is already bound. - dupe_params: list[str] = [] - for param_name, _ in bound_params.items(): - if param_name in self.__bound_params: - dupe_params.append(param_name) - - if dupe_params: - raise ValueError( - f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`." - ) - - return self.__create_copy(bound_params=bound_params, strict=strict) + new_core_tool = self.__core_tool.bind_params(bound_params) + return AsyncToolboxTool(core_tool=new_core_tool) def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], - strict: bool = True, ) -> "AsyncToolboxTool": """ Registers a value or a function to retrieve the value for a given bound @@ -402,9 +143,6 @@ def bind_param( param_name: The name of the bound parameter. param_value: The value of the bound parameter, or a callable that returns the value. - strict: If True, a ValueError is raised if the provided bound param - is not defined in the tool's schema, or requires authentication. - If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -412,7 +150,5 @@ def bind_param( Raises: ValueError: If the provided bound param is already bound. - ValueError: if the provided bound param is not defined in the tool's - schema, or requires authentication, and strict is True. """ - return self.bind_params({param_name: param_value}, strict) + return self.bind_params({param_name: param_value}) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 3c75779c..1d395585 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -12,22 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from threading import Thread -from typing import Any, Awaitable, Callable, Optional, TypeVar, Union +from asyncio import to_thread +from typing import Any, Callable, Optional, Union +from warnings import warn -from aiohttp import ClientSession +from toolbox_core.sync_client import ToolboxSyncClient as ToolboxCoreSyncClient +from toolbox_core.sync_tool import ToolboxSyncTool -from .async_client import AsyncToolboxClient from .tools import ToolboxTool -T = TypeVar("T") - class ToolboxClient: - __session: Optional[ClientSession] = None - __loop: Optional[asyncio.AbstractEventLoop] = None - __thread: Optional[Thread] = None def __init__( self, @@ -39,51 +34,7 @@ def __init__( Args: url: The base URL of the Toolbox service. """ - - # Running a loop in a background thread allows us to support async - # methods from non-async environments. - if ToolboxClient.__loop is None: - loop = asyncio.new_event_loop() - thread = Thread(target=loop.run_forever, daemon=True) - thread.start() - ToolboxClient.__thread = thread - ToolboxClient.__loop = loop - - async def __start_session() -> None: - - # Use a default session if none is provided. This leverages connection - # pooling for better performance by reusing a single session throughout - # the application's lifetime. - if ToolboxClient.__session is None: - ToolboxClient.__session = ClientSession() - - coro = __start_session() - - asyncio.run_coroutine_threadsafe(coro, ToolboxClient.__loop).result() - - if not ToolboxClient.__session: - raise ValueError("Session cannot be None.") - self.__async_client = AsyncToolboxClient(url, ToolboxClient.__session) - - def __run_as_sync(self, coro: Awaitable[T]) -> T: - """Run an async coroutine synchronously""" - if not self.__loop: - raise Exception( - "Cannot call synchronous methods before the background loop is initialized." - ) - return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() - - async def __run_as_async(self, coro: Awaitable[T]) -> T: - """Run an async coroutine asynchronously""" - - # If a loop has not been provided, attempt to run in current thread. - if not self.__loop: - return await coro - - # Otherwise, run in the background thread. - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__loop) - ) + self.__core_client = ToolboxCoreSyncClient(url=url) async def aload_tool( self, @@ -92,7 +43,6 @@ async def aload_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> ToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. @@ -105,27 +55,43 @@ async def aload_tool( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ - async_tool = await self.__run_as_async( - self.__async_client.aload_tool( - tool_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_tool = await to_thread( + self.__core_client.load_tool, + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - return ToolboxTool(async_tool, self.__loop, self.__thread) + return ToolboxTool(core_tool=core_tool) async def aload_toolset( self, @@ -134,7 +100,7 @@ async def aload_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[ToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -149,30 +115,52 @@ async def aload_toolset( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - async_tools = await self.__run_as_async( - self.__async_client.aload_toolset( - toolset_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_tools = await to_thread( + self.__core_client.load_toolset, + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, ) - tools: list[ToolboxTool] = [] - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - for async_tool in async_tools: - tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + tools = [] + for core_tool in core_tools: + tools.append(ToolboxTool(core_tool=core_tool)) return tools def load_tool( @@ -182,7 +170,6 @@ def load_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> ToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. @@ -195,27 +182,42 @@ def load_tool( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ - async_tool = self.__run_as_sync( - self.__async_client.aload_tool( - tool_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_sync_tool = self.__core_client.load_tool( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - return ToolboxTool(async_tool, self.__loop, self.__thread) + return ToolboxTool(core_tool=core_sync_tool) def load_toolset( self, @@ -224,7 +226,7 @@ def load_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[ToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -239,27 +241,49 @@ def load_toolset( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - async_tools = self.__run_as_sync( - self.__async_client.aload_toolset( - toolset_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_sync_tools = self.__core_client.load_toolset( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, ) - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - tools: list[ToolboxTool] = [] - for async_tool in async_tools: - tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + tools = [] + for core_sync_tool in core_sync_tools: + tools.append(ToolboxTool(core_tool=core_sync_tool)) return tools diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index feb2a597..fd7ab197 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from asyncio import AbstractEventLoop -from threading import Thread -from typing import Any, Awaitable, Callable, TypeVar, Union +from asyncio import to_thread +from typing import Any, Callable, Union from langchain_core.tools import BaseTool - -from .async_tools import AsyncToolboxTool - -T = TypeVar("T") +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.utils import params_to_pydantic_model class ToolboxTool(BaseTool): @@ -32,59 +28,32 @@ class ToolboxTool(BaseTool): def __init__( self, - async_tool: AsyncToolboxTool, - loop: AbstractEventLoop, - thread: Thread, + core_tool: ToolboxCoreSyncTool, ) -> None: """ Initializes a ToolboxTool instance. Args: - async_tool: The underlying AsyncToolboxTool instance. - loop: The event loop used to run asynchronous tasks. - thread: The thread to run blocking operations in. + core_tool: The underlying core sync ToolboxTool instance. """ # Due to how pydantic works, we must initialize the underlying # BaseTool class before assigning values to member variables. super().__init__( - name=async_tool.name, - description=async_tool.description, - args_schema=async_tool.args_schema, - ) - - self.__async_tool = async_tool - self.__loop = loop - self.__thread = thread - - def __run_as_sync(self, coro: Awaitable[T]) -> T: - """Run an async coroutine synchronously""" - if not self.__loop: - raise Exception( - "Cannot call synchronous methods before the background loop is initialized." - ) - return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() - - async def __run_as_async(self, coro: Awaitable[T]) -> T: - """Run an async coroutine asynchronously""" - - # If a loop has not been provided, attempt to run in current thread. - if not self.__loop: - return await coro - - # Otherwise, run in the background thread. - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__loop) + name=core_tool.__name__, + description=core_tool.__doc__, + args_schema=params_to_pydantic_model(core_tool._name, core_tool._params), ) + self.__core_tool = core_tool - def _run(self, **kwargs: Any) -> dict[str, Any]: - return self.__run_as_sync(self.__async_tool._arun(**kwargs)) + def _run(self, **kwargs: Any) -> str: + return self.__core_tool(**kwargs) - async def _arun(self, **kwargs: Any) -> dict[str, Any]: - return await self.__run_as_async(self.__async_tool._arun(**kwargs)) + async def _arun(self, **kwargs: Any) -> str: + return await to_thread(self.__core_tool, **kwargs) def add_auth_token_getters( - self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True + self, auth_token_getters: dict[str, Callable[[], str]] ) -> "ToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding @@ -93,27 +62,20 @@ def add_auth_token_getters( Args: auth_token_getters: A dictionary of authentication source names to the functions that return corresponding ID token. - strict: If True, a ValueError is raised if any of the provided auth - parameters is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. + instance, with added auth token getters. Raises: ValueError: If any of the provided auth parameters is already registered. - ValueError: If any of the provided auth parameters is already bound - and strict is True. """ - return ToolboxTool( - self.__async_tool.add_auth_token_getters(auth_token_getters, strict), - self.__loop, - self.__thread, - ) + new_core_tool = self.__core_tool.add_auth_token_getters(auth_token_getters) + return ToolboxTool(core_tool=new_core_tool) def add_auth_token_getter( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + self, auth_source: str, get_id_token: Callable[[], str] ) -> "ToolboxTool": """ Registers a function to retrieve an ID token for a given authentication @@ -122,28 +84,19 @@ def add_auth_token_getter( Args: auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if the provided auth - parameter is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth token. + instance, with added auth token getter. Raises: ValueError: If the provided auth parameter is already registered. - ValueError: If the provided auth parameter is already bound and - strict is True. """ - return ToolboxTool( - self.__async_tool.add_auth_token_getter(auth_source, get_id_token, strict), - self.__loop, - self.__thread, - ) + return self.add_auth_token_getters({auth_source: get_id_token}) def bind_params( self, bound_params: dict[str, Union[Any, Callable[[], Any]]], - strict: bool = True, ) -> "ToolboxTool": """ Registers values or functions to retrieve the value for the @@ -152,9 +105,6 @@ def bind_params( Args: bound_params: A dictionary of the bound parameter name to the value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -162,21 +112,14 @@ def bind_params( Raises: ValueError: If any of the provided bound params is already bound. - ValueError: if any of the provided bound params is not defined in - the tool's schema, or require authentication, and strict is - True. """ - return ToolboxTool( - self.__async_tool.bind_params(bound_params, strict), - self.__loop, - self.__thread, - ) + new_core_tool = self.__core_tool.bind_params(bound_params) + return ToolboxTool(core_tool=new_core_tool) def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], - strict: bool = True, ) -> "ToolboxTool": """ Registers a value or a function to retrieve the value for a given bound @@ -186,9 +129,6 @@ def bind_param( param_name: The name of the bound parameter. param_value: The value of the bound parameter, or a callable that returns the value. - strict: If True, a ValueError is raised if the provided bound - param is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -196,11 +136,5 @@ def bind_param( Raises: ValueError: If the provided bound param is already bound. - ValueError: if the provided bound param is not defined in the tool's - schema, or requires authentication, and strict is True. """ - return ToolboxTool( - self.__async_tool.bind_param(param_name, param_value, strict), - self.__loop, - self.__thread, - ) + return self.bind_params({param_name: param_value}) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/utils.py b/packages/toolbox-langchain/src/toolbox_langchain/utils.py deleted file mode 100644 index 985c7bfe..00000000 --- a/packages/toolbox-langchain/src/toolbox_langchain/utils.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from typing import Any, Callable, Optional, Type, cast -from warnings import warn - -from aiohttp import ClientSession -from deprecated import deprecated -from langchain_core.tools import ToolException -from pydantic import BaseModel, Field, create_model - - -class ParameterSchema(BaseModel): - """ - Schema for a tool parameter. - """ - - name: str - type: str - description: str - authSources: Optional[list[str]] = None - items: Optional["ParameterSchema"] = None - - -class ToolSchema(BaseModel): - """ - Schema for a tool. - """ - - description: str - parameters: list[ParameterSchema] - authRequired: list[str] = [] - - -class ManifestSchema(BaseModel): - """ - Schema for the Toolbox manifest. - """ - - serverVersion: str - tools: dict[str, ToolSchema] - - -async def _load_manifest(url: str, session: ClientSession) -> ManifestSchema: - """ - Asynchronously fetches and parses the JSON manifest schema from the given - URL. - - Args: - url: The URL to fetch the JSON from. - session: The HTTP client session. - - Returns: - The parsed Toolbox manifest. - - Raises: - json.JSONDecodeError: If the response is not valid JSON. - ValueError: If the response is not a valid manifest. - """ - async with session.get(url) as response: - # TODO: Remove as it masks error messages. - response.raise_for_status() - try: - # TODO: Simply use response.json() - parsed_json = json.loads(await response.text()) - except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Failed to parse JSON from {url}: {e}", e.doc, e.pos - ) from e - try: - return ManifestSchema(**parsed_json) - except ValueError as e: - raise ValueError(f"Invalid JSON data from {url}: {e}") from e - - -def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]: - """ - Converts the given manifest schema to a Pydantic BaseModel class. - - Args: - model_name: The name of the model to create. - schema: The schema to convert. - - Returns: - A Pydantic BaseModel class. - """ - field_definitions = {} - for field in schema: - field_definitions[field.name] = cast( - Any, - ( - _parse_type(field), - Field(description=field.description), - ), - ) - - return create_model(model_name, **field_definitions) - - -def _parse_type(schema_: ParameterSchema) -> Any: - """ - Converts a schema type to a JSON type. - - Args: - schema_: The ParameterSchema to convert. - - Returns: - A valid JSON type. - - Raises: - ValueError: If the given type is not supported. - """ - type_ = schema_.type - - if type_ == "string": - return str - elif type_ == "integer": - return int - elif type_ == "float": - return float - elif type_ == "boolean": - return bool - elif type_ == "array": - if isinstance(schema_, ParameterSchema) and schema_.items: - return list[_parse_type(schema_.items)] # type: ignore - else: - raise ValueError(f"Schema missing field items") - else: - raise ValueError(f"Unsupported schema type: {type_}") - - -@deprecated("Please use `_get_auth_tokens` instead.") -def _get_auth_headers(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: - """ - Deprecated. Use `_get_auth_tokens` instead. - """ - return _get_auth_tokens(id_token_getters) - - -def _get_auth_tokens(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: - """ - Gets ID tokens for the given auth sources in the getters map and returns - tokens to be included in tool invocation. - - Args: - id_token_getters: A dict that maps auth source names to the functions - that return its ID token. - - Returns: - A dictionary of tokens to be included in the tool invocation. - """ - auth_tokens = {} - for auth_source, get_id_token in id_token_getters.items(): - auth_tokens[f"{auth_source}_token"] = get_id_token() - return auth_tokens - - -async def _invoke_tool( - url: str, - session: ClientSession, - tool_name: str, - data: dict, - id_token_getters: dict[str, Callable[[], str]], -) -> dict: - """ - Asynchronously makes an API call to the Toolbox service to invoke a tool. - - Args: - url: The base URL of the Toolbox service. - session: The HTTP client session. - tool_name: The name of the tool to invoke. - data: The input data for the tool. - id_token_getters: A dict that maps auth source names to the functions - that return its ID token. - - Returns: - A dictionary containing the parsed JSON response from the tool - invocation. - - Raises: - ToolException: If the Toolbox service returns an error. - """ - url = f"{url}/api/tool/{tool_name}/invoke" - auth_tokens = _get_auth_tokens(id_token_getters) - - # ID tokens contain sensitive user information (claims). Transmitting these - # over HTTP exposes the data to interception and unauthorized access. Always - # use HTTPS to ensure secure communication and protect user privacy. - if auth_tokens and not url.startswith("https://"): - warn( - "Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication." - ) - - async with session.post( - url, - json=data, - headers=auth_tokens, - ) as response: - ret = await response.json() - if "error" in ret: - raise ToolException(ret) - return ret.get("result", ret) - - -def _find_auth_params( - params: list[ParameterSchema], -) -> tuple[list[ParameterSchema], list[ParameterSchema]]: - """ - Separates parameters into those that are authenticated and those that are not. - - Args: - params: A list of ParameterSchema objects. - - Returns: - A tuple containing two lists: - - auth_params: A list of ParameterSchema objects that require authentication. - - non_auth_params: A list of ParameterSchema objects that do not require authentication. - """ - _auth_params: list[ParameterSchema] = [] - _non_auth_params: list[ParameterSchema] = [] - - for param in params: - if param.authSources: - _auth_params.append(param) - else: - _non_auth_params.append(param) - - return (_auth_params, _non_auth_params) - - -def _find_bound_params( - params: list[ParameterSchema], bound_params: list[str] -) -> tuple[list[ParameterSchema], list[ParameterSchema]]: - """ - Separates parameters into those that are bound and those that are not. - - Args: - params: A list of ParameterSchema objects. - bound_params: A list of parameter names that are bound. - - Returns: - A tuple containing two lists: - - bound_params: A list of ParameterSchema objects whose names are in the bound_params list. - - non_bound_params: A list of ParameterSchema objects whose names are not in the bound_params list. - """ - - _bound_params: list[ParameterSchema] = [] - _non_bound_params: list[ParameterSchema] = [] - - for param in params: - if param.name in bound_params: - _bound_params.append(param) - else: - _non_bound_params.append(param) - - return (_bound_params, _non_bound_params) diff --git a/packages/toolbox-langchain/tests/test_async_client.py b/packages/toolbox-langchain/tests/test_async_client.py index 25ad78eb..6b398560 100644 --- a/packages/toolbox-langchain/tests/test_async_client.py +++ b/packages/toolbox-langchain/tests/test_async_client.py @@ -17,10 +17,12 @@ import pytest from aiohttp import ClientSession +from toolbox_core.client import ToolboxClient as ToolboxCoreClient +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool from toolbox_langchain.async_client import AsyncToolboxClient from toolbox_langchain.async_tools import AsyncToolboxTool -from toolbox_langchain.utils import ManifestSchema URL = "http://test_url" MANIFEST_JSON = { @@ -52,131 +54,277 @@ @pytest.mark.asyncio class TestAsyncToolboxClient: - @pytest.fixture() - def manifest_schema(self): - return ManifestSchema(**MANIFEST_JSON) - @pytest.fixture() def mock_session(self): return AsyncMock(spec=ClientSession) + @pytest.fixture + def mock_core_client_instance(self, mock_session): + mock = AsyncMock(spec=ToolboxCoreClient) + + async def mock_load_tool_impl(name, auth_token_getters, bound_params): + tool_schema_dict = MANIFEST_JSON["tools"].get(name) + if not tool_schema_dict: + raise ValueError(f"Tool '{name}' not in mock manifest_dict") + + core_params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + # Return a mock that looks like toolbox_core.tool.ToolboxTool + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) + core_tool_mock.__name__ = name + core_tool_mock.__doc__ = tool_schema_dict["description"] + core_tool_mock._name = name + core_tool_mock._params = core_params + # Add other necessary attributes or method mocks if AsyncToolboxTool uses them + return core_tool_mock + + mock.load_tool = AsyncMock(side_effect=mock_load_tool_impl) + + async def mock_load_toolset_impl( + name, auth_token_getters, bound_params, strict + ): + core_tools_list = [] + for tool_name_iter, tool_schema_dict in MANIFEST_JSON["tools"].items(): + core_params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) + core_tool_mock.__name__ = tool_name_iter + core_tool_mock.__doc__ = tool_schema_dict["description"] + core_tool_mock._name = tool_name_iter + core_tool_mock._params = core_params + core_tools_list.append(core_tool_mock) + return core_tools_list + + mock.load_toolset = AsyncMock(side_effect=mock_load_toolset_impl) + # Mock the session attribute if it's directly accessed by AsyncToolboxClient tests + mock._ToolboxClient__session = mock_session + return mock + @pytest.fixture() - def mock_client(self, mock_session): - return AsyncToolboxClient(URL, session=mock_session) + def mock_client(self, mock_session, mock_core_client_instance): + # Patch the ToolboxCoreClient constructor used by AsyncToolboxClient + with patch( + "toolbox_langchain.async_client.ToolboxCoreClient", + return_value=mock_core_client_instance, + ): + client = AsyncToolboxClient(URL, session=mock_session) + # Ensure the mocked core client is used + client._AsyncToolboxClient__core_client = mock_core_client_instance + return client async def test_create_with_existing_session(self, mock_client, mock_session): - assert mock_client._AsyncToolboxClient__session == mock_session + # AsyncToolboxClient stores the core_client, which stores the session + assert ( + mock_client._AsyncToolboxClient__core_client._ToolboxClient__session + == mock_session + ) - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_tool( - self, mock_load_manifest, mock_client, mock_session, manifest_schema + self, + mock_client, ): tool_name = "test_tool_1" - mock_load_manifest.return_value = manifest_schema + test_bound_params = {"bp1": "value1"} - tool = await mock_client.aload_tool(tool_name) + tool = await mock_client.aload_tool(tool_name, bound_params=test_bound_params) - mock_load_manifest.assert_called_once_with( - f"{URL}/api/tool/{tool_name}", mock_session + # Assert that the core client's load_tool was called correctly + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters={}, bound_params=test_bound_params ) assert isinstance(tool, AsyncToolboxTool) - assert tool.name == tool_name + assert ( + tool.name == tool_name + ) # AsyncToolboxTool gets its name from the core_tool - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_tool_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema - ): + async def test_aload_tool_auth_headers_deprecated(self, mock_client): tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_lambda = lambda: "Bearer token" with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( - tool_name, auth_headers={"Authorization": lambda: "Bearer token"} + tool_name, + auth_headers={"Authorization": auth_lambda}, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_tool_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema - ): + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, + auth_token_getters={"Authorization": auth_lambda}, + bound_params={}, + ) + + async def test_aload_tool_auth_headers_and_getters_precedence(self, mock_client): tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_getters = {"test_source": lambda: "id_token_from_getters"} + auth_headers_lambda = lambda: "Bearer token_from_headers" + with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( tool_name, - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_token_getters={"test": lambda: "token"}, + auth_headers={"Authorization": auth_headers_lambda}, + auth_token_getters=auth_getters, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_toolset( - self, mock_load_manifest, mock_client, mock_session, manifest_schema - ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest - tools = await mock_client.aload_toolset() + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters=auth_getters, bound_params={} + ) + + async def test_aload_tool_auth_tokens_deprecated(self, mock_client): + tool_name = "test_tool_1" + token_lambda = lambda: "id_token" + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, + auth_tokens={"some_token_key": token_lambda}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, + auth_token_getters={"some_token_key": token_lambda}, + bound_params={}, + ) + + async def test_aload_tool_auth_tokens_and_getters_precedence(self, mock_client): + tool_name = "test_tool_1" + auth_getters = {"real_source": lambda: "token_from_getters"} + token_lambda = lambda: "token_from_auth_tokens" + + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, + auth_tokens={"deprecated_source": token_lambda}, + auth_token_getters=auth_getters, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) - mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters=auth_getters, bound_params={} + ) + + async def test_aload_toolset(self, mock_client): + test_bound_params = {"bp_set": "value_set"} + tools = await mock_client.aload_toolset( + bound_params=test_bound_params, strict=True + ) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters={}, + bound_params=test_bound_params, + strict=True, + ) assert len(tools) == 2 for tool in tools: assert isinstance(tool, AsyncToolboxTool) assert tool.name in ["test_tool_1", "test_tool_2"] - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_toolset_with_toolset_name( - self, mock_load_manifest, mock_client, mock_session, manifest_schema - ): + async def test_aload_toolset_with_toolset_name(self, mock_client): toolset_name = "test_toolset_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest tools = await mock_client.aload_toolset(toolset_name=toolset_name) - mock_load_manifest.assert_called_once_with( - f"{URL}/api/toolset/{toolset_name}", mock_session + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=toolset_name, auth_token_getters={}, bound_params={}, strict=False ) assert len(tools) == 2 for tool in tools: assert isinstance(tool, AsyncToolboxTool) - assert tool.name in ["test_tool_1", "test_tool_2"] - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_toolset_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema + async def test_aload_toolset_auth_headers_deprecated(self, mock_client): + auth_lambda = lambda: "Bearer token" + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset(auth_headers={"Authorization": auth_lambda}) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters={"Authorization": auth_lambda}, + bound_params={}, + strict=False, + ) + + async def test_aload_toolset_auth_headers_and_getters_precedence( # Renamed for clarity + self, mock_client ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_getters = {"test_source": lambda: "id_token_from_getters"} + auth_headers_lambda = lambda: "Bearer token_from_headers" with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"} + auth_headers={"Authorization": auth_headers_lambda}, + auth_token_getters=auth_getters, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_toolset_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema - ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters=auth_getters, + bound_params={}, + strict=False, # auth_getters takes precedence + ) + + async def test_aload_toolset_auth_tokens_deprecated(self, mock_client): + token_lambda = lambda: "id_token" with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_token_getters={"test": lambda: "token"}, + auth_tokens={"some_token_key": token_lambda} ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) - assert "auth_headers" in str(w[-1].message) + assert "auth_tokens" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters={"some_token_key": token_lambda}, + bound_params={}, + strict=False, + ) + + async def test_aload_toolset_auth_tokens_and_getters_precedence(self, mock_client): + auth_getters = {"real_source": lambda: "token_from_getters"} + token_lambda = lambda: "token_from_auth_tokens" + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_tokens={"deprecated_source": token_lambda}, + auth_token_getters=auth_getters, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, auth_token_getters=auth_getters, bound_params={}, strict=False + ) async def test_load_tool_not_implemented(self, mock_client): with pytest.raises(NotImplementedError) as excinfo: diff --git a/packages/toolbox-langchain/tests/test_async_tools.py b/packages/toolbox-langchain/tests/test_async_tools.py index e23aee85..96bd7660 100644 --- a/packages/toolbox-langchain/tests/test_async_tools.py +++ b/packages/toolbox-langchain/tests/test_async_tools.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types # For MappingProxyType from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio from pydantic import ValidationError +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool from toolbox_langchain.async_tools import AsyncToolboxTool @@ -24,7 +27,7 @@ @pytest.mark.asyncio class TestAsyncToolboxTool: @pytest.fixture - def tool_schema(self): + def tool_schema_dict(self): return { "description": "Test Tool Description", "parameters": [ @@ -34,9 +37,10 @@ def tool_schema(self): } @pytest.fixture - def auth_tool_schema(self): + def auth_tool_schema_dict(self): return { "description": "Test Tool Description", + "authRequired": ["test-auth-source"], "parameters": [ { "name": "param1", @@ -48,133 +52,193 @@ def auth_tool_schema(self): ], } + def _create_core_tool_from_dict( + self, session, name, schema_dict, url, initial_auth_getters=None + ): + core_params_schemas = [ + CoreParameterSchema(**p) for p in schema_dict["parameters"] + ] + + tool_constructor_params = [] + required_authn_for_core = {} + for p_schema in core_params_schemas: + if p_schema.authSources: + required_authn_for_core[p_schema.name] = p_schema.authSources + else: + tool_constructor_params.append(p_schema) + + return ToolboxCoreTool( + session=session, + base_url=url, + name=name, + description=schema_dict["description"], + params=tool_constructor_params, + required_authn_params=types.MappingProxyType(required_authn_for_core), + required_authz_tokens=schema_dict.get("authRequired", []), + auth_service_token_getters=types.MappingProxyType( + initial_auth_getters or {} + ), + bound_params=types.MappingProxyType({}), + client_headers=types.MappingProxyType({}), + ) + @pytest_asyncio.fixture @patch("aiohttp.ClientSession") - async def toolbox_tool(self, MockClientSession, tool_schema): + async def toolbox_tool(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} - ) - tool = AsyncToolboxTool( + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"result": "test-result"}) + mock_response.status = 200 # *** Fix: Set status for the mock response *** + + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, name="test_tool", - schema=tool_schema, + schema_dict=tool_schema_dict, url="http://test_url", - session=mock_session, ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) return tool @pytest_asyncio.fixture @patch("aiohttp.ClientSession") - async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema): + async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema_dict): mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"result": "test-result"}) + mock_response.status = 200 # *** Fix: Set status for the mock response *** + + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, + name="test_tool", + schema_dict=auth_tool_schema_dict, + url="https://test-url", ) - with pytest.warns( - UserWarning, - match=r"Parameter\(s\) `param1` of tool test_tool require authentication", - ): - tool = AsyncToolboxTool( - name="test_tool", - schema=auth_tool_schema, - url="https://test-url", - session=mock_session, - ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) return tool @patch("aiohttp.ClientSession") - async def test_toolbox_tool_init(self, MockClientSession, tool_schema): + async def test_toolbox_tool_init(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - tool = AsyncToolboxTool( + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.status = 200 + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, name="test_tool", - schema=tool_schema, + schema_dict=tool_schema_dict, url="https://test-url", - session=mock_session, ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) assert tool.name == "test_tool" - assert tool.description == "Test Tool Description" + assert tool.description == core_tool_instance.__doc__ @pytest.mark.parametrize( - "params, expected_bound_params", + "params_to_bind", [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), + ({"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}), + ({"param1": "bound-value", "param2": 123}), ], ) - async def test_toolbox_tool_bind_params( - self, toolbox_tool, params, expected_bound_params - ): - tool = toolbox_tool.bind_params(params) - for key, value in expected_bound_params.items(): - if callable(value): - assert value() == tool._AsyncToolboxTool__bound_params[key]() - else: - assert value == tool._AsyncToolboxTool__bound_params[key] - - @pytest.mark.parametrize("strict", [True, False]) - async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool, strict): - if strict: - with pytest.raises(ValueError) as e: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) - else: - with pytest.warns(UserWarning) as record: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert len(record) == 1 - assert "Parameter(s) param3 missing and cannot be bound." in str( - record[0].message + async def test_toolbox_tool_bind_params(self, toolbox_tool, params_to_bind): + original_core_tool = toolbox_tool._AsyncToolboxTool__core_tool + with patch.object( + original_core_tool, "bind_params", wraps=original_core_tool.bind_params + ) as mock_core_bind_params: + new_langchain_tool = toolbox_tool.bind_params(params_to_bind) + mock_core_bind_params.assert_called_once_with(params_to_bind) + assert isinstance( + new_langchain_tool._AsyncToolboxTool__core_tool, ToolboxCoreTool + ) + new_core_tool_signature_params = ( + new_langchain_tool._AsyncToolboxTool__core_tool.__signature__.parameters ) + for bound_param_name in params_to_bind.keys(): + assert bound_param_name not in new_core_tool_signature_params + + async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool): + with pytest.raises( + ValueError, match="unable to bind parameters: no parameter named param3" + ): + toolbox_tool.bind_params({"param3": "bound-value"}) async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): tool = toolbox_tool.bind_params({"param1": "bound-value"}) - with pytest.raises(ValueError) as e: - tool = tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( - e.value - ) + with pytest.raises( + ValueError, + match="cannot re-bind parameter: parameter 'param1' is already bound", + ): + tool.bind_params({"param1": "bound-value"}) async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): - with pytest.raises(ValueError) as e: + auth_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + # Verify that 'param1' is not in the list of bindable parameters for the core tool + # because it requires authentication. + assert "param1" not in [p.name for p in auth_core_tool._ToolboxTool__params] + with pytest.raises( + ValueError, match="unable to bind parameters: no parameter named param1" + ): auth_toolbox_tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) param1 already authenticated and cannot be bound." in str( - e.value + + async def test_toolbox_tool_add_valid_auth_token_getter(self, auth_toolbox_tool): + get_token_lambda = lambda: "test-token-value" + original_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + with patch.object( + original_core_tool, + "add_auth_token_getters", + wraps=original_core_tool.add_auth_token_getters, + ) as mock_core_add_getters: + tool = auth_toolbox_tool.add_auth_token_getters( + {"test-auth-source": get_token_lambda} + ) + mock_core_add_getters.assert_called_once_with( + {"test-auth-source": get_token_lambda} + ) + core_tool_after_add = tool._AsyncToolboxTool__core_tool + assert ( + "test-auth-source" + in core_tool_after_add._ToolboxTool__auth_service_token_getters + ) + assert ( + core_tool_after_add._ToolboxTool__auth_service_token_getters[ + "test-auth-source" + ] + is get_token_lambda + ) + assert not core_tool_after_add._ToolboxTool__required_authn_params.get( + "param1" + ) + assert ( + "test-auth-source" + not in core_tool_after_add._ToolboxTool__required_authz_tokens + ) + + async def test_toolbox_tool_add_unused_auth_token_getter_raises_error( + self, auth_toolbox_tool + ): + unused_lambda = lambda: "another-token" + with pytest.raises(ValueError) as excinfo: + auth_toolbox_tool.add_auth_token_getters( + {"another-auth-source": unused_lambda} + ) + assert ( + "Authentication source(s) `another-auth-source` unused by tool `test_tool`" + in str(excinfo.value) ) - @pytest.mark.parametrize( - "auth_token_getters, expected_auth_token_getters", - [ - ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), - ( + valid_lambda = lambda: "test-token" + with pytest.raises(ValueError) as excinfo_mixed: + auth_toolbox_tool.add_auth_token_getters( { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - ), - ], - ) - async def test_toolbox_tool_add_auth_token_getters( - self, auth_toolbox_tool, auth_token_getters, expected_auth_token_getters - ): - tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) - for source, getter in expected_auth_token_getters.items(): - assert tool._AsyncToolboxTool__auth_token_getters[source]() == getter() + "test-auth-source": valid_lambda, + "another-auth-source": unused_lambda, + } + ) + assert ( + "Authentication source(s) `another-auth-source` unused by tool `test_tool`" + in str(excinfo_mixed.value) + ) async def test_toolbox_tool_add_auth_token_getters_duplicate( self, auth_toolbox_tool @@ -182,45 +246,44 @@ async def test_toolbox_tool_add_auth_token_getters_duplicate( tool = auth_toolbox_tool.add_auth_token_getters( {"test-auth-source": lambda: "test-token"} ) - with pytest.raises(ValueError) as e: - tool = tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - assert ( - "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." - in str(e.value) - ) + with pytest.raises( + ValueError, + match="Authentication source\\(s\\) `test-auth-source` already registered in tool `test_tool`\\.", + ): + tool.add_auth_token_getters({"test-auth-source": lambda: "test-token"}) - async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - with pytest.raises(PermissionError) as e: - auth_toolbox_tool._AsyncToolboxTool__validate_auth(strict=True) - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value - ) + async def test_toolbox_tool_call_requires_auth_strict(self, auth_toolbox_tool): + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: test-auth-source", + ): + await auth_toolbox_tool.ainvoke({"param2": 123}) async def test_toolbox_tool_call(self, toolbox_tool): result = await toolbox_tool.ainvoke({"param1": "test-value", "param2": 123}) assert result == "test-result" - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + core_tool = toolbox_tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": "test-value", "param2": 123}, headers={}, ) @pytest.mark.parametrize( - "bound_param, expected_value", + "bound_param_map, expected_value", [ ({"param1": "bound-value"}, "bound-value"), ({"param1": lambda: "dynamic-value"}, "dynamic-value"), ], ) async def test_toolbox_tool_call_with_bound_params( - self, toolbox_tool, bound_param, expected_value + self, toolbox_tool, bound_param_map, expected_value ): - tool = toolbox_tool.bind_params(bound_param) + tool = toolbox_tool.bind_params(bound_param_map) result = await tool.ainvoke({"param2": 123}) assert result == "test-result" - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + core_tool = tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": expected_value, "param2": 123}, headers={}, @@ -232,29 +295,53 @@ async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): ) result = await tool.ainvoke({"param2": 123}) assert result == "test-result" - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + core_tool = tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "https://test-url/api/tool/test_tool/invoke", json={"param2": 123}, headers={"test-auth-source_token": "test-token"}, ) - async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool): + async def test_toolbox_tool_call_with_auth_tokens_insecure( + self, auth_toolbox_tool, auth_tool_schema_dict + ): + core_tool_of_auth_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + mock_session = core_tool_of_auth_tool._ToolboxTool__session + with pytest.warns( UserWarning, match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", ): - auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" - tool = auth_toolbox_tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - result = await tool.ainvoke({"param2": 123}) - assert result == "test-result" - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( - "http://test-url/api/tool/test_tool/invoke", - json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, + insecure_core_tool = self._create_core_tool_from_dict( + session=mock_session, + name="test_tool", + schema_dict=auth_tool_schema_dict, + url="http://test-url", ) + insecure_auth_langchain_tool = AsyncToolboxTool(core_tool=insecure_core_tool) + + tool_with_getter = insecure_auth_langchain_tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool_with_getter.ainvoke({"param2": 123}) + assert result == "test-result" + + modified_core_tool_in_new_tool = tool_with_getter._AsyncToolboxTool__core_tool + assert ( + modified_core_tool_in_new_tool._ToolboxTool__base_url == "http://test-url" + ) + assert ( + modified_core_tool_in_new_tool._ToolboxTool__url + == "http://test-url/api/tool/test_tool/invoke" + ) + + modified_core_tool_in_new_tool._ToolboxTool__session.post.assert_called_once_with( + "http://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) + async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): with pytest.raises(ValidationError) as e: await toolbox_tool.ainvoke({"param1": 123, "param2": "invalid"}) diff --git a/packages/toolbox-langchain/tests/test_client.py b/packages/toolbox-langchain/tests/test_client.py index 62999019..98f29e53 100644 --- a/packages/toolbox-langchain/tests/test_client.py +++ b/packages/toolbox-langchain/tests/test_client.py @@ -16,6 +16,9 @@ import pytest from pydantic import BaseModel +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.utils import params_to_pydantic_model from toolbox_langchain.client import ToolboxClient from toolbox_langchain.tools import ToolboxTool @@ -23,237 +26,397 @@ URL = "http://test_url" +def create_mock_core_sync_tool( + name="mock-sync-tool", + doc="Mock sync description.", + model_name="MockSyncModel", + params=None, +): + mock_tool = Mock(spec=ToolboxCoreSyncTool) + mock_tool.__name__ = name + mock_tool.__doc__ = doc + mock_tool._name = model_name + if params is None: + mock_tool._params = [ + CoreParameterSchema(name="param1", type="string", description="Param 1") + ] + else: + mock_tool._params = params + return mock_tool + + +def assert_pydantic_models_equivalent( + model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str +): + assert issubclass(model_cls1, BaseModel), "model_cls1 is not a Pydantic BaseModel" + assert issubclass(model_cls2, BaseModel), "model_cls2 is not a Pydantic BaseModel" + + assert ( + model_cls1.__name__ == expected_model_name + ), f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}" + assert ( + model_cls2.__name__ == expected_model_name + ), f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}" + + fields1 = model_cls1.model_fields + fields2 = model_cls2.model_fields + + assert ( + fields1.keys() == fields2.keys() + ), f"Field names mismatch: {fields1.keys()} != {fields2.keys()}" + + for field_name in fields1.keys(): + field_info1 = fields1[field_name] + field_info2 = fields2[field_name] + + assert ( + field_info1.annotation == field_info2.annotation + ), f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})" + assert ( + field_info1.description == field_info2.description + ), f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')" + is_required1 = ( + field_info1.is_required() + if hasattr(field_info1, "is_required") + else not field_info1.is_nullable() + ) + is_required2 = ( + field_info2.is_required() + if hasattr(field_info2, "is_required") + else not field_info2.is_nullable() + ) + assert ( + is_required1 == is_required2 + ), f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})" + + class TestToolboxClient: @pytest.fixture() def toolbox_client(self): client = ToolboxClient(URL) assert isinstance(client, ToolboxClient) - assert client._ToolboxClient__async_client is not None + assert client._ToolboxClient__core_client is not None + return client - # Check that the background loop was created and started - assert client._ToolboxClient__loop is not None - assert client._ToolboxClient__loop.is_running() + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + def test_load_tool(self, mock_core_load_tool, toolbox_client): + mock_core_tool_instance = create_mock_core_sync_tool( + name="test_tool_sync", + doc="Sync tool description.", + model_name="TestToolSyncModel", + params=[ + CoreParameterSchema( + name="sp1", type="integer", description="Sync Param 1" + ) + ], + ) + mock_core_load_tool.return_value = mock_core_tool_instance - return client + langchain_tool = toolbox_client.load_tool("test_tool") + + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == mock_core_tool_instance.__name__ + assert langchain_tool.description == mock_core_tool_instance.__doc__ + + # Generate the expected schema once for comparison + expected_args_schema = params_to_pydantic_model( + mock_core_tool_instance._name, mock_core_tool_instance._params + ) + + assert_pydantic_models_equivalent( + langchain_tool.args_schema, + expected_args_schema, + mock_core_tool_instance._name, + ) + + mock_core_load_tool.assert_called_once_with( + name="test_tool", auth_token_getters={}, bound_params={} + ) + + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + def test_load_toolset(self, mock_core_load_toolset, toolbox_client): + mock_core_tool_instance1 = create_mock_core_sync_tool( + name="tool-0", doc="desc 0", model_name="T0Model" + ) + mock_core_tool_instance2 = create_mock_core_sync_tool( + name="tool-1", doc="desc 1", model_name="T1Model", params=[] + ) + + mock_core_load_toolset.return_value = [ + mock_core_tool_instance1, + mock_core_tool_instance2, + ] + + langchain_tools = toolbox_client.load_toolset() + assert len(langchain_tools) == 2 + + tool_instances_mocks = [mock_core_tool_instance1, mock_core_tool_instance2] + for i, tool_instance_mock in enumerate(tool_instances_mocks): + langchain_tool = langchain_tools[i] + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == tool_instance_mock.__name__ + assert langchain_tool.description == tool_instance_mock.__doc__ - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - def test_load_tool(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - - tool = toolbox_client.load_tool("test_tool") - - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - tools = toolbox_client.load_toolset() - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) + expected_args_schema = params_to_pydantic_model( + tool_instance_mock._name, tool_instance_mock._params + ) + assert_pydantic_models_equivalent( + langchain_tool.args_schema, + expected_args_schema, + tool_instance_mock._name, + ) + + mock_core_load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False ) - mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - - tool = await toolbox_client.aload_tool("test_tool") - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client): + mock_core_sync_tool_instance = create_mock_core_sync_tool( + name="test_async_loaded_tool", + doc="Async loaded sync tool description.", + model_name="AsyncTestToolModel", + ) + mock_sync_core_load_tool.return_value = mock_core_sync_tool_instance + + langchain_tool = await toolbox_client.aload_tool("test_tool") + + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == mock_core_sync_tool_instance.__name__ + assert langchain_tool.description == mock_core_sync_tool_instance.__doc__ + + expected_args_schema = params_to_pydantic_model( + mock_core_sync_tool_instance._name, mock_core_sync_tool_instance._params + ) + assert_pydantic_models_equivalent( + langchain_tool.args_schema, + expected_args_schema, + mock_core_sync_tool_instance._name, + ) + + mock_sync_core_load_tool.assert_called_once_with( + name="test_tool", auth_token_getters={}, bound_params={} + ) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - async def test_aload_toolset(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - tools = await toolbox_client.aload_toolset() - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + async def test_aload_toolset(self, mock_sync_core_load_toolset, toolbox_client): + mock_core_sync_tool1 = create_mock_core_sync_tool( + name="async-tool-0", doc="async desc 0", model_name="AT0Model" ) - mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - def test_load_tool_with_args(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool + mock_core_sync_tool2 = create_mock_core_sync_tool( + name="async-tool-1", + doc="async desc 1", + model_name="AT1Model", + params=[CoreParameterSchema(name="p1", type="string", description="P1")], + ) + + mock_sync_core_load_toolset.return_value = [ + mock_core_sync_tool1, + mock_core_sync_tool2, + ] + + langchain_tools = await toolbox_client.aload_toolset() + assert len(langchain_tools) == 2 + + tool_instances_mocks = [mock_core_sync_tool1, mock_core_sync_tool2] + for i, tool_instance_mock in enumerate(tool_instances_mocks): + langchain_tool = langchain_tools[i] + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == tool_instance_mock.__name__ + + expected_args_schema = params_to_pydantic_model( + tool_instance_mock._name, tool_instance_mock._params + ) + assert_pydantic_models_equivalent( + langchain_tool.args_schema, + expected_args_schema, + tool_instance_mock._name, + ) + + mock_sync_core_load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False + ) + + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): + mock_core_tool_instance = create_mock_core_sync_tool() + mock_core_load_tool.return_value = mock_core_tool_instance + auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} + # Scenario 1: auth_token_getters takes precedence + with pytest.warns(DeprecationWarning) as record: + tool = toolbox_client.load_tool( + "test_tool_name", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + assert len(record) == 2 + messages = sorted([str(r.message) for r in record]) + # Warning for auth_headers when auth_token_getters is also present + assert ( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." + in messages + ) + # Warning for auth_tokens when auth_token_getters is also present + assert ( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used." + in messages + ) - tool = toolbox_client.load_tool( - "test_tool_name", + assert isinstance(tool, ToolboxTool) + mock_core_load_tool.assert_called_with( + name="test_tool_name", auth_token_getters=auth_token_getters, - auth_tokens=auth_tokens, - auth_headers=auth_headers, bound_params=bound_params, - strict=False, ) + mock_core_load_tool.reset_mock() + + # Scenario 2: auth_tokens and auth_headers provided, auth_token_getters is default (empty initially) + with pytest.warns(DeprecationWarning) as record: + toolbox_client.load_tool( + "test_tool_name_2", + auth_tokens=auth_tokens_deprecated, # This will be used for auth_token_getters + auth_headers=auth_headers_deprecated, # This will warn as auth_token_getters is now populated + bound_params=bound_params, + ) + assert len(record) == 2 + messages = sorted([str(r.message) for r in record]) + + assert ( + messages[0] + == "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead." + ) + assert ( + messages[1] + == "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." + ) + + expected_getters_for_call = auth_tokens_deprecated + + mock_core_load_tool.assert_called_with( + name="test_tool_name_2", + auth_token_getters=expected_getters_for_call, + bound_params=bound_params, + ) + mock_core_load_tool.reset_mock() + + with pytest.warns( + DeprecationWarning, + match="Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + ) as record: + toolbox_client.load_tool( + "test_tool_name_3", + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + assert len(record) == 1 - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with( - "test_tool_name", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + mock_core_load_tool.assert_called_with( + name="test_tool_name_3", + auth_token_getters=auth_headers_deprecated, + bound_params=bound_params, ) - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset_with_args(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + def test_load_toolset_with_args(self, mock_core_load_toolset, toolbox_client): + mock_core_tool_instance = create_mock_core_sync_tool(model_name="MySetModel") + mock_core_load_toolset.return_value = [mock_core_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} + toolset_name = "my_toolset" - tools = toolbox_client.load_toolset( - toolset_name="my_toolset", + with pytest.warns(DeprecationWarning) as record: + tools = toolbox_client.load_toolset( + toolset_name=toolset_name, + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + strict=True, + ) + assert len(record) == 2 + + assert len(tools) == 1 + assert isinstance(tools[0], ToolboxTool) + mock_core_load_toolset.assert_called_with( + name=toolset_name, auth_token_getters=auth_token_getters, - auth_tokens=auth_tokens, - auth_headers=auth_headers, bound_params=bound_params, - strict=False, - ) - - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + strict=True, ) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool_with_args(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + async def test_aload_tool_with_args(self, mock_sync_core_load_tool, toolbox_client): + mock_core_tool_instance = create_mock_core_sync_tool( + model_name="MyAsyncToolModel" + ) + mock_sync_core_load_tool.return_value = mock_core_tool_instance auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - tool = await toolbox_client.aload_tool( - "test_tool", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with( - "test_tool", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + with pytest.warns(DeprecationWarning) as record: + tool = await toolbox_client.aload_tool( + "test_tool", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + assert len(record) == 2 + + assert isinstance(tool, ToolboxTool) + mock_sync_core_load_tool.assert_called_with( + name="test_tool", + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - async def test_aload_toolset_with_args(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + async def test_aload_toolset_with_args( + self, mock_sync_core_load_toolset, toolbox_client + ): + mock_core_tool_instance = create_mock_core_sync_tool( + model_name="MyAsyncSetModel" + ) + mock_sync_core_load_toolset.return_value = [mock_core_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} + toolset_name = "my_async_toolset" - tools = await toolbox_client.aload_toolset( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + with pytest.warns(DeprecationWarning) as record: + tools = await toolbox_client.aload_toolset( + toolset_name, + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + strict=True, + ) + assert len(record) == 2 + + assert len(tools) == 1 + assert isinstance(tools[0], ToolboxTool) + mock_sync_core_load_toolset.assert_called_with( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=True, ) diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 214ea305..12002717 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -36,7 +36,6 @@ import pytest import pytest_asyncio -from langchain_core.tools import ToolException from pydantic import ValidationError from toolbox_langchain.client import ToolboxClient @@ -54,7 +53,7 @@ def toolbox(self): @pytest_asyncio.fixture(scope="function") async def get_n_rows_tool(self, toolbox): tool = await toolbox.aload_tool("get-n-rows") - assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + assert tool._ToolboxTool__core_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests @@ -71,7 +70,7 @@ async def test_aload_toolset_specific( toolset = await toolbox.aload_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in expected_tools async def test_aload_toolset_all(self, toolbox): @@ -85,7 +84,7 @@ async def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in tool_names async def test_run_tool_async(self, get_n_rows_tool): @@ -114,11 +113,14 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool): @pytest.mark.asyncio async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" - tool = await toolbox.aload_tool( - "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} - ) - response = await tool.ainvoke({"id": "2"}) - assert "row2" in response + with pytest.raises( + ValueError, + match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", + ): + await toolbox.aload_tool( + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, + ) async def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" @@ -127,7 +129,7 @@ async def test_run_tool_no_auth(self, toolbox): ) with pytest.raises( PermissionError, - match="Tool get-row-by-id-auth requires authentication, but no valid authentication sources are registered. Please register the required sources before use.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool.ainvoke({"id": "2"}) @@ -138,8 +140,8 @@ async def test_run_tool_wrong_auth(self, toolbox, auth_token2): ) auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) with pytest.raises( - ToolException, - match="{'status': 'Unauthorized', 'error': 'tool invocation not authorized. Please make sure your specify correct auth headers'}", + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", ): await auth_tool.ainvoke({"id": "2"}) @@ -157,7 +159,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): tool = await toolbox.aload_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool.ainvoke({"email": ""}) @@ -179,8 +181,8 @@ async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): auth_token_getters={"my-test-auth": lambda: auth_token1}, ) with pytest.raises( - ToolException, - match="{'status': 'Bad Request', 'error': 'provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims'}", + Exception, + match='provided parameters were invalid: error parsing authenticated parameter "data": no field named row_data in claims', ): await tool.ainvoke({}) @@ -196,7 +198,7 @@ def toolbox(self): @pytest.fixture(scope="function") def get_n_rows_tool(self, toolbox): tool = toolbox.load_tool("get-n-rows") - assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + assert tool._ToolboxTool__core_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests @@ -213,7 +215,7 @@ def test_load_toolset_specific( toolset = toolbox.load_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in expected_tools def test_aload_toolset_all(self, toolbox): @@ -227,7 +229,7 @@ def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in tool_names @pytest.mark.asyncio @@ -256,11 +258,14 @@ def test_run_tool_wrong_param_type(self, get_n_rows_tool): #### Auth tests def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" - tool = toolbox.load_tool( - "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} - ) - response = tool.invoke({"id": "2"}) - assert "row2" in response + with pytest.raises( + ValueError, + match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", + ): + toolbox.load_tool( + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, + ) def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" @@ -269,7 +274,7 @@ def test_run_tool_no_auth(self, toolbox): ) with pytest.raises( PermissionError, - match="Tool get-row-by-id-auth requires authentication, but no valid authentication sources are registered. Please register the required sources before use.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool.invoke({"id": "2"}) @@ -280,8 +285,8 @@ def test_run_tool_wrong_auth(self, toolbox, auth_token2): ) auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) with pytest.raises( - ToolException, - match="{'status': 'Unauthorized', 'error': 'tool invocation not authorized. Please make sure your specify correct auth headers'}", + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", ): auth_tool.invoke({"id": "2"}) @@ -299,7 +304,7 @@ def test_run_tool_param_auth_no_auth(self, toolbox): tool = toolbox.load_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool.invoke({"email": ""}) @@ -321,7 +326,7 @@ def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): auth_token_getters={"my-test-auth": lambda: auth_token1}, ) with pytest.raises( - ToolException, - match="{'status': 'Bad Request', 'error': 'provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims'}", + Exception, + match='provided parameters were invalid: error parsing authenticated parameter "data": no field named row_data in claims', ): tool.invoke({}) diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 751005af..90fddf4b 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -12,21 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +import asyncio +from unittest.mock import AsyncMock, Mock, call, patch import pytest from pydantic import BaseModel +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.tool import ToolboxTool as CoreAsyncTool +from toolbox_core.utils import params_to_pydantic_model -from toolbox_langchain.async_tools import AsyncToolboxTool from toolbox_langchain.tools import ToolboxTool +def assert_pydantic_models_equivalent( + model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str +): + assert issubclass(model_cls1, BaseModel), "model_cls1 is not a Pydantic BaseModel" + assert issubclass(model_cls2, BaseModel), "model_cls2 is not a Pydantic BaseModel" + assert ( + model_cls1.__name__ == expected_model_name + ), f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}" + assert ( + model_cls2.__name__ == expected_model_name + ), f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}" + + fields1 = model_cls1.model_fields + fields2 = model_cls2.model_fields + + assert ( + fields1.keys() == fields2.keys() + ), f"Field names mismatch: {fields1.keys()} != {fields2.keys()}" + + for field_name in fields1.keys(): + field_info1 = fields1[field_name] + field_info2 = fields2[field_name] + + assert ( + field_info1.annotation == field_info2.annotation + ), f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})" + assert ( + field_info1.description == field_info2.description + ), f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')" + is_required1 = ( + field_info1.is_required() + if hasattr(field_info1, "is_required") + else not field_info1.is_nullable() + ) + is_required2 = ( + field_info2.is_required() + if hasattr(field_info2, "is_required") + else not field_info2.is_nullable() + ) + assert ( + is_required1 == is_required2 + ), f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})" + + class TestToolboxTool: @pytest.fixture - def tool_schema(self): + def tool_schema_dict(self): return { "description": "Test Tool Description", - "name": "test_tool", "parameters": [ {"name": "param1", "type": "string", "description": "Param 1"}, {"name": "param2", "type": "integer", "description": "Param 2"}, @@ -34,10 +81,10 @@ def tool_schema(self): } @pytest.fixture - def auth_tool_schema(self): + def auth_tool_schema_dict(self): return { - "description": "Test Tool Description", - "name": "test_tool", + "description": "Test Auth Tool Description", + "authRequired": ["test-auth-source"], "parameters": [ { "name": "param1", @@ -49,190 +96,210 @@ def auth_tool_schema(self): ], } - @pytest.fixture(scope="function") - def mock_async_tool(self, tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_token_getters = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool - - @pytest.fixture(scope="function") - def mock_async_auth_tool(self, auth_tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = auth_tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_token_getters = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool - @pytest.fixture - def toolbox_tool(self, mock_async_tool): - return ToolboxTool( - async_tool=mock_async_tool, - loop=Mock(), - thread=Mock(), + def mock_core_tool(self, tool_schema_dict): + sync_mock = Mock(spec=ToolboxCoreSyncTool) + + sync_mock.__name__ = "test_tool_name_for_langchain" + sync_mock.__doc__ = tool_schema_dict["description"] + sync_mock._name = "TestToolPydanticModel" + sync_mock._params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + + mock_async_tool_attr = AsyncMock(spec=CoreAsyncTool) + mock_async_tool_attr.return_value = "dummy_internal_async_tool_result" + sync_mock._ToolboxSyncTool__async_tool = mock_async_tool_attr + sync_mock._ToolboxSyncTool__loop = Mock(spec=asyncio.AbstractEventLoop) + sync_mock._ToolboxSyncTool__thread = Mock() + + new_mock_instance_for_methods = Mock(spec=ToolboxCoreSyncTool) + new_mock_instance_for_methods.__name__ = sync_mock.__name__ + new_mock_instance_for_methods.__doc__ = sync_mock.__doc__ + new_mock_instance_for_methods._name = sync_mock._name + new_mock_instance_for_methods._params = sync_mock._params + new_mock_instance_for_methods._ToolboxSyncTool__async_tool = AsyncMock( + spec=CoreAsyncTool + ) + new_mock_instance_for_methods._ToolboxSyncTool__loop = Mock( + spec=asyncio.AbstractEventLoop + ) + new_mock_instance_for_methods._ToolboxSyncTool__thread = Mock() + + sync_mock.add_auth_token_getters = Mock( + return_value=new_mock_instance_for_methods ) + sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) + + return sync_mock @pytest.fixture - def auth_toolbox_tool(self, mock_async_auth_tool): - return ToolboxTool( - async_tool=mock_async_auth_tool, - loop=Mock(), - thread=Mock(), + def mock_core_sync_auth_tool(self, auth_tool_schema_dict): + sync_mock = Mock(spec=ToolboxCoreSyncTool) + sync_mock.__name__ = "test_auth_tool_lc_name" + sync_mock.__doc__ = auth_tool_schema_dict["description"] + sync_mock._name = "TestAuthToolPydanticModel" + sync_mock._params = [ + CoreParameterSchema(**p) for p in auth_tool_schema_dict["parameters"] + ] + + mock_async_tool_attr = AsyncMock(spec=CoreAsyncTool) + mock_async_tool_attr.return_value = "dummy_internal_async_auth_tool_result" + sync_mock._ToolboxSyncTool__async_tool = mock_async_tool_attr + sync_mock._ToolboxSyncTool__loop = Mock(spec=asyncio.AbstractEventLoop) + sync_mock._ToolboxSyncTool__thread = Mock() + + new_mock_instance_for_methods = Mock(spec=ToolboxCoreSyncTool) + new_mock_instance_for_methods.__name__ = sync_mock.__name__ + new_mock_instance_for_methods.__doc__ = sync_mock.__doc__ + new_mock_instance_for_methods._name = sync_mock._name + new_mock_instance_for_methods._params = sync_mock._params + new_mock_instance_for_methods._ToolboxSyncTool__async_tool = AsyncMock( + spec=CoreAsyncTool + ) + new_mock_instance_for_methods._ToolboxSyncTool__loop = Mock( + spec=asyncio.AbstractEventLoop ) + new_mock_instance_for_methods._ToolboxSyncTool__thread = Mock() - def test_toolbox_tool_init(self, mock_async_tool): - tool = ToolboxTool( - async_tool=mock_async_tool, - loop=Mock(), - thread=Mock(), + sync_mock.add_auth_token_getters = Mock( + return_value=new_mock_instance_for_methods + ) + sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) + return sync_mock + + @pytest.fixture + def toolbox_tool(self, mock_core_tool): + return ToolboxTool(core_tool=mock_core_tool) + + @pytest.fixture + def auth_toolbox_tool(self, mock_core_sync_auth_tool): + return ToolboxTool(core_tool=mock_core_sync_auth_tool) + + def test_toolbox_tool_init(self, mock_core_tool): + tool = ToolboxTool(core_tool=mock_core_tool) + + assert tool.name == mock_core_tool.__name__ + assert tool.description == mock_core_tool.__doc__ + assert tool._ToolboxTool__core_tool == mock_core_tool + + expected_args_schema = params_to_pydantic_model( + mock_core_tool._name, mock_core_tool._params + ) + assert_pydantic_models_equivalent( + tool.args_schema, expected_args_schema, mock_core_tool._name ) - async_tool = tool._ToolboxTool__async_tool - assert async_tool.name == mock_async_tool.name - assert async_tool.description == mock_async_tool.description - assert async_tool.args_schema == mock_async_tool.args_schema @pytest.mark.parametrize( - "params, expected_bound_params", + "params", [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), + ({"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}), + ({"param1": "bound-value", "param2": 123}), ], ) def test_toolbox_tool_bind_params( self, params, - expected_bound_params, toolbox_tool, - mock_async_tool, + mock_core_tool, ): - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_params - mock_async_tool.bind_params.return_value = mock_async_tool + returned_core_tool_mock = mock_core_tool.bind_params.return_value + new_langchain_tool = toolbox_tool.bind_params(params) - tool = toolbox_tool.bind_params(params) - mock_async_tool.bind_params.assert_called_once_with(params, True) - assert isinstance(tool, ToolboxTool) + mock_core_tool.bind_params.assert_called_once_with(params) + assert isinstance(new_langchain_tool, ToolboxTool) + assert new_langchain_tool._ToolboxTool__core_tool == returned_core_tool_mock - for key, value in expected_bound_params.items(): - async_tool_bound_param_val = ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params[key] - ) - if callable(value): - assert value() == async_tool_bound_param_val() - else: - assert value == async_tool_bound_param_val + def test_toolbox_tool_bind_param(self, toolbox_tool, mock_core_tool): + returned_core_tool_mock = mock_core_tool.bind_params.return_value + new_langchain_tool = toolbox_tool.bind_param("param1", "bound-value") - def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): - expected_bound_param = {"param1": "bound-value"} - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_param - mock_async_tool.bind_param.return_value = mock_async_tool - - tool = toolbox_tool.bind_param("param1", "bound-value") - mock_async_tool.bind_param.assert_called_once_with( - "param1", "bound-value", True - ) - - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params - == expected_bound_param - ) - assert isinstance(tool, ToolboxTool) + mock_core_tool.bind_params.assert_called_once_with({"param1": "bound-value"}) + assert isinstance(new_langchain_tool, ToolboxTool) + assert new_langchain_tool._ToolboxTool__core_tool == returned_core_tool_mock @pytest.mark.parametrize( - "auth_token_getters, expected_auth_token_getters", + "auth_token_getters", [ + ({"test-auth-source": lambda: "test-token"}), ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), - ( - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, { "test-auth-source": lambda: "test-token", "another-auth-source": lambda: "another-token", - }, + } ), ], ) def test_toolbox_tool_add_auth_token_getters( self, auth_token_getters, - expected_auth_token_getters, - mock_async_auth_tool, auth_toolbox_tool, + mock_core_sync_auth_tool, ): - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( - expected_auth_token_getters + returned_core_tool_mock = ( + mock_core_sync_auth_tool.add_auth_token_getters.return_value ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getters.return_value = ( - mock_async_auth_tool + new_langchain_tool = auth_toolbox_tool.add_auth_token_getters( + auth_token_getters ) - tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) - mock_async_auth_tool.add_auth_token_getters.assert_called_once_with( - auth_token_getters, True + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + auth_token_getters ) - for source, getter in expected_auth_token_getters.items(): - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ - source - ]() - == getter() - ) - assert isinstance(tool, ToolboxTool) + assert isinstance(new_langchain_tool, ToolboxTool) + assert new_langchain_tool._ToolboxTool__core_tool == returned_core_tool_mock def test_toolbox_tool_add_auth_token_getter( - self, mock_async_auth_tool, auth_toolbox_tool + self, auth_toolbox_tool, mock_core_sync_auth_tool ): get_id_token = lambda: "test-token" - expected_auth_token_getters = {"test-auth-source": get_id_token} - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( - expected_auth_token_getters - ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getter.return_value = ( - mock_async_auth_tool + returned_core_tool_mock = ( + mock_core_sync_auth_tool.add_auth_token_getters.return_value ) - tool = auth_toolbox_tool.add_auth_token_getter("test-auth-source", get_id_token) - mock_async_auth_tool.add_auth_token_getter.assert_called_once_with( - "test-auth-source", get_id_token, True + new_langchain_tool = auth_toolbox_tool.add_auth_token_getter( + "test-auth-source", get_id_token ) - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ - "test-auth-source" - ]() - == "test-token" + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + {"test-auth-source": get_id_token} ) - assert isinstance(tool, ToolboxTool) + assert isinstance(new_langchain_tool, ToolboxTool) + assert new_langchain_tool._ToolboxTool__core_tool == returned_core_tool_mock - def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - auth_toolbox_tool._ToolboxTool__async_tool._arun = Mock( - side_effect=PermissionError( - "Parameter(s) `param1` of tool test_tool require authentication" - ) - ) - with pytest.raises(PermissionError) as e: - auth_toolbox_tool._run() - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value + def test_toolbox_tool_run(self, toolbox_tool, mock_core_tool): + kwargs_to_run = {"param1": "run_value1", "param2": 100} + expected_result = "sync_run_output" + mock_core_tool.return_value = expected_result + + result = toolbox_tool._run(**kwargs_to_run) + + assert result == expected_result + assert mock_core_tool.call_count == 1 + assert mock_core_tool.call_args == call(**kwargs_to_run) + + @pytest.mark.asyncio + @patch("toolbox_langchain.tools.to_thread", new_callable=AsyncMock) + async def test_toolbox_tool_arun( + self, mock_to_thread_in_tools, toolbox_tool, mock_core_tool + ): + kwargs_to_run = {"param1": "arun_value1", "param2": 200} + expected_result = "async_run_output" + + mock_core_tool.return_value = expected_result + + async def to_thread_side_effect(func, *args, **kwargs_for_func): + return func(**kwargs_for_func) + + mock_to_thread_in_tools.side_effect = to_thread_side_effect + + result = await toolbox_tool._arun(**kwargs_to_run) + + assert result == expected_result + mock_to_thread_in_tools.assert_awaited_once_with( + mock_core_tool, **kwargs_to_run ) + + assert mock_core_tool.call_count == 1 + assert mock_core_tool.call_args == call(**kwargs_to_run) diff --git a/packages/toolbox-langchain/tests/test_utils.py b/packages/toolbox-langchain/tests/test_utils.py deleted file mode 100644 index 488a6aef..00000000 --- a/packages/toolbox-langchain/tests/test_utils.py +++ /dev/null @@ -1,290 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import json -import re -import warnings -from unittest.mock import AsyncMock, Mock, patch - -import aiohttp -import pytest -from pydantic import BaseModel - -from toolbox_langchain.utils import ( - ParameterSchema, - _get_auth_headers, - _invoke_tool, - _load_manifest, - _parse_type, - _schema_to_model, -) - -URL = "https://my-toolbox.com/test" -MOCK_MANIFEST = """ -{ - "serverVersion": "0.0.1", - "tools": { - "test_tool": { - "summary": "Test Tool", - "description": "This is a test tool.", - "parameters": [ - { - "name": "param1", - "type": "string", - "description": "Parameter 1" - }, - { - "name": "param2", - "type": "integer", - "description": "Parameter 2" - } - ] - } - } -} -""" - - -class TestUtils: - @pytest.fixture(scope="module") - def mock_manifest(self): - return aiohttp.ClientResponse( - method="GET", - url=aiohttp.client.URL(URL), - writer=None, - continue100=None, - timer=None, - request_info=None, - traces=None, - session=None, - loop=asyncio.get_event_loop(), - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value=MOCK_MANIFEST) - - mock_get.return_value = mock_manifest - session = aiohttp.ClientSession() - manifest = await _load_manifest(URL, session) - await session.close() - mock_get.assert_called_once_with(URL) - - assert manifest.serverVersion == "0.0.1" - assert len(manifest.tools) == 1 - - tool = manifest.tools["test_tool"] - assert tool.description == "This is a test tool." - assert tool.parameters == [ - ParameterSchema(name="param1", type="string", description="Parameter 1"), - ParameterSchema(name="param2", type="integer", description="Parameter 2"), - ] - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_invalid_json(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value="{ invalid manifest") - mock_get.return_value = mock_manifest - - with pytest.raises(Exception) as e: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - - mock_get.assert_called_once_with(URL) - assert isinstance(e.value, json.JSONDecodeError) - assert ( - str(e.value) - == "Failed to parse JSON from https://my-toolbox.com/test: Expecting property name enclosed in double quotes: line 1 column 3 (char 2): line 1 column 3 (char 2)" - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_invalid_manifest(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value='{ "something": "invalid" }') - mock_get.return_value = mock_manifest - - with pytest.raises(Exception) as e: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - - mock_get.assert_called_once_with(URL) - assert isinstance(e.value, ValueError) - assert re.match( - r"Invalid JSON data from https://my-toolbox.com/test: 2 validation errors for ManifestSchema\nserverVersion\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing\ntools\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing", - str(e.value), - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_api_error(self, mock_get, mock_manifest): - error = aiohttp.ClientError("Simulated HTTP Error") - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(side_effect=error) - mock_get.return_value = mock_manifest - - with pytest.raises(aiohttp.ClientError) as exc_info: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - mock_get.assert_called_once_with(URL) - assert exc_info.value == error - - def test_schema_to_model(self): - schema = [ - ParameterSchema(name="param1", type="string", description="Parameter 1"), - ParameterSchema(name="param2", type="integer", description="Parameter 2"), - ] - model = _schema_to_model("TestModel", schema) - assert issubclass(model, BaseModel) - - assert model.model_fields["param1"].annotation == str - assert model.model_fields["param1"].description == "Parameter 1" - assert model.model_fields["param2"].annotation == int - assert model.model_fields["param2"].description == "Parameter 2" - - def test_schema_to_model_empty(self): - model = _schema_to_model("TestModel", []) - assert issubclass(model, BaseModel) - assert len(model.model_fields) == 0 - - @pytest.mark.parametrize( - "parameter_schema, expected_type", - [ - (ParameterSchema(name="foo", description="bar", type="string"), str), - (ParameterSchema(name="foo", description="bar", type="integer"), int), - (ParameterSchema(name="foo", description="bar", type="float"), float), - (ParameterSchema(name="foo", description="bar", type="boolean"), bool), - ( - ParameterSchema( - name="foo", - description="bar", - type="array", - items=ParameterSchema( - name="foo", description="bar", type="integer" - ), - ), - list[int], - ), - ], - ) - def test_parse_type(self, parameter_schema, expected_type): - assert _parse_type(parameter_schema) == expected_type - - @pytest.mark.parametrize( - "fail_parameter_schema", - [ - (ParameterSchema(name="foo", description="bar", type="invalid")), - ( - ParameterSchema( - name="foo", - description="bar", - type="array", - items=ParameterSchema( - name="foo", description="bar", type="invalid" - ), - ) - ), - ], - ) - def test_parse_type_invalid(self, fail_parameter_schema): - with pytest.raises(ValueError): - _parse_type(fail_parameter_schema) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool(self, mock_post): - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - result = await _invoke_tool( - "http://localhost:5000", - aiohttp.ClientSession(), - "tool_name", - {"input": "data"}, - {}, - ) - - mock_post.assert_called_once_with( - "http://localhost:5000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={}, - ) - assert result == {"key": "value"} - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool_unsecure_with_auth(self, mock_post): - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - with pytest.warns( - UserWarning, - match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", - ): - result = await _invoke_tool( - "http://localhost:5000", - aiohttp.ClientSession(), - "tool_name", - {"input": "data"}, - {"my_test_auth": lambda: "fake_id_token"}, - ) - - mock_post.assert_called_once_with( - "http://localhost:5000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={"my_test_auth_token": "fake_id_token"}, - ) - assert result == {"key": "value"} - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool_secure_with_auth(self, mock_post): - session = aiohttp.ClientSession() - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - with warnings.catch_warnings(): - warnings.simplefilter("error") - result = await _invoke_tool( - "https://localhost:5000", - session, - "tool_name", - {"input": "data"}, - {"my_test_auth": lambda: "fake_id_token"}, - ) - - mock_post.assert_called_once_with( - "https://localhost:5000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={"my_test_auth_token": "fake_id_token"}, - ) - assert result == {"key": "value"} - - def test_get_auth_headers_deprecation_warning(self): - """Test _get_auth_headers deprecation warning.""" - with pytest.warns( - DeprecationWarning, - match=r"Call to deprecated function \(or staticmethod\) _get_auth_headers\. \(Please use `_get_auth_tokens` instead\.\)$", - ): - _get_auth_headers({"auth_source1": lambda: "test_token"})