From 77bfdb2f9a0bb30a9045b08753ec488384cba26b Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 14:05:36 +0530 Subject: [PATCH 01/53] fix(toolbox-langchain)!: Base toolbox-langchain over toolbox-core --- .../src/toolbox_langchain/async_client.py | 52 ++-- .../src/toolbox_langchain/async_tools.py | 289 ++--------------- .../src/toolbox_langchain/utils.py | 268 ---------------- .../tests/test_async_client.py | 2 +- .../toolbox-langchain/tests/test_utils.py | 290 ------------------ 5 files changed, 37 insertions(+), 864 deletions(-) delete mode 100644 packages/toolbox-langchain/src/toolbox_langchain/utils.py delete mode 100644 packages/toolbox-langchain/tests/test_utils.py diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py index c58bbfdf..d7e39814 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py @@ -18,7 +18,8 @@ from aiohttp import ClientSession from .tools import AsyncToolboxTool -from .utils import ManifestSchema, _load_manifest + +from toolbox_core.client import ToolboxClient as ToolboxCoreClient # This class is an internal implementation detail and is not exposed to the @@ -38,8 +39,7 @@ def __init__( url: The base URL of the Toolbox service. session: An HTTP client session. """ - self.__url = url - self.__session = session + self.__core_client = ToolboxCoreClient(url=url, session=session) async def aload_tool( self, @@ -48,7 +48,6 @@ async def aload_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> AsyncToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. @@ -61,9 +60,6 @@ async def aload_tool( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. @@ -94,18 +90,12 @@ async def aload_tool( ) auth_token_getters = auth_headers - url = f"{self.__url}/api/tool/{tool_name}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - - return AsyncToolboxTool( - tool_name, - manifest.tools[tool_name], - self.__url, - self.__session, - auth_token_getters, - bound_params, - strict, + core_tool = await self.__core_client.load_tool( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params ) + return AsyncToolboxTool(core_tool=core_tool) async def aload_toolset( self, @@ -162,22 +152,16 @@ async def aload_toolset( ) auth_token_getters = auth_headers - url = f"{self.__url}/api/toolset/{toolset_name or ''}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - tools: list[AsyncToolboxTool] = [] - - for tool_name, tool_schema in manifest.tools.items(): - tools.append( - AsyncToolboxTool( - tool_name, - tool_schema, - self.__url, - self.__session, - auth_token_getters, - bound_params, - strict, - ) - ) + core_tools = await self.__core_client.load_toolset( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict + ) + + tools = [] + for core_tool in core_tools: + tools.append(AsyncToolboxTool(core_tool_instance=core_tool)) return tools def load_tool( diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index 40e21ee6..1a7b3dd1 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -16,16 +16,9 @@ from typing import Any, Callable, TypeVar, Union from warnings import warn -from aiohttp import ClientSession from langchain_core.tools import BaseTool +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool -from .utils import ( - ToolSchema, - _find_auth_params, - _find_bound_params, - _invoke_tool, - _schema_to_model, -) T = TypeVar("T") @@ -41,13 +34,7 @@ class AsyncToolboxTool(BaseTool): def __init__( self, - name: str, - schema: ToolSchema, - url: str, - session: ClientSession, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + core_tool: ToolboxCoreTool, ) -> None: """ Initializes an AsyncToolboxTool instance. @@ -61,89 +48,19 @@ def __init__( functions that retrieve ID tokens. bound_params: A mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters is missing from the schema or requires - authentication. If False, only issues a warning. """ - # If the schema is not already a ToolSchema instance, we create one from - # its attributes. This allows flexibility in how the schema is provided, - # accepting both a ToolSchema object and a dictionary of schema - # attributes. - if not isinstance(schema, ToolSchema): - schema = ToolSchema(**schema) - - auth_params, non_auth_params = _find_auth_params(schema.parameters) - non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( - non_auth_params, list(bound_params) - ) - - # Check if the user is trying to bind a param that is authenticated or - # is missing from the given schema. - auth_bound_params: list[str] = [] - missing_bound_params: list[str] = [] - for bound_param in bound_params: - if bound_param in [param.name for param in auth_params]: - auth_bound_params.append(bound_param) - elif bound_param not in [param.name for param in non_auth_params]: - missing_bound_params.append(bound_param) - - # Create error messages for any params that are found to be - # authenticated or missing. - messages: list[str] = [] - if auth_bound_params: - messages.append( - f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." - ) - if missing_bound_params: - messages.append( - f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound." - ) - - # Join any error messages and raise them as an error or warning, - # depending on the value of the strict flag. - if messages: - message = "\n\n".join(messages) - if strict: - raise ValueError(message) - warn(message) - - # Bind values for parameters present in the schema that don't require - # authentication. - bound_params = { - param_name: param_value - for param_name, param_value in bound_params.items() - if param_name in [param.name for param in non_auth_bound_params] - } - - # Update the tools schema to validate only the presence of parameters - # that neither require authentication nor are bound. - schema.parameters = non_auth_non_bound_params - - # Due to how pydantic works, we must initialize the underlying - # BaseTool class before assigning values to member variables. + self.__core_tool = core_tool super().__init__( - name=name, - description=schema.description, - args_schema=_schema_to_model(model_name=name, schema=schema.parameters), + name=self.__core_tool.__name__, + description=self.__core_tool.__doc__, + args_schema=self.__core_tool._ToolboxTool__pydantic_model, ) - self.__name = name - self.__schema = schema - self.__url = url - self.__session = session - self.__auth_token_getters = auth_token_getters - self.__auth_params = auth_params - self.__bound_params = bound_params - - # Warn users about any missing authentication so they can add it before - # tool invocation. - self.__validate_auth(strict=False) - - def _run(self, **kwargs: Any) -> dict[str, Any]: + def _run(self, **kwargs: Any) -> str: raise NotImplementedError("Synchronous methods not supported by async tools.") - async def _arun(self, **kwargs: Any) -> dict[str, Any]: + async def _arun(self, **kwargs: Any) -> str: """ The coroutine that invokes the tool with the given arguments. @@ -154,140 +71,12 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: A dictionary containing the parsed JSON response from the tool invocation. """ + return await self.__core_tool(**kwargs) - # If the tool had parameters that require authentication, then right - # before invoking that tool, we check whether all these required - # authentication sources have been registered or not. - self.__validate_auth() - - # Evaluate dynamic parameter values if any - evaluated_params = {} - for param_name, param_value in self.__bound_params.items(): - if callable(param_value): - evaluated_params[param_name] = param_value() - else: - evaluated_params[param_name] = param_value - - # Merge bound parameters with the provided arguments - kwargs.update(evaluated_params) - - return await _invoke_tool( - self.__url, self.__session, self.__name, kwargs, self.__auth_token_getters - ) - - def __validate_auth(self, strict: bool = True) -> None: - """ - Checks if a tool meets the authentication requirements. - - A tool is considered authenticated if all of its parameters meet at - least one of the following conditions: - * The parameter has at least one registered authentication source. - * The parameter requires no authentication. - - Args: - strict: If True, raises a PermissionError if any required - authentication sources are not registered. If False, only issues - a warning. - - Raises: - PermissionError: If strict is True and any required authentication - sources are not registered. - """ - is_authenticated: bool = not self.__schema.authRequired - params_missing_auth: list[str] = [] - - # Check tool for at least 1 required auth source - for src in self.__schema.authRequired: - if src in self.__auth_token_getters: - is_authenticated = True - break - - # Check each parameter for at least 1 required auth source - for param in self.__auth_params: - if not param.authSources: - raise ValueError("Auth sources cannot be None.") - has_auth = False - for src in param.authSources: - - # Find first auth source that is specified - if src in self.__auth_token_getters: - has_auth = True - break - if not has_auth: - params_missing_auth.append(param.name) - - messages: list[str] = [] - - if not is_authenticated: - messages.append( - f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use." - ) - - if params_missing_auth: - messages.append( - f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." - ) - - if messages: - message = "\n\n".join(messages) - if strict: - raise PermissionError(message) - warn(message) - - def __create_copy( - self, - *, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool, - ) -> "AsyncToolboxTool": - """ - Creates a copy of the current AsyncToolboxTool instance, allowing for - modification of auth tokens and bound params. - - This method enables the creation of new tool instances with inherited - properties from the current instance, while optionally updating the auth - tokens and bound params. This is useful for creating variations of the - tool with additional auth tokens or bound params without modifying the - original instance, ensuring immutability. - - Args: - auth_token_getters: A dictionary of auth source names to functions - that retrieve ID tokens. These tokens will be merged with the - existing auth tokens. - bound_params: A dictionary of parameter names to their - bound values or functions to retrieve the values. These params - will be merged with the existing bound params. - strict: If True, raises a ValueError if any of the given bound - parameters is missing from the schema or requires - authentication. If False, only issues a warning. - - Returns: - A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens or bound params. - """ - new_schema = deepcopy(self.__schema) - - # Reconstruct the complete parameter schema by merging the auth - # parameters back with the non-auth parameters. This is necessary to - # accurately validate the new combination of auth tokens and bound - # params in the constructor of the new AsyncToolboxTool instance, ensuring - # that any overlaps or conflicts are correctly identified and reported - # as errors or warnings, depending on the given `strict` flag. - new_schema.parameters += self.__auth_params - return AsyncToolboxTool( - name=self.__name, - schema=new_schema, - url=self.__url, - session=self.__session, - auth_token_getters={**self.__auth_token_getters, **auth_token_getters}, - bound_params={**self.__bound_params, **bound_params}, - strict=strict, - ) def add_auth_token_getters( - self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True + self, auth_token_getters: dict[str, Callable[[], str]] ) -> "AsyncToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding @@ -296,8 +85,6 @@ def add_auth_token_getters( Args: auth_token_getters: A dictionary of authentication source names to the functions that return corresponding ID token getters. - strict: If True, a ValueError is raised if any of the provided auth - parameters is already bound. If False, only a warning is issued. Returns: A new AsyncToolboxTool instance that is a deep copy of the current @@ -306,26 +93,13 @@ def add_auth_token_getters( Raises: ValueError: If any of the provided auth parameters is already registered. - ValueError: If any of the provided auth parameters is already bound - and strict is True. """ - - # Check if the authentication source is already registered. - dupe_tokens: list[str] = [] - for auth_token, _ in auth_token_getters.items(): - if auth_token in self.__auth_token_getters: - dupe_tokens.append(auth_token) - - if dupe_tokens: - raise ValueError( - f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." - ) - - return self.__create_copy(auth_token_getters=auth_token_getters, strict=strict) + new_core_tool = self.__core_tool.add_auth_token_getters(auth_token_getters) + return AsyncToolboxTool(core_tool=new_core_tool) def add_auth_token_getter( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + self, auth_source: str, get_id_token: Callable[[], str] ) -> "AsyncToolboxTool": """ Registers a function to retrieve an ID token for a given authentication @@ -334,8 +108,6 @@ def add_auth_token_getter( Args: auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if the provided auth - parameter is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -343,15 +115,13 @@ def add_auth_token_getter( Raises: ValueError: If the provided auth parameter is already registered. - ValueError: If the provided auth parameter is already bound and - strict is True. + """ - return self.add_auth_token_getters({auth_source: get_id_token}, strict=strict) + return self.add_auth_token_getters({auth_source: get_id_token}) def bind_params( self, bound_params: dict[str, Union[Any, Callable[[], Any]]], - strict: bool = True, ) -> "AsyncToolboxTool": """ Registers values or functions to retrieve the value for the @@ -360,9 +130,6 @@ def bind_params( Args: bound_params: A dictionary of the bound parameter name to the value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new AsyncToolboxTool instance that is a deep copy of the current @@ -370,29 +137,14 @@ def bind_params( Raises: ValueError: If any of the provided bound params is already bound. - ValueError: if any of the provided bound params is not defined in - the tool's schema, or requires authentication, and strict is - True. """ - - # Check if the parameter is already bound. - dupe_params: list[str] = [] - for param_name, _ in bound_params.items(): - if param_name in self.__bound_params: - dupe_params.append(param_name) - - if dupe_params: - raise ValueError( - f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`." - ) - - return self.__create_copy(bound_params=bound_params, strict=strict) + new_core_tool = self.__core_tool.bind_params(bound_params) + return AsyncToolboxTool(core_tool=new_core_tool) def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], - strict: bool = True, ) -> "AsyncToolboxTool": """ Registers a value or a function to retrieve the value for a given bound @@ -402,9 +154,6 @@ def bind_param( param_name: The name of the bound parameter. param_value: The value of the bound parameter, or a callable that returns the value. - strict: If True, a ValueError is raised if the provided bound param - is not defined in the tool's schema, or requires authentication. - If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -412,7 +161,5 @@ def bind_param( Raises: ValueError: If the provided bound param is already bound. - ValueError: if the provided bound param is not defined in the tool's - schema, or requires authentication, and strict is True. """ - return self.bind_params({param_name: param_value}, strict) + return self.bind_params({param_name: param_value}) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/utils.py b/packages/toolbox-langchain/src/toolbox_langchain/utils.py deleted file mode 100644 index 985c7bfe..00000000 --- a/packages/toolbox-langchain/src/toolbox_langchain/utils.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from typing import Any, Callable, Optional, Type, cast -from warnings import warn - -from aiohttp import ClientSession -from deprecated import deprecated -from langchain_core.tools import ToolException -from pydantic import BaseModel, Field, create_model - - -class ParameterSchema(BaseModel): - """ - Schema for a tool parameter. - """ - - name: str - type: str - description: str - authSources: Optional[list[str]] = None - items: Optional["ParameterSchema"] = None - - -class ToolSchema(BaseModel): - """ - Schema for a tool. - """ - - description: str - parameters: list[ParameterSchema] - authRequired: list[str] = [] - - -class ManifestSchema(BaseModel): - """ - Schema for the Toolbox manifest. - """ - - serverVersion: str - tools: dict[str, ToolSchema] - - -async def _load_manifest(url: str, session: ClientSession) -> ManifestSchema: - """ - Asynchronously fetches and parses the JSON manifest schema from the given - URL. - - Args: - url: The URL to fetch the JSON from. - session: The HTTP client session. - - Returns: - The parsed Toolbox manifest. - - Raises: - json.JSONDecodeError: If the response is not valid JSON. - ValueError: If the response is not a valid manifest. - """ - async with session.get(url) as response: - # TODO: Remove as it masks error messages. - response.raise_for_status() - try: - # TODO: Simply use response.json() - parsed_json = json.loads(await response.text()) - except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Failed to parse JSON from {url}: {e}", e.doc, e.pos - ) from e - try: - return ManifestSchema(**parsed_json) - except ValueError as e: - raise ValueError(f"Invalid JSON data from {url}: {e}") from e - - -def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]: - """ - Converts the given manifest schema to a Pydantic BaseModel class. - - Args: - model_name: The name of the model to create. - schema: The schema to convert. - - Returns: - A Pydantic BaseModel class. - """ - field_definitions = {} - for field in schema: - field_definitions[field.name] = cast( - Any, - ( - _parse_type(field), - Field(description=field.description), - ), - ) - - return create_model(model_name, **field_definitions) - - -def _parse_type(schema_: ParameterSchema) -> Any: - """ - Converts a schema type to a JSON type. - - Args: - schema_: The ParameterSchema to convert. - - Returns: - A valid JSON type. - - Raises: - ValueError: If the given type is not supported. - """ - type_ = schema_.type - - if type_ == "string": - return str - elif type_ == "integer": - return int - elif type_ == "float": - return float - elif type_ == "boolean": - return bool - elif type_ == "array": - if isinstance(schema_, ParameterSchema) and schema_.items: - return list[_parse_type(schema_.items)] # type: ignore - else: - raise ValueError(f"Schema missing field items") - else: - raise ValueError(f"Unsupported schema type: {type_}") - - -@deprecated("Please use `_get_auth_tokens` instead.") -def _get_auth_headers(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: - """ - Deprecated. Use `_get_auth_tokens` instead. - """ - return _get_auth_tokens(id_token_getters) - - -def _get_auth_tokens(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: - """ - Gets ID tokens for the given auth sources in the getters map and returns - tokens to be included in tool invocation. - - Args: - id_token_getters: A dict that maps auth source names to the functions - that return its ID token. - - Returns: - A dictionary of tokens to be included in the tool invocation. - """ - auth_tokens = {} - for auth_source, get_id_token in id_token_getters.items(): - auth_tokens[f"{auth_source}_token"] = get_id_token() - return auth_tokens - - -async def _invoke_tool( - url: str, - session: ClientSession, - tool_name: str, - data: dict, - id_token_getters: dict[str, Callable[[], str]], -) -> dict: - """ - Asynchronously makes an API call to the Toolbox service to invoke a tool. - - Args: - url: The base URL of the Toolbox service. - session: The HTTP client session. - tool_name: The name of the tool to invoke. - data: The input data for the tool. - id_token_getters: A dict that maps auth source names to the functions - that return its ID token. - - Returns: - A dictionary containing the parsed JSON response from the tool - invocation. - - Raises: - ToolException: If the Toolbox service returns an error. - """ - url = f"{url}/api/tool/{tool_name}/invoke" - auth_tokens = _get_auth_tokens(id_token_getters) - - # ID tokens contain sensitive user information (claims). Transmitting these - # over HTTP exposes the data to interception and unauthorized access. Always - # use HTTPS to ensure secure communication and protect user privacy. - if auth_tokens and not url.startswith("https://"): - warn( - "Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication." - ) - - async with session.post( - url, - json=data, - headers=auth_tokens, - ) as response: - ret = await response.json() - if "error" in ret: - raise ToolException(ret) - return ret.get("result", ret) - - -def _find_auth_params( - params: list[ParameterSchema], -) -> tuple[list[ParameterSchema], list[ParameterSchema]]: - """ - Separates parameters into those that are authenticated and those that are not. - - Args: - params: A list of ParameterSchema objects. - - Returns: - A tuple containing two lists: - - auth_params: A list of ParameterSchema objects that require authentication. - - non_auth_params: A list of ParameterSchema objects that do not require authentication. - """ - _auth_params: list[ParameterSchema] = [] - _non_auth_params: list[ParameterSchema] = [] - - for param in params: - if param.authSources: - _auth_params.append(param) - else: - _non_auth_params.append(param) - - return (_auth_params, _non_auth_params) - - -def _find_bound_params( - params: list[ParameterSchema], bound_params: list[str] -) -> tuple[list[ParameterSchema], list[ParameterSchema]]: - """ - Separates parameters into those that are bound and those that are not. - - Args: - params: A list of ParameterSchema objects. - bound_params: A list of parameter names that are bound. - - Returns: - A tuple containing two lists: - - bound_params: A list of ParameterSchema objects whose names are in the bound_params list. - - non_bound_params: A list of ParameterSchema objects whose names are not in the bound_params list. - """ - - _bound_params: list[ParameterSchema] = [] - _non_bound_params: list[ParameterSchema] = [] - - for param in params: - if param.name in bound_params: - _bound_params.append(param) - else: - _non_bound_params.append(param) - - return (_bound_params, _non_bound_params) diff --git a/packages/toolbox-langchain/tests/test_async_client.py b/packages/toolbox-langchain/tests/test_async_client.py index 25ad78eb..7b3d38c9 100644 --- a/packages/toolbox-langchain/tests/test_async_client.py +++ b/packages/toolbox-langchain/tests/test_async_client.py @@ -20,7 +20,7 @@ from toolbox_langchain.async_client import AsyncToolboxClient from toolbox_langchain.async_tools import AsyncToolboxTool -from toolbox_langchain.utils import ManifestSchema +from toolbox_core.protocol import ManifestSchema URL = "http://test_url" MANIFEST_JSON = { diff --git a/packages/toolbox-langchain/tests/test_utils.py b/packages/toolbox-langchain/tests/test_utils.py deleted file mode 100644 index 488a6aef..00000000 --- a/packages/toolbox-langchain/tests/test_utils.py +++ /dev/null @@ -1,290 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import json -import re -import warnings -from unittest.mock import AsyncMock, Mock, patch - -import aiohttp -import pytest -from pydantic import BaseModel - -from toolbox_langchain.utils import ( - ParameterSchema, - _get_auth_headers, - _invoke_tool, - _load_manifest, - _parse_type, - _schema_to_model, -) - -URL = "https://my-toolbox.com/test" -MOCK_MANIFEST = """ -{ - "serverVersion": "0.0.1", - "tools": { - "test_tool": { - "summary": "Test Tool", - "description": "This is a test tool.", - "parameters": [ - { - "name": "param1", - "type": "string", - "description": "Parameter 1" - }, - { - "name": "param2", - "type": "integer", - "description": "Parameter 2" - } - ] - } - } -} -""" - - -class TestUtils: - @pytest.fixture(scope="module") - def mock_manifest(self): - return aiohttp.ClientResponse( - method="GET", - url=aiohttp.client.URL(URL), - writer=None, - continue100=None, - timer=None, - request_info=None, - traces=None, - session=None, - loop=asyncio.get_event_loop(), - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value=MOCK_MANIFEST) - - mock_get.return_value = mock_manifest - session = aiohttp.ClientSession() - manifest = await _load_manifest(URL, session) - await session.close() - mock_get.assert_called_once_with(URL) - - assert manifest.serverVersion == "0.0.1" - assert len(manifest.tools) == 1 - - tool = manifest.tools["test_tool"] - assert tool.description == "This is a test tool." - assert tool.parameters == [ - ParameterSchema(name="param1", type="string", description="Parameter 1"), - ParameterSchema(name="param2", type="integer", description="Parameter 2"), - ] - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_invalid_json(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value="{ invalid manifest") - mock_get.return_value = mock_manifest - - with pytest.raises(Exception) as e: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - - mock_get.assert_called_once_with(URL) - assert isinstance(e.value, json.JSONDecodeError) - assert ( - str(e.value) - == "Failed to parse JSON from https://my-toolbox.com/test: Expecting property name enclosed in double quotes: line 1 column 3 (char 2): line 1 column 3 (char 2)" - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_invalid_manifest(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value='{ "something": "invalid" }') - mock_get.return_value = mock_manifest - - with pytest.raises(Exception) as e: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - - mock_get.assert_called_once_with(URL) - assert isinstance(e.value, ValueError) - assert re.match( - r"Invalid JSON data from https://my-toolbox.com/test: 2 validation errors for ManifestSchema\nserverVersion\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing\ntools\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing", - str(e.value), - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_api_error(self, mock_get, mock_manifest): - error = aiohttp.ClientError("Simulated HTTP Error") - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(side_effect=error) - mock_get.return_value = mock_manifest - - with pytest.raises(aiohttp.ClientError) as exc_info: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - mock_get.assert_called_once_with(URL) - assert exc_info.value == error - - def test_schema_to_model(self): - schema = [ - ParameterSchema(name="param1", type="string", description="Parameter 1"), - ParameterSchema(name="param2", type="integer", description="Parameter 2"), - ] - model = _schema_to_model("TestModel", schema) - assert issubclass(model, BaseModel) - - assert model.model_fields["param1"].annotation == str - assert model.model_fields["param1"].description == "Parameter 1" - assert model.model_fields["param2"].annotation == int - assert model.model_fields["param2"].description == "Parameter 2" - - def test_schema_to_model_empty(self): - model = _schema_to_model("TestModel", []) - assert issubclass(model, BaseModel) - assert len(model.model_fields) == 0 - - @pytest.mark.parametrize( - "parameter_schema, expected_type", - [ - (ParameterSchema(name="foo", description="bar", type="string"), str), - (ParameterSchema(name="foo", description="bar", type="integer"), int), - (ParameterSchema(name="foo", description="bar", type="float"), float), - (ParameterSchema(name="foo", description="bar", type="boolean"), bool), - ( - ParameterSchema( - name="foo", - description="bar", - type="array", - items=ParameterSchema( - name="foo", description="bar", type="integer" - ), - ), - list[int], - ), - ], - ) - def test_parse_type(self, parameter_schema, expected_type): - assert _parse_type(parameter_schema) == expected_type - - @pytest.mark.parametrize( - "fail_parameter_schema", - [ - (ParameterSchema(name="foo", description="bar", type="invalid")), - ( - ParameterSchema( - name="foo", - description="bar", - type="array", - items=ParameterSchema( - name="foo", description="bar", type="invalid" - ), - ) - ), - ], - ) - def test_parse_type_invalid(self, fail_parameter_schema): - with pytest.raises(ValueError): - _parse_type(fail_parameter_schema) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool(self, mock_post): - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - result = await _invoke_tool( - "http://localhost:5000", - aiohttp.ClientSession(), - "tool_name", - {"input": "data"}, - {}, - ) - - mock_post.assert_called_once_with( - "http://localhost:5000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={}, - ) - assert result == {"key": "value"} - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool_unsecure_with_auth(self, mock_post): - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - with pytest.warns( - UserWarning, - match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", - ): - result = await _invoke_tool( - "http://localhost:5000", - aiohttp.ClientSession(), - "tool_name", - {"input": "data"}, - {"my_test_auth": lambda: "fake_id_token"}, - ) - - mock_post.assert_called_once_with( - "http://localhost:5000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={"my_test_auth_token": "fake_id_token"}, - ) - assert result == {"key": "value"} - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool_secure_with_auth(self, mock_post): - session = aiohttp.ClientSession() - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - with warnings.catch_warnings(): - warnings.simplefilter("error") - result = await _invoke_tool( - "https://localhost:5000", - session, - "tool_name", - {"input": "data"}, - {"my_test_auth": lambda: "fake_id_token"}, - ) - - mock_post.assert_called_once_with( - "https://localhost:5000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={"my_test_auth_token": "fake_id_token"}, - ) - assert result == {"key": "value"} - - def test_get_auth_headers_deprecation_warning(self): - """Test _get_auth_headers deprecation warning.""" - with pytest.warns( - DeprecationWarning, - match=r"Call to deprecated function \(or staticmethod\) _get_auth_headers\. \(Please use `_get_auth_tokens` instead\.\)$", - ): - _get_auth_headers({"auth_source1": lambda: "test_token"}) From 53a673c23fecabd5db5558311f93b2bb050c5167 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 15:06:40 +0530 Subject: [PATCH 02/53] fix: add toolbox-core as package dependency --- packages/toolbox-langchain/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index f4f5b7aa..a66987c8 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,6 +9,7 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ + "toolbox-core>=0.1.0,<1.0.0", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.7.0,<3.0.0", From 42f4d35c157ba03f2e0b2e94363bd234f662ac72 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 15:29:26 +0530 Subject: [PATCH 03/53] fix: Base sync client --- .../src/toolbox_langchain/client.py | 271 ++++++++++-------- 1 file changed, 146 insertions(+), 125 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 3c75779c..994e5f93 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -13,21 +13,15 @@ # limitations under the License. import asyncio -from threading import Thread -from typing import Any, Awaitable, Callable, Optional, TypeVar, Union +from warnings import warn +from typing import Any, Callable, Optional, Union -from aiohttp import ClientSession - -from .async_client import AsyncToolboxClient from .tools import ToolboxTool +from toolbox_core.sync_client import ToolboxSyncClient as ToolboxCoreSyncClient -T = TypeVar("T") class ToolboxClient: - __session: Optional[ClientSession] = None - __loop: Optional[asyncio.AbstractEventLoop] = None - __thread: Optional[Thread] = None def __init__( self, @@ -39,51 +33,7 @@ def __init__( Args: url: The base URL of the Toolbox service. """ - - # Running a loop in a background thread allows us to support async - # methods from non-async environments. - if ToolboxClient.__loop is None: - loop = asyncio.new_event_loop() - thread = Thread(target=loop.run_forever, daemon=True) - thread.start() - ToolboxClient.__thread = thread - ToolboxClient.__loop = loop - - async def __start_session() -> None: - - # Use a default session if none is provided. This leverages connection - # pooling for better performance by reusing a single session throughout - # the application's lifetime. - if ToolboxClient.__session is None: - ToolboxClient.__session = ClientSession() - - coro = __start_session() - - asyncio.run_coroutine_threadsafe(coro, ToolboxClient.__loop).result() - - if not ToolboxClient.__session: - raise ValueError("Session cannot be None.") - self.__async_client = AsyncToolboxClient(url, ToolboxClient.__session) - - def __run_as_sync(self, coro: Awaitable[T]) -> T: - """Run an async coroutine synchronously""" - if not self.__loop: - raise Exception( - "Cannot call synchronous methods before the background loop is initialized." - ) - return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() - - async def __run_as_async(self, coro: Awaitable[T]) -> T: - """Run an async coroutine asynchronously""" - - # If a loop has not been provided, attempt to run in current thread. - if not self.__loop: - return await coro - - # Otherwise, run in the background thread. - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__loop) - ) + self.__core_sync_client = ToolboxCoreSyncClient(url=url) async def aload_tool( self, @@ -92,7 +42,6 @@ async def aload_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> ToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. @@ -105,27 +54,42 @@ async def aload_tool( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ - async_tool = await self.__run_as_async( - self.__async_client.aload_tool( - tool_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + core_tool = await self.__core_sync_client._ToolboxSyncClient__async_client.load_tool( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params ) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - return ToolboxTool(async_tool, self.__loop, self.__thread) + return ToolboxTool(core_tool=core_tool) async def aload_toolset( self, @@ -134,7 +98,7 @@ async def aload_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[ToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -149,30 +113,51 @@ async def aload_toolset( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - async_tools = await self.__run_as_async( - self.__async_client.aload_toolset( - toolset_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + core_tools = await self.__core_sync_client._ToolboxSyncClient__async_client.load_toolset( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict ) - tools: list[ToolboxTool] = [] - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - for async_tool in async_tools: - tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + tools = [] + for core_tool in core_tools: + tools.append(ToolboxTool(core_tool_instance=core_tool)) return tools def load_tool( @@ -182,7 +167,6 @@ def load_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> ToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. @@ -195,27 +179,42 @@ def load_tool( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ - async_tool = self.__run_as_sync( - self.__async_client.aload_tool( - tool_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + core_tool = self.__core_sync_client.load_tool( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params ) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - return ToolboxTool(async_tool, self.__loop, self.__thread) + return ToolboxTool(core_tool=core_tool) def load_toolset( self, @@ -224,7 +223,7 @@ def load_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[ToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -239,27 +238,49 @@ def load_toolset( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - async_tools = self.__run_as_sync( - self.__async_client.aload_toolset( - toolset_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + core_tools = self.__core_sync_client.load_toolset( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict ) - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - tools: list[ToolboxTool] = [] - for async_tool in async_tools: - tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + tools = [] + for core_tool in core_tools: + tools.append(ToolboxTool(core_tool_instance=core_tool)) return tools From 2aeb7180f03c80b648df0e66f6df7923fb2f590d Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 17:14:58 +0530 Subject: [PATCH 04/53] fix: Fix running background asyncio in current loop issue --- .../src/toolbox_langchain/async_client.py | 2 +- .../src/toolbox_langchain/client.py | 27 ++++++++++++++++--- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py index d7e39814..41f2db60 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py @@ -161,7 +161,7 @@ async def aload_toolset( tools = [] for core_tool in core_tools: - tools.append(AsyncToolboxTool(core_tool_instance=core_tool)) + tools.append(AsyncToolboxTool(core_tool=core_tool)) return tools def load_tool( diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 994e5f93..52079b03 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -84,11 +84,21 @@ async def aload_tool( ) auth_token_getters = auth_tokens - core_tool = await self.__core_sync_client._ToolboxSyncClient__async_client.load_tool( + coro = self.__core_sync_client._ToolboxSyncClient__async_client.load_tool( name=tool_name, auth_token_getters=auth_token_getters, bound_params=bound_params ) + + # If a loop has not been provided, attempt to run in current thread. + if not self.__core_sync_client._ToolboxSyncClient__loop: + return await coro + + # Otherwise, run in the background thread. + core_tool = await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) + ) + return ToolboxTool(core_tool=core_tool) async def aload_toolset( @@ -148,16 +158,25 @@ async def aload_toolset( ) auth_token_getters = auth_tokens - core_tools = await self.__core_sync_client._ToolboxSyncClient__async_client.load_toolset( + coro = self.__core_sync_client._ToolboxSyncClient__async_client.load_toolset( name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, strict=strict ) + # If a loop has not been provided, attempt to run in current thread. + if not self.__core_sync_client._ToolboxSyncClient__loop: + return await coro + + # Otherwise, run in the background thread. + core_tools = await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) + ) + tools = [] for core_tool in core_tools: - tools.append(ToolboxTool(core_tool_instance=core_tool)) + tools.append(ToolboxTool(core_tool=core_tool)) return tools def load_tool( @@ -282,5 +301,5 @@ def load_toolset( tools = [] for core_tool in core_tools: - tools.append(ToolboxTool(core_tool_instance=core_tool)) + tools.append(ToolboxTool(core_tool=core_tool)) return tools From bbc4bbaaa4202094c64e73ff787516d5859fdb1a Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 17:43:02 +0530 Subject: [PATCH 05/53] fix: Base toolbox sync & async tools to toolbox core counterparts --- .../src/toolbox_langchain/async_client.py | 3 +- .../src/toolbox_langchain/async_tools.py | 17 +-- .../src/toolbox_langchain/client.py | 20 ++-- .../src/toolbox_langchain/tools.py | 104 ++++-------------- 4 files changed, 38 insertions(+), 106 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py index 41f2db60..2e1053a3 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py @@ -17,8 +17,7 @@ from aiohttp import ClientSession -from .tools import AsyncToolboxTool - +from .async_tools import AsyncToolboxTool from toolbox_core.client import ToolboxClient as ToolboxCoreClient diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index 1a7b3dd1..aec0efa4 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from typing import Any, Callable, TypeVar, Union -from warnings import warn +from typing import Any, Callable, Union from langchain_core.tools import BaseTool from toolbox_core.tool import ToolboxTool as ToolboxCoreTool -T = TypeVar("T") - # This class is an internal implementation detail and is not exposed to the # end-user. It should not be used directly by external code. Changes to this @@ -40,14 +36,7 @@ def __init__( Initializes an AsyncToolboxTool instance. Args: - name: The name of the tool. - schema: The tool schema. - url: The base URL of the Toolbox service. - session: The HTTP client session. - auth_token_getters: A mapping of authentication source names to - functions that retrieve ID tokens. - bound_params: A mapping of parameter names to their bound - values. + core_tool: The underlying core async ToolboxTool instance. """ self.__core_tool = core_tool @@ -88,7 +77,7 @@ def add_auth_token_getters( Returns: A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. + instance, with added auth token getters. Raises: ValueError: If any of the provided auth parameters is already diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 52079b03..27994171 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -95,11 +95,11 @@ async def aload_tool( return await coro # Otherwise, run in the background thread. - core_tool = await asyncio.wrap_future( + core_sync_tool = await asyncio.wrap_future( asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) ) - return ToolboxTool(core_tool=core_tool) + return ToolboxTool(core_sync_tool=core_sync_tool) async def aload_toolset( self, @@ -170,13 +170,13 @@ async def aload_toolset( return await coro # Otherwise, run in the background thread. - core_tools = await asyncio.wrap_future( + core_sync_tools = await asyncio.wrap_future( asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) ) tools = [] - for core_tool in core_tools: - tools.append(ToolboxTool(core_tool=core_tool)) + for core_sync_tool in core_sync_tools: + tools.append(ToolboxTool(core_sync_tool=core_sync_tool)) return tools def load_tool( @@ -228,12 +228,12 @@ def load_tool( ) auth_token_getters = auth_tokens - core_tool = self.__core_sync_client.load_tool( + core_sync_tool = self.__core_sync_client.load_tool( name=tool_name, auth_token_getters=auth_token_getters, bound_params=bound_params ) - return ToolboxTool(core_tool=core_tool) + return ToolboxTool(core_sync_tool=core_sync_tool) def load_toolset( self, @@ -292,7 +292,7 @@ def load_toolset( ) auth_token_getters = auth_tokens - core_tools = self.__core_sync_client.load_toolset( + core_sync_tools = self.__core_sync_client.load_toolset( name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, @@ -300,6 +300,6 @@ def load_toolset( ) tools = [] - for core_tool in core_tools: - tools.append(ToolboxTool(core_tool=core_tool)) + for core_sync_tool in core_sync_tools: + tools.append(ToolboxTool(core_sync_tool=core_sync_tool)) return tools diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index feb2a597..ecdde697 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -13,15 +13,11 @@ # limitations under the License. import asyncio -from asyncio import AbstractEventLoop -from threading import Thread -from typing import Any, Awaitable, Callable, TypeVar, Union +from typing import Any, Callable, Union from langchain_core.tools import BaseTool +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool -from .async_tools import AsyncToolboxTool - -T = TypeVar("T") class ToolboxTool(BaseTool): @@ -32,56 +28,37 @@ class ToolboxTool(BaseTool): def __init__( self, - async_tool: AsyncToolboxTool, - loop: AbstractEventLoop, - thread: Thread, + core_sync_tool: ToolboxCoreSyncTool, ) -> None: """ Initializes a ToolboxTool instance. Args: - async_tool: The underlying AsyncToolboxTool instance. - loop: The event loop used to run asynchronous tasks. - thread: The thread to run blocking operations in. + core_sync_tool: The underlying core sync ToolboxTool instance. """ - # Due to how pydantic works, we must initialize the underlying - # BaseTool class before assigning values to member variables. + self.__core_sync_tool = core_sync_tool super().__init__( - name=async_tool.name, - description=async_tool.description, - args_schema=async_tool.args_schema, + name=self.__core_sync_tool.__name__, + description=self.__core_sync_tool.__doc__, + args_schema=self.__core_sync_tool._ToolboxSyncTool__pydantic_model, ) - self.__async_tool = async_tool - self.__loop = loop - self.__thread = thread - - def __run_as_sync(self, coro: Awaitable[T]) -> T: - """Run an async coroutine synchronously""" - if not self.__loop: - raise Exception( - "Cannot call synchronous methods before the background loop is initialized." - ) - return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() + def _run(self, **kwargs: Any) -> dict[str, Any]: + return self.__core_sync_tool(**kwargs) - async def __run_as_async(self, coro: Awaitable[T]) -> T: - """Run an async coroutine asynchronously""" + async def _arun(self, **kwargs: Any) -> dict[str, Any]: + coro = self.__core_sync_tool._ToolboxSyncTool__async_tool(**kwargs) # If a loop has not been provided, attempt to run in current thread. - if not self.__loop: + if not self.__core_sync_client._ToolboxSyncClient__loop: return await coro # Otherwise, run in the background thread. - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__loop) + await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncTool__loop) ) - def _run(self, **kwargs: Any) -> dict[str, Any]: - return self.__run_as_sync(self.__async_tool._arun(**kwargs)) - - async def _arun(self, **kwargs: Any) -> dict[str, Any]: - return await self.__run_as_async(self.__async_tool._arun(**kwargs)) def add_auth_token_getters( self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True @@ -93,27 +70,21 @@ def add_auth_token_getters( Args: auth_token_getters: A dictionary of authentication source names to the functions that return corresponding ID token. - strict: If True, a ValueError is raised if any of the provided auth - parameters is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. + instance, with added auth token getters. Raises: ValueError: If any of the provided auth parameters is already registered. - ValueError: If any of the provided auth parameters is already bound - and strict is True. """ - return ToolboxTool( - self.__async_tool.add_auth_token_getters(auth_token_getters, strict), - self.__loop, - self.__thread, - ) + new_core_sync_tool = self.__core_sync_tool.add_auth_token_getters(auth_token_getters) + return ToolboxTool(core_sync_tool=new_core_sync_tool) + def add_auth_token_getter( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + self, auth_source: str, get_id_token: Callable[[], str] ) -> "ToolboxTool": """ Registers a function to retrieve an ID token for a given authentication @@ -122,8 +93,6 @@ def add_auth_token_getter( Args: auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if the provided auth - parameter is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -131,19 +100,12 @@ def add_auth_token_getter( Raises: ValueError: If the provided auth parameter is already registered. - ValueError: If the provided auth parameter is already bound and - strict is True. """ - return ToolboxTool( - self.__async_tool.add_auth_token_getter(auth_source, get_id_token, strict), - self.__loop, - self.__thread, - ) + return self.add_auth_token_getters({auth_source: get_id_token}) def bind_params( self, bound_params: dict[str, Union[Any, Callable[[], Any]]], - strict: bool = True, ) -> "ToolboxTool": """ Registers values or functions to retrieve the value for the @@ -152,9 +114,6 @@ def bind_params( Args: bound_params: A dictionary of the bound parameter name to the value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -162,15 +121,9 @@ def bind_params( Raises: ValueError: If any of the provided bound params is already bound. - ValueError: if any of the provided bound params is not defined in - the tool's schema, or require authentication, and strict is - True. """ - return ToolboxTool( - self.__async_tool.bind_params(bound_params, strict), - self.__loop, - self.__thread, - ) + new_core_sync_tool = self.__core_sync_tool.bind_params(bound_params) + return ToolboxTool(core_sync_tool=new_core_sync_tool) def bind_param( self, @@ -186,9 +139,6 @@ def bind_param( param_name: The name of the bound parameter. param_value: The value of the bound parameter, or a callable that returns the value. - strict: If True, a ValueError is raised if the provided bound - param is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -196,11 +146,5 @@ def bind_param( Raises: ValueError: If the provided bound param is already bound. - ValueError: if the provided bound param is not defined in the tool's - schema, or requires authentication, and strict is True. """ - return ToolboxTool( - self.__async_tool.bind_param(param_name, param_value, strict), - self.__loop, - self.__thread, - ) + return self.bind_params({param_name: param_value}) From 56e14cd41ae6809c9da3d013e2e30df273a50666 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 17:51:46 +0530 Subject: [PATCH 06/53] fix: Fix getting pydantic model from ToolboxSyncTool --- packages/toolbox-langchain/src/toolbox_langchain/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index ecdde697..37ca1c4e 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -41,7 +41,7 @@ def __init__( super().__init__( name=self.__core_sync_tool.__name__, description=self.__core_sync_tool.__doc__, - args_schema=self.__core_sync_tool._ToolboxSyncTool__pydantic_model, + args_schema=self.__core_sync_tool._ToolboxSyncTool__async_tool._ToolboxTool__pydantic_model, ) def _run(self, **kwargs: Any) -> dict[str, Any]: From 44fa98e2784ada1e50d2f0926b2035128ebce371 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 19:54:39 +0530 Subject: [PATCH 07/53] fix: Fix issue causing async core tools for creating sync tools --- .../src/toolbox_langchain/client.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 27994171..646abbba 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -18,6 +18,7 @@ from .tools import ToolboxTool from toolbox_core.sync_client import ToolboxSyncClient as ToolboxCoreSyncClient +from toolbox_core.sync_tool import ToolboxSyncTool @@ -90,15 +91,16 @@ async def aload_tool( bound_params=bound_params ) - # If a loop has not been provided, attempt to run in current thread. if not self.__core_sync_client._ToolboxSyncClient__loop: - return await coro - - # Otherwise, run in the background thread. - core_sync_tool = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) - ) - + # If a loop has not been provided, attempt to run in current thread. + core_tool = await coro + else: + # Otherwise, run in the background thread. + core_tool = await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) + ) + + core_sync_tool = ToolboxSyncTool(core_tool, self.__core_sync_client._ToolboxSyncClient__loop, self.__core_sync_client._ToolboxSyncClient__thread) return ToolboxTool(core_sync_tool=core_sync_tool) async def aload_toolset( @@ -165,15 +167,19 @@ async def aload_toolset( strict=strict ) - # If a loop has not been provided, attempt to run in current thread. if not self.__core_sync_client._ToolboxSyncClient__loop: - return await coro - - # Otherwise, run in the background thread. - core_sync_tools = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) - ) - + # If a loop has not been provided, attempt to run in current thread. + core_tools = await coro + else: + # Otherwise, run in the background thread. + core_tools = await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) + ) + + core_sync_tools = [ + ToolboxSyncTool(core_tool, self.__core_sync_client._ToolboxSyncClient__loop, self.__core_sync_client._ToolboxSyncClient__thread) + for core_tool in core_tools + ] tools = [] for core_sync_tool in core_sync_tools: tools.append(ToolboxTool(core_sync_tool=core_sync_tool)) From f44dc265262b3c0c67a4192bbb1d188a85ffdca2 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 20:01:27 +0530 Subject: [PATCH 08/53] fix: Fix reading name from correct param --- packages/toolbox-langchain/tests/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 214ea305..cff21533 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -54,7 +54,7 @@ def toolbox(self): @pytest_asyncio.fixture(scope="function") async def get_n_rows_tool(self, toolbox): tool = await toolbox.aload_tool("get-n-rows") - assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + assert tool._ToolboxTool__core_sync_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests From 00a3666ec88f62c92581d18c7ba45f0ed2540000 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 20:10:50 +0530 Subject: [PATCH 09/53] fix: Fix issue of unknown parameter due to pydantic initialization --- .../toolbox-langchain/src/toolbox_langchain/async_tools.py | 3 +-- packages/toolbox-langchain/src/toolbox_langchain/tools.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index aec0efa4..169aef03 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -38,13 +38,12 @@ def __init__( Args: core_tool: The underlying core async ToolboxTool instance. """ - - self.__core_tool = core_tool super().__init__( name=self.__core_tool.__name__, description=self.__core_tool.__doc__, args_schema=self.__core_tool._ToolboxTool__pydantic_model, ) + self.__core_tool = core_tool def _run(self, **kwargs: Any) -> str: raise NotImplementedError("Synchronous methods not supported by async tools.") diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 37ca1c4e..b781a8d3 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -36,13 +36,12 @@ def __init__( Args: core_sync_tool: The underlying core sync ToolboxTool instance. """ - - self.__core_sync_tool = core_sync_tool super().__init__( name=self.__core_sync_tool.__name__, description=self.__core_sync_tool.__doc__, args_schema=self.__core_sync_tool._ToolboxSyncTool__async_tool._ToolboxTool__pydantic_model, ) + self.__core_sync_tool = core_sync_tool def _run(self, **kwargs: Any) -> dict[str, Any]: return self.__core_sync_tool(**kwargs) From 24d7a557598b237f210d4dc0bae08e191a7965e2 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 20:15:39 +0530 Subject: [PATCH 10/53] fix: Fix nit error + add comment --- .../src/toolbox_langchain/async_tools.py | 9 ++++++--- .../toolbox-langchain/src/toolbox_langchain/tools.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index 169aef03..f2f26433 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -38,10 +38,13 @@ def __init__( Args: core_tool: The underlying core async ToolboxTool instance. """ + + # Due to how pydantic works, we must initialize the underlying + # BaseTool class before assigning values to member variables. super().__init__( - name=self.__core_tool.__name__, - description=self.__core_tool.__doc__, - args_schema=self.__core_tool._ToolboxTool__pydantic_model, + name=core_tool.__name__, + description=core_tool.__doc__, + args_schema=core_tool._ToolboxTool__pydantic_model, ) self.__core_tool = core_tool diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index b781a8d3..fb3d6ef0 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -36,10 +36,13 @@ def __init__( Args: core_sync_tool: The underlying core sync ToolboxTool instance. """ + + # Due to how pydantic works, we must initialize the underlying + # BaseTool class before assigning values to member variables. super().__init__( - name=self.__core_sync_tool.__name__, - description=self.__core_sync_tool.__doc__, - args_schema=self.__core_sync_tool._ToolboxSyncTool__async_tool._ToolboxTool__pydantic_model, + name=core_sync_tool.__name__, + description=core_sync_tool.__doc__, + args_schema=core_sync_tool._ToolboxSyncTool__async_tool._ToolboxTool__pydantic_model, ) self.__core_sync_tool = core_sync_tool From 113a3128ce68220feac973478dc10212267bdc2a Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 20:20:08 +0530 Subject: [PATCH 11/53] fix: Fix sync tool name assertion --- packages/toolbox-langchain/tests/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index cff21533..689d8c40 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -196,7 +196,7 @@ def toolbox(self): @pytest.fixture(scope="function") def get_n_rows_tool(self, toolbox): tool = toolbox.load_tool("get-n-rows") - assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + assert tool._ToolboxTool__core_sync_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests From 3e09cf222d7f831774d4a68958c99ede52dc3f32 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 20:23:38 +0530 Subject: [PATCH 12/53] chore: Temporarily remove unittests --- .../tests/test_async_client.py | 193 ------------ .../tests/test_async_tools.py | 274 ------------------ .../toolbox-langchain/tests/test_client.py | 259 ----------------- .../toolbox-langchain/tests/test_tools.py | 238 --------------- 4 files changed, 964 deletions(-) delete mode 100644 packages/toolbox-langchain/tests/test_async_client.py delete mode 100644 packages/toolbox-langchain/tests/test_async_tools.py delete mode 100644 packages/toolbox-langchain/tests/test_client.py delete mode 100644 packages/toolbox-langchain/tests/test_tools.py diff --git a/packages/toolbox-langchain/tests/test_async_client.py b/packages/toolbox-langchain/tests/test_async_client.py deleted file mode 100644 index 7b3d38c9..00000000 --- a/packages/toolbox-langchain/tests/test_async_client.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import AsyncMock, patch -from warnings import catch_warnings, simplefilter - -import pytest -from aiohttp import ClientSession - -from toolbox_langchain.async_client import AsyncToolboxClient -from toolbox_langchain.async_tools import AsyncToolboxTool -from toolbox_core.protocol import ManifestSchema - -URL = "http://test_url" -MANIFEST_JSON = { - "serverVersion": "1.0.0", - "tools": { - "test_tool_1": { - "description": "Test Tool 1 Description", - "parameters": [ - { - "name": "param1", - "type": "string", - "description": "Param 1", - } - ], - }, - "test_tool_2": { - "description": "Test Tool 2 Description", - "parameters": [ - { - "name": "param2", - "type": "integer", - "description": "Param 2", - } - ], - }, - }, -} - - -@pytest.mark.asyncio -class TestAsyncToolboxClient: - @pytest.fixture() - def manifest_schema(self): - return ManifestSchema(**MANIFEST_JSON) - - @pytest.fixture() - def mock_session(self): - return AsyncMock(spec=ClientSession) - - @pytest.fixture() - def mock_client(self, mock_session): - return AsyncToolboxClient(URL, session=mock_session) - - async def test_create_with_existing_session(self, mock_client, mock_session): - assert mock_client._AsyncToolboxClient__session == mock_session - - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_tool( - self, mock_load_manifest, mock_client, mock_session, manifest_schema - ): - tool_name = "test_tool_1" - mock_load_manifest.return_value = manifest_schema - - tool = await mock_client.aload_tool(tool_name) - - mock_load_manifest.assert_called_once_with( - f"{URL}/api/tool/{tool_name}", mock_session - ) - assert isinstance(tool, AsyncToolboxTool) - assert tool.name == tool_name - - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_tool_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema - ): - tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest - with catch_warnings(record=True) as w: - simplefilter("always") - await mock_client.aload_tool( - tool_name, auth_headers={"Authorization": lambda: "Bearer token"} - ) - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "auth_headers" in str(w[-1].message) - - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_tool_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema - ): - tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest - with catch_warnings(record=True) as w: - simplefilter("always") - await mock_client.aload_tool( - tool_name, - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_token_getters={"test": lambda: "token"}, - ) - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "auth_headers" in str(w[-1].message) - - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_toolset( - self, mock_load_manifest, mock_client, mock_session, manifest_schema - ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest - tools = await mock_client.aload_toolset() - - mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) - assert len(tools) == 2 - for tool in tools: - assert isinstance(tool, AsyncToolboxTool) - assert tool.name in ["test_tool_1", "test_tool_2"] - - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_toolset_with_toolset_name( - self, mock_load_manifest, mock_client, mock_session, manifest_schema - ): - toolset_name = "test_toolset_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest - tools = await mock_client.aload_toolset(toolset_name=toolset_name) - - mock_load_manifest.assert_called_once_with( - f"{URL}/api/toolset/{toolset_name}", mock_session - ) - assert len(tools) == 2 - for tool in tools: - assert isinstance(tool, AsyncToolboxTool) - assert tool.name in ["test_tool_1", "test_tool_2"] - - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_toolset_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema - ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest - with catch_warnings(record=True) as w: - simplefilter("always") - await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"} - ) - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "auth_headers" in str(w[-1].message) - - @patch("toolbox_langchain.async_client._load_manifest") - async def test_aload_toolset_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema - ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest - with catch_warnings(record=True) as w: - simplefilter("always") - await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_token_getters={"test": lambda: "token"}, - ) - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "auth_headers" in str(w[-1].message) - - async def test_load_tool_not_implemented(self, mock_client): - with pytest.raises(NotImplementedError) as excinfo: - mock_client.load_tool("test_tool") - assert "Synchronous methods not supported by async client." in str( - excinfo.value - ) - - async def test_load_toolset_not_implemented(self, mock_client): - with pytest.raises(NotImplementedError) as excinfo: - mock_client.load_toolset() - assert "Synchronous methods not supported by async client." in str( - excinfo.value - ) diff --git a/packages/toolbox-langchain/tests/test_async_tools.py b/packages/toolbox-langchain/tests/test_async_tools.py deleted file mode 100644 index e23aee85..00000000 --- a/packages/toolbox-langchain/tests/test_async_tools.py +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import AsyncMock, Mock, patch - -import pytest -import pytest_asyncio -from pydantic import ValidationError - -from toolbox_langchain.async_tools import AsyncToolboxTool - - -@pytest.mark.asyncio -class TestAsyncToolboxTool: - @pytest.fixture - def tool_schema(self): - return { - "description": "Test Tool Description", - "parameters": [ - {"name": "param1", "type": "string", "description": "Param 1"}, - {"name": "param2", "type": "integer", "description": "Param 2"}, - ], - } - - @pytest.fixture - def auth_tool_schema(self): - return { - "description": "Test Tool Description", - "parameters": [ - { - "name": "param1", - "type": "string", - "description": "Param 1", - "authSources": ["test-auth-source"], - }, - {"name": "param2", "type": "integer", "description": "Param 2"}, - ], - } - - @pytest_asyncio.fixture - @patch("aiohttp.ClientSession") - async def toolbox_tool(self, MockClientSession, tool_schema): - mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} - ) - tool = AsyncToolboxTool( - name="test_tool", - schema=tool_schema, - url="http://test_url", - session=mock_session, - ) - return tool - - @pytest_asyncio.fixture - @patch("aiohttp.ClientSession") - async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema): - mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} - ) - with pytest.warns( - UserWarning, - match=r"Parameter\(s\) `param1` of tool test_tool require authentication", - ): - tool = AsyncToolboxTool( - name="test_tool", - schema=auth_tool_schema, - url="https://test-url", - session=mock_session, - ) - return tool - - @patch("aiohttp.ClientSession") - async def test_toolbox_tool_init(self, MockClientSession, tool_schema): - mock_session = MockClientSession.return_value - tool = AsyncToolboxTool( - name="test_tool", - schema=tool_schema, - url="https://test-url", - session=mock_session, - ) - assert tool.name == "test_tool" - assert tool.description == "Test Tool Description" - - @pytest.mark.parametrize( - "params, expected_bound_params", - [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), - ], - ) - async def test_toolbox_tool_bind_params( - self, toolbox_tool, params, expected_bound_params - ): - tool = toolbox_tool.bind_params(params) - for key, value in expected_bound_params.items(): - if callable(value): - assert value() == tool._AsyncToolboxTool__bound_params[key]() - else: - assert value == tool._AsyncToolboxTool__bound_params[key] - - @pytest.mark.parametrize("strict", [True, False]) - async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool, strict): - if strict: - with pytest.raises(ValueError) as e: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) - else: - with pytest.warns(UserWarning) as record: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert len(record) == 1 - assert "Parameter(s) param3 missing and cannot be bound." in str( - record[0].message - ) - - async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): - tool = toolbox_tool.bind_params({"param1": "bound-value"}) - with pytest.raises(ValueError) as e: - tool = tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( - e.value - ) - - async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): - with pytest.raises(ValueError) as e: - auth_toolbox_tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) param1 already authenticated and cannot be bound." in str( - e.value - ) - - @pytest.mark.parametrize( - "auth_token_getters, expected_auth_token_getters", - [ - ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), - ( - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - ), - ], - ) - async def test_toolbox_tool_add_auth_token_getters( - self, auth_toolbox_tool, auth_token_getters, expected_auth_token_getters - ): - tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) - for source, getter in expected_auth_token_getters.items(): - assert tool._AsyncToolboxTool__auth_token_getters[source]() == getter() - - async def test_toolbox_tool_add_auth_token_getters_duplicate( - self, auth_toolbox_tool - ): - tool = auth_toolbox_tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - with pytest.raises(ValueError) as e: - tool = tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - assert ( - "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." - in str(e.value) - ) - - async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - with pytest.raises(PermissionError) as e: - auth_toolbox_tool._AsyncToolboxTool__validate_auth(strict=True) - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value - ) - - async def test_toolbox_tool_call(self, toolbox_tool): - result = await toolbox_tool.ainvoke({"param1": "test-value", "param2": 123}) - assert result == "test-result" - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( - "http://test_url/api/tool/test_tool/invoke", - json={"param1": "test-value", "param2": 123}, - headers={}, - ) - - @pytest.mark.parametrize( - "bound_param, expected_value", - [ - ({"param1": "bound-value"}, "bound-value"), - ({"param1": lambda: "dynamic-value"}, "dynamic-value"), - ], - ) - async def test_toolbox_tool_call_with_bound_params( - self, toolbox_tool, bound_param, expected_value - ): - tool = toolbox_tool.bind_params(bound_param) - result = await tool.ainvoke({"param2": 123}) - assert result == "test-result" - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( - "http://test_url/api/tool/test_tool/invoke", - json={"param1": expected_value, "param2": 123}, - headers={}, - ) - - async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): - tool = auth_toolbox_tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - result = await tool.ainvoke({"param2": 123}) - assert result == "test-result" - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( - "https://test-url/api/tool/test_tool/invoke", - json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, - ) - - async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool): - with pytest.warns( - UserWarning, - match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", - ): - auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" - tool = auth_toolbox_tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - result = await tool.ainvoke({"param2": 123}) - assert result == "test-result" - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( - "http://test-url/api/tool/test_tool/invoke", - json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, - ) - - async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): - with pytest.raises(ValidationError) as e: - await toolbox_tool.ainvoke({"param1": 123, "param2": "invalid"}) - assert "2 validation errors for test_tool" in str(e.value) - assert "param1\n Input should be a valid string" in str(e.value) - assert "param2\n Input should be a valid integer" in str(e.value) - - async def test_toolbox_tool_call_with_empty_input(self, toolbox_tool): - with pytest.raises(ValidationError) as e: - await toolbox_tool.ainvoke({}) - assert "2 validation errors for test_tool" in str(e.value) - assert "param1\n Field required" in str(e.value) - assert "param2\n Field required" in str(e.value) - - async def test_toolbox_tool_run_not_implemented(self, toolbox_tool): - with pytest.raises(NotImplementedError): - toolbox_tool._run() diff --git a/packages/toolbox-langchain/tests/test_client.py b/packages/toolbox-langchain/tests/test_client.py deleted file mode 100644 index 62999019..00000000 --- a/packages/toolbox-langchain/tests/test_client.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import Mock, patch - -import pytest -from pydantic import BaseModel - -from toolbox_langchain.client import ToolboxClient -from toolbox_langchain.tools import ToolboxTool - -URL = "http://test_url" - - -class TestToolboxClient: - @pytest.fixture() - def toolbox_client(self): - client = ToolboxClient(URL) - assert isinstance(client, ToolboxClient) - assert client._ToolboxClient__async_client is not None - - # Check that the background loop was created and started - assert client._ToolboxClient__loop is not None - assert client._ToolboxClient__loop.is_running() - - return client - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - def test_load_tool(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - - tool = toolbox_client.load_tool("test_tool") - - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - tools = toolbox_client.load_toolset() - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) - - @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - - tool = await toolbox_client.aload_tool("test_tool") - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) - - @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - async def test_aload_toolset(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - tools = await toolbox_client.aload_toolset() - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - def test_load_tool_with_args(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} - bound_params = {"param1": "value4"} - - tool = toolbox_client.load_tool( - "test_tool_name", - auth_token_getters=auth_token_getters, - auth_tokens=auth_tokens, - auth_headers=auth_headers, - bound_params=bound_params, - strict=False, - ) - - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with( - "test_tool_name", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset_with_args(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} - bound_params = {"param1": "value4"} - - tools = toolbox_client.load_toolset( - toolset_name="my_toolset", - auth_token_getters=auth_token_getters, - auth_tokens=auth_tokens, - auth_headers=auth_headers, - bound_params=bound_params, - strict=False, - ) - - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - - @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool_with_args(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - - auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} - bound_params = {"param1": "value4"} - - tool = await toolbox_client.aload_tool( - "test_tool", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with( - "test_tool", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - - @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - async def test_aload_toolset_with_args(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} - bound_params = {"param1": "value4"} - - tools = await toolbox_client.aload_toolset( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py deleted file mode 100644 index 751005af..00000000 --- a/packages/toolbox-langchain/tests/test_tools.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import Mock - -import pytest -from pydantic import BaseModel - -from toolbox_langchain.async_tools import AsyncToolboxTool -from toolbox_langchain.tools import ToolboxTool - - -class TestToolboxTool: - @pytest.fixture - def tool_schema(self): - return { - "description": "Test Tool Description", - "name": "test_tool", - "parameters": [ - {"name": "param1", "type": "string", "description": "Param 1"}, - {"name": "param2", "type": "integer", "description": "Param 2"}, - ], - } - - @pytest.fixture - def auth_tool_schema(self): - return { - "description": "Test Tool Description", - "name": "test_tool", - "parameters": [ - { - "name": "param1", - "type": "string", - "description": "Param 1", - "authSources": ["test-auth-source"], - }, - {"name": "param2", "type": "integer", "description": "Param 2"}, - ], - } - - @pytest.fixture(scope="function") - def mock_async_tool(self, tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_token_getters = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool - - @pytest.fixture(scope="function") - def mock_async_auth_tool(self, auth_tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = auth_tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_token_getters = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool - - @pytest.fixture - def toolbox_tool(self, mock_async_tool): - return ToolboxTool( - async_tool=mock_async_tool, - loop=Mock(), - thread=Mock(), - ) - - @pytest.fixture - def auth_toolbox_tool(self, mock_async_auth_tool): - return ToolboxTool( - async_tool=mock_async_auth_tool, - loop=Mock(), - thread=Mock(), - ) - - def test_toolbox_tool_init(self, mock_async_tool): - tool = ToolboxTool( - async_tool=mock_async_tool, - loop=Mock(), - thread=Mock(), - ) - async_tool = tool._ToolboxTool__async_tool - assert async_tool.name == mock_async_tool.name - assert async_tool.description == mock_async_tool.description - assert async_tool.args_schema == mock_async_tool.args_schema - - @pytest.mark.parametrize( - "params, expected_bound_params", - [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), - ], - ) - def test_toolbox_tool_bind_params( - self, - params, - expected_bound_params, - toolbox_tool, - mock_async_tool, - ): - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_params - mock_async_tool.bind_params.return_value = mock_async_tool - - tool = toolbox_tool.bind_params(params) - mock_async_tool.bind_params.assert_called_once_with(params, True) - assert isinstance(tool, ToolboxTool) - - for key, value in expected_bound_params.items(): - async_tool_bound_param_val = ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params[key] - ) - if callable(value): - assert value() == async_tool_bound_param_val() - else: - assert value == async_tool_bound_param_val - - def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): - expected_bound_param = {"param1": "bound-value"} - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_param - mock_async_tool.bind_param.return_value = mock_async_tool - - tool = toolbox_tool.bind_param("param1", "bound-value") - mock_async_tool.bind_param.assert_called_once_with( - "param1", "bound-value", True - ) - - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params - == expected_bound_param - ) - assert isinstance(tool, ToolboxTool) - - @pytest.mark.parametrize( - "auth_token_getters, expected_auth_token_getters", - [ - ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), - ( - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - ), - ], - ) - def test_toolbox_tool_add_auth_token_getters( - self, - auth_token_getters, - expected_auth_token_getters, - mock_async_auth_tool, - auth_toolbox_tool, - ): - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( - expected_auth_token_getters - ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getters.return_value = ( - mock_async_auth_tool - ) - - tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) - mock_async_auth_tool.add_auth_token_getters.assert_called_once_with( - auth_token_getters, True - ) - for source, getter in expected_auth_token_getters.items(): - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ - source - ]() - == getter() - ) - assert isinstance(tool, ToolboxTool) - - def test_toolbox_tool_add_auth_token_getter( - self, mock_async_auth_tool, auth_toolbox_tool - ): - get_id_token = lambda: "test-token" - expected_auth_token_getters = {"test-auth-source": get_id_token} - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( - expected_auth_token_getters - ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getter.return_value = ( - mock_async_auth_tool - ) - - tool = auth_toolbox_tool.add_auth_token_getter("test-auth-source", get_id_token) - mock_async_auth_tool.add_auth_token_getter.assert_called_once_with( - "test-auth-source", get_id_token, True - ) - - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ - "test-auth-source" - ]() - == "test-token" - ) - assert isinstance(tool, ToolboxTool) - - def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - auth_toolbox_tool._ToolboxTool__async_tool._arun = Mock( - side_effect=PermissionError( - "Parameter(s) `param1` of tool test_tool require authentication" - ) - ) - with pytest.raises(PermissionError) as e: - auth_toolbox_tool._run() - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value - ) From c92e5dd401e26b06378f84cdc26bef8f93815e27 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 8 May 2025 20:38:18 +0530 Subject: [PATCH 13/53] chore: Remove unused strict flag + fix default values + fix docstring --- .../toolbox-core/src/toolbox_core/sync_client.py | 1 - .../src/toolbox_langchain/async_client.py | 13 +++++++------ .../src/toolbox_langchain/tools.py | 3 +-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/packages/toolbox-core/src/toolbox_core/sync_client.py b/packages/toolbox-core/src/toolbox_core/sync_client.py index 96d4bba7..7ec94a44 100644 --- a/packages/toolbox-core/src/toolbox_core/sync_client.py +++ b/packages/toolbox-core/src/toolbox_core/sync_client.py @@ -56,7 +56,6 @@ def __init__( async def create_client(): return ToolboxClient(url, client_headers=client_headers) - # Ignoring type since we're already checking the existence of a loop above. self.__async_client = asyncio.run_coroutine_threadsafe( create_client(), self.__class__.__loop ).result() diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py index 2e1053a3..0b423f7a 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py @@ -103,7 +103,7 @@ async def aload_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[AsyncToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -118,9 +118,11 @@ async def aload_toolset( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. @@ -170,7 +172,6 @@ def load_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> AsyncToolboxTool: raise NotImplementedError("Synchronous methods not supported by async client.") @@ -181,6 +182,6 @@ def load_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[AsyncToolboxTool]: raise NotImplementedError("Synchronous methods not supported by async client.") diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index fb3d6ef0..21eb630a 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -63,7 +63,7 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: def add_auth_token_getters( - self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True + self, auth_token_getters: dict[str, Callable[[], str]] ) -> "ToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding @@ -131,7 +131,6 @@ def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], - strict: bool = True, ) -> "ToolboxTool": """ Registers a value or a function to retrieve the value for a given bound From d8fd735e32cafd57abd074fa3fa778c17de7d2b8 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 9 May 2025 10:30:17 +0530 Subject: [PATCH 14/53] fix: Update package to be from git repo --- packages/toolbox-langchain/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index a66987c8..503e216a 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,7 +9,8 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ - "toolbox-core>=0.1.0,<1.0.0", + # "toolbox-core>=0.1.0,<1.0.0", + "toolbox-core=git+https://github.com/googleapis/mcp-toolbox-sdk-python.git#egg=toolbox-core&subdirectory=packages/toolbox-core "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.7.0,<3.0.0", From 595eebed03c997a0e94f9fa175f5c066afa2d0e6 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 9 May 2025 15:39:30 +0530 Subject: [PATCH 15/53] fix: Fix toolbox-core package local path --- packages/toolbox-langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index 503e216a..53bf79fb 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] dependencies = [ # "toolbox-core>=0.1.0,<1.0.0", - "toolbox-core=git+https://github.com/googleapis/mcp-toolbox-sdk-python.git#egg=toolbox-core&subdirectory=packages/toolbox-core + "toolbox-core @ file:../toolbox-core", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.7.0,<3.0.0", From 5f587d41391d6fb36abcf8b7b0b6f84e3e01b0a6 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 9 May 2025 15:50:09 +0530 Subject: [PATCH 16/53] fix: Fix local package path --- packages/toolbox-langchain/pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index 53bf79fb..9800af52 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,8 +9,7 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ - # "toolbox-core>=0.1.0,<1.0.0", - "toolbox-core @ file:../toolbox-core", + "toolbox-core @ file:./packages/toolbox-core", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.7.0,<3.0.0", From 6f14d52b4f069e6218ea1811e647f7d777637f32 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 9 May 2025 21:36:49 +0530 Subject: [PATCH 17/53] fix: Update git path --- packages/toolbox-langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index 9800af52..5cf99cc6 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,7 +9,7 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ - "toolbox-core @ file:./packages/toolbox-core", + "toolbox-core @ git+https://github.com/googleapis/mcp-toolbox-sdk-python.git@anubhav-lc-wraps-core#subdirectory=packages/toolbox-core", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.7.0,<3.0.0", From 9d1ae03dd1b297f0c27b2c0be05ae18f1570abd2 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 9 May 2025 22:03:22 +0530 Subject: [PATCH 18/53] fix: Fix tests --- packages/toolbox-langchain/tests/test_e2e.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 689d8c40..6bb8e827 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -71,7 +71,7 @@ async def test_aload_toolset_specific( toolset = await toolbox.aload_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_sync_tool.__name__ assert name in expected_tools async def test_aload_toolset_all(self, toolbox): @@ -85,7 +85,7 @@ async def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_sync_tool.__name__ assert name in tool_names async def test_run_tool_async(self, get_n_rows_tool): From dc3a4ff0c8069b8af986aca33385f3f3a2bd8c90 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 9 May 2025 23:37:44 +0530 Subject: [PATCH 19/53] fix: Fix using correct object for fetching loop --- packages/toolbox-langchain/src/toolbox_langchain/tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 21eb630a..8f49263d 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -53,12 +53,12 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: coro = self.__core_sync_tool._ToolboxSyncTool__async_tool(**kwargs) # If a loop has not been provided, attempt to run in current thread. - if not self.__core_sync_client._ToolboxSyncClient__loop: + if not self.__core_sync_tool._ToolboxSyncTool__loop: return await coro # Otherwise, run in the background thread. await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncTool__loop) + asyncio.run_coroutine_threadsafe(coro, self.__core_sync_tool._ToolboxSyncTool__loop) ) From c8e61540361d0cd5aae6ca23416adc2b480601c4 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 9 May 2025 23:43:23 +0530 Subject: [PATCH 20/53] fix: Return invoke result --- packages/toolbox-langchain/src/toolbox_langchain/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 8f49263d..b947edeb 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -57,7 +57,7 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: return await coro # Otherwise, run in the background thread. - await asyncio.wrap_future( + return await asyncio.wrap_future( asyncio.run_coroutine_threadsafe(coro, self.__core_sync_tool._ToolboxSyncTool__loop) ) From b05884010c6c89a786e27c7dd016fc31feb663aa Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 00:07:20 +0530 Subject: [PATCH 21/53] fix: Integration test errors --- packages/toolbox-langchain/tests/test_e2e.py | 49 +++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 6bb8e827..eed872e4 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -114,11 +114,13 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool): @pytest.mark.asyncio async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" - tool = await toolbox.aload_tool( - "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} - ) - response = await tool.ainvoke({"id": "2"}) - assert "row2" in response + with pytest.raises( + ValueError, + match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", + ): + await toolbox.aload_tool( + "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} + ) async def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" @@ -127,7 +129,7 @@ async def test_run_tool_no_auth(self, toolbox): ) with pytest.raises( PermissionError, - match="Tool get-row-by-id-auth requires authentication, but no valid authentication sources are registered. Please register the required sources before use.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool.ainvoke({"id": "2"}) @@ -138,8 +140,8 @@ async def test_run_tool_wrong_auth(self, toolbox, auth_token2): ) auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) with pytest.raises( - ToolException, - match="{'status': 'Unauthorized', 'error': 'tool invocation not authorized. Please make sure your specify correct auth headers'}", + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", ): await auth_tool.ainvoke({"id": "2"}) @@ -157,7 +159,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): tool = await toolbox.aload_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool.ainvoke({"email": ""}) @@ -179,12 +181,11 @@ async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): auth_token_getters={"my-test-auth": lambda: auth_token1}, ) with pytest.raises( - ToolException, - match="{'status': 'Bad Request', 'error': 'provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims'}", + Exception, + match="provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims" ): await tool.ainvoke({}) - @pytest.mark.usefixtures("toolbox_server") class TestE2EClientSync: @pytest.fixture(scope="session") @@ -213,7 +214,7 @@ def test_load_toolset_specific( toolset = toolbox.load_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_sync_tool.__name__ assert name in expected_tools def test_aload_toolset_all(self, toolbox): @@ -227,7 +228,7 @@ def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_sync_tool.__name__ assert name in tool_names @pytest.mark.asyncio @@ -256,11 +257,13 @@ def test_run_tool_wrong_param_type(self, get_n_rows_tool): #### Auth tests def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" - tool = toolbox.load_tool( - "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} - ) - response = tool.invoke({"id": "2"}) - assert "row2" in response + with pytest.raises( + ValueError, + match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", + ): + toolbox.load_tool( + "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} + ) def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" @@ -269,7 +272,7 @@ def test_run_tool_no_auth(self, toolbox): ) with pytest.raises( PermissionError, - match="Tool get-row-by-id-auth requires authentication, but no valid authentication sources are registered. Please register the required sources before use.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool.invoke({"id": "2"}) @@ -281,7 +284,7 @@ def test_run_tool_wrong_auth(self, toolbox, auth_token2): auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) with pytest.raises( ToolException, - match="{'status': 'Unauthorized', 'error': 'tool invocation not authorized. Please make sure your specify correct auth headers'}", + match="tool invocation not authorized. Please make sure your specify correct auth headers", ): auth_tool.invoke({"id": "2"}) @@ -299,7 +302,7 @@ def test_run_tool_param_auth_no_auth(self, toolbox): tool = toolbox.load_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool.invoke({"email": ""}) @@ -322,6 +325,6 @@ def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): ) with pytest.raises( ToolException, - match="{'status': 'Bad Request', 'error': 'provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims'}", + match="provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims", ): tool.invoke({}) From b83b8e34d739b5d5ce3a995c959fc67995e7d5f7 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 00:08:23 +0530 Subject: [PATCH 22/53] chore: Delint --- .../src/toolbox_langchain/async_client.py | 6 ++-- .../src/toolbox_langchain/async_tools.py | 3 -- .../src/toolbox_langchain/client.py | 32 +++++++++++++------ .../src/toolbox_langchain/tools.py | 11 ++++--- packages/toolbox-langchain/tests/test_e2e.py | 11 ++++--- 5 files changed, 38 insertions(+), 25 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py index 0b423f7a..95e384c8 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py @@ -16,9 +16,9 @@ from warnings import warn from aiohttp import ClientSession +from toolbox_core.client import ToolboxClient as ToolboxCoreClient from .async_tools import AsyncToolboxTool -from toolbox_core.client import ToolboxClient as ToolboxCoreClient # This class is an internal implementation detail and is not exposed to the @@ -92,7 +92,7 @@ async def aload_tool( core_tool = await self.__core_client.load_tool( name=tool_name, auth_token_getters=auth_token_getters, - bound_params=bound_params + bound_params=bound_params, ) return AsyncToolboxTool(core_tool=core_tool) @@ -157,7 +157,7 @@ async def aload_toolset( name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, - strict=strict + strict=strict, ) tools = [] diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index f2f26433..d3a9955c 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -18,7 +18,6 @@ from toolbox_core.tool import ToolboxTool as ToolboxCoreTool - # This class is an internal implementation detail and is not exposed to the # end-user. It should not be used directly by external code. Changes to this # class will not be considered breaking changes to the public API. @@ -64,8 +63,6 @@ async def _arun(self, **kwargs: Any) -> str: """ return await self.__core_tool(**kwargs) - - def add_auth_token_getters( self, auth_token_getters: dict[str, Callable[[], str]] ) -> "AsyncToolboxTool": diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 646abbba..f36bdc7a 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -13,13 +13,13 @@ # limitations under the License. import asyncio -from warnings import warn from typing import Any, Callable, Optional, Union +from warnings import warn -from .tools import ToolboxTool from toolbox_core.sync_client import ToolboxSyncClient as ToolboxCoreSyncClient from toolbox_core.sync_tool import ToolboxSyncTool +from .tools import ToolboxTool class ToolboxClient: @@ -88,7 +88,7 @@ async def aload_tool( coro = self.__core_sync_client._ToolboxSyncClient__async_client.load_tool( name=tool_name, auth_token_getters=auth_token_getters, - bound_params=bound_params + bound_params=bound_params, ) if not self.__core_sync_client._ToolboxSyncClient__loop: @@ -97,10 +97,16 @@ async def aload_tool( else: # Otherwise, run in the background thread. core_tool = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) + asyncio.run_coroutine_threadsafe( + coro, self.__core_sync_client._ToolboxSyncClient__loop + ) ) - core_sync_tool = ToolboxSyncTool(core_tool, self.__core_sync_client._ToolboxSyncClient__loop, self.__core_sync_client._ToolboxSyncClient__thread) + core_sync_tool = ToolboxSyncTool( + core_tool, + self.__core_sync_client._ToolboxSyncClient__loop, + self.__core_sync_client._ToolboxSyncClient__thread, + ) return ToolboxTool(core_sync_tool=core_sync_tool) async def aload_toolset( @@ -164,7 +170,7 @@ async def aload_toolset( name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, - strict=strict + strict=strict, ) if not self.__core_sync_client._ToolboxSyncClient__loop: @@ -173,11 +179,17 @@ async def aload_toolset( else: # Otherwise, run in the background thread. core_tools = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._ToolboxSyncClient__loop) + asyncio.run_coroutine_threadsafe( + coro, self.__core_sync_client._ToolboxSyncClient__loop + ) ) core_sync_tools = [ - ToolboxSyncTool(core_tool, self.__core_sync_client._ToolboxSyncClient__loop, self.__core_sync_client._ToolboxSyncClient__thread) + ToolboxSyncTool( + core_tool, + self.__core_sync_client._ToolboxSyncClient__loop, + self.__core_sync_client._ToolboxSyncClient__thread, + ) for core_tool in core_tools ] tools = [] @@ -237,7 +249,7 @@ def load_tool( core_sync_tool = self.__core_sync_client.load_tool( name=tool_name, auth_token_getters=auth_token_getters, - bound_params=bound_params + bound_params=bound_params, ) return ToolboxTool(core_sync_tool=core_sync_tool) @@ -302,7 +314,7 @@ def load_toolset( name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, - strict=strict + strict=strict, ) tools = [] diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index b947edeb..505ff04c 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -19,7 +19,6 @@ from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool - class ToolboxTool(BaseTool): """ A subclass of LangChain's BaseTool that supports features specific to @@ -58,10 +57,11 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: # Otherwise, run in the background thread. return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_sync_tool._ToolboxSyncTool__loop) + asyncio.run_coroutine_threadsafe( + coro, self.__core_sync_tool._ToolboxSyncTool__loop + ) ) - def add_auth_token_getters( self, auth_token_getters: dict[str, Callable[[], str]] ) -> "ToolboxTool": @@ -81,10 +81,11 @@ def add_auth_token_getters( ValueError: If any of the provided auth parameters is already registered. """ - new_core_sync_tool = self.__core_sync_tool.add_auth_token_getters(auth_token_getters) + new_core_sync_tool = self.__core_sync_tool.add_auth_token_getters( + auth_token_getters + ) return ToolboxTool(core_sync_tool=new_core_sync_tool) - def add_auth_token_getter( self, auth_source: str, get_id_token: Callable[[], str] ) -> "ToolboxTool": diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index eed872e4..78606928 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -119,7 +119,8 @@ async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", ): await toolbox.aload_tool( - "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, ) async def test_run_tool_no_auth(self, toolbox): @@ -182,10 +183,11 @@ async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): ) with pytest.raises( Exception, - match="provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims" + match='provided parameters were invalid: error parsing authenticated parameter "data": no field named row_data in claims', ): await tool.ainvoke({}) + @pytest.mark.usefixtures("toolbox_server") class TestE2EClientSync: @pytest.fixture(scope="session") @@ -262,7 +264,8 @@ def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", ): toolbox.load_tool( - "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, ) def test_run_tool_no_auth(self, toolbox): @@ -325,6 +328,6 @@ def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): ) with pytest.raises( ToolException, - match="provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims", + match='provided parameters were invalid: error parsing authenticated parameter "data": no field named row_data in claims', ): tool.invoke({}) From b771dc4ee4e20ec4d1d97ee91d6e66f6f2d80b85 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 00:31:52 +0530 Subject: [PATCH 23/53] fix: Fix integration test --- packages/toolbox-langchain/tests/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 78606928..8792efdd 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -286,7 +286,7 @@ def test_run_tool_wrong_auth(self, toolbox, auth_token2): ) auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) with pytest.raises( - ToolException, + Exception, match="tool invocation not authorized. Please make sure your specify correct auth headers", ): auth_tool.invoke({"id": "2"}) From c2a4e64613c71b1bf6041c7d815d19c867e69ece Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 00:34:16 +0530 Subject: [PATCH 24/53] fix: Fix integration tests --- packages/toolbox-langchain/tests/test_e2e.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 8792efdd..7c9b417f 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -36,7 +36,6 @@ import pytest import pytest_asyncio -from langchain_core.tools import ToolException from pydantic import ValidationError from toolbox_langchain.client import ToolboxClient @@ -327,7 +326,7 @@ def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): auth_token_getters={"my-test-auth": lambda: auth_token1}, ) with pytest.raises( - ToolException, + Exception, match='provided parameters were invalid: error parsing authenticated parameter "data": no field named row_data in claims', ): tool.invoke({}) From 2d917a03c0b8bf5aaf8b62df2fa5f0c8ea992859 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 00:37:18 +0530 Subject: [PATCH 25/53] chore: Add unit tests previously deleted --- .../tests/test_async_client.py | 193 ++++++++++++ .../tests/test_async_tools.py | 274 ++++++++++++++++++ .../toolbox-langchain/tests/test_client.py | 259 +++++++++++++++++ .../toolbox-langchain/tests/test_tools.py | 238 +++++++++++++++ 4 files changed, 964 insertions(+) create mode 100644 packages/toolbox-langchain/tests/test_async_client.py create mode 100644 packages/toolbox-langchain/tests/test_async_tools.py create mode 100644 packages/toolbox-langchain/tests/test_client.py create mode 100644 packages/toolbox-langchain/tests/test_tools.py diff --git a/packages/toolbox-langchain/tests/test_async_client.py b/packages/toolbox-langchain/tests/test_async_client.py new file mode 100644 index 00000000..7b3d38c9 --- /dev/null +++ b/packages/toolbox-langchain/tests/test_async_client.py @@ -0,0 +1,193 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, patch +from warnings import catch_warnings, simplefilter + +import pytest +from aiohttp import ClientSession + +from toolbox_langchain.async_client import AsyncToolboxClient +from toolbox_langchain.async_tools import AsyncToolboxTool +from toolbox_core.protocol import ManifestSchema + +URL = "http://test_url" +MANIFEST_JSON = { + "serverVersion": "1.0.0", + "tools": { + "test_tool_1": { + "description": "Test Tool 1 Description", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + } + ], + }, + "test_tool_2": { + "description": "Test Tool 2 Description", + "parameters": [ + { + "name": "param2", + "type": "integer", + "description": "Param 2", + } + ], + }, + }, +} + + +@pytest.mark.asyncio +class TestAsyncToolboxClient: + @pytest.fixture() + def manifest_schema(self): + return ManifestSchema(**MANIFEST_JSON) + + @pytest.fixture() + def mock_session(self): + return AsyncMock(spec=ClientSession) + + @pytest.fixture() + def mock_client(self, mock_session): + return AsyncToolboxClient(URL, session=mock_session) + + async def test_create_with_existing_session(self, mock_client, mock_session): + assert mock_client._AsyncToolboxClient__session == mock_session + + @patch("toolbox_langchain.async_client._load_manifest") + async def test_aload_tool( + self, mock_load_manifest, mock_client, mock_session, manifest_schema + ): + tool_name = "test_tool_1" + mock_load_manifest.return_value = manifest_schema + + tool = await mock_client.aload_tool(tool_name) + + mock_load_manifest.assert_called_once_with( + f"{URL}/api/tool/{tool_name}", mock_session + ) + assert isinstance(tool, AsyncToolboxTool) + assert tool.name == tool_name + + @patch("toolbox_langchain.async_client._load_manifest") + async def test_aload_tool_auth_headers_deprecated( + self, mock_load_manifest, mock_client, manifest_schema + ): + tool_name = "test_tool_1" + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, auth_headers={"Authorization": lambda: "Bearer token"} + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + @patch("toolbox_langchain.async_client._load_manifest") + async def test_aload_tool_auth_headers_and_tokens( + self, mock_load_manifest, mock_client, manifest_schema + ): + tool_name = "test_tool_1" + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, + auth_headers={"Authorization": lambda: "Bearer token"}, + auth_token_getters={"test": lambda: "token"}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + @patch("toolbox_langchain.async_client._load_manifest") + async def test_aload_toolset( + self, mock_load_manifest, mock_client, mock_session, manifest_schema + ): + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + tools = await mock_client.aload_toolset() + + mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) + assert len(tools) == 2 + for tool in tools: + assert isinstance(tool, AsyncToolboxTool) + assert tool.name in ["test_tool_1", "test_tool_2"] + + @patch("toolbox_langchain.async_client._load_manifest") + async def test_aload_toolset_with_toolset_name( + self, mock_load_manifest, mock_client, mock_session, manifest_schema + ): + toolset_name = "test_toolset_1" + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + tools = await mock_client.aload_toolset(toolset_name=toolset_name) + + mock_load_manifest.assert_called_once_with( + f"{URL}/api/toolset/{toolset_name}", mock_session + ) + assert len(tools) == 2 + for tool in tools: + assert isinstance(tool, AsyncToolboxTool) + assert tool.name in ["test_tool_1", "test_tool_2"] + + @patch("toolbox_langchain.async_client._load_manifest") + async def test_aload_toolset_auth_headers_deprecated( + self, mock_load_manifest, mock_client, manifest_schema + ): + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_headers={"Authorization": lambda: "Bearer token"} + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + @patch("toolbox_langchain.async_client._load_manifest") + async def test_aload_toolset_auth_headers_and_tokens( + self, mock_load_manifest, mock_client, manifest_schema + ): + mock_manifest = manifest_schema + mock_load_manifest.return_value = mock_manifest + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_headers={"Authorization": lambda: "Bearer token"}, + auth_token_getters={"test": lambda: "token"}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + + async def test_load_tool_not_implemented(self, mock_client): + with pytest.raises(NotImplementedError) as excinfo: + mock_client.load_tool("test_tool") + assert "Synchronous methods not supported by async client." in str( + excinfo.value + ) + + async def test_load_toolset_not_implemented(self, mock_client): + with pytest.raises(NotImplementedError) as excinfo: + mock_client.load_toolset() + assert "Synchronous methods not supported by async client." in str( + excinfo.value + ) diff --git a/packages/toolbox-langchain/tests/test_async_tools.py b/packages/toolbox-langchain/tests/test_async_tools.py new file mode 100644 index 00000000..e23aee85 --- /dev/null +++ b/packages/toolbox-langchain/tests/test_async_tools.py @@ -0,0 +1,274 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import pytest_asyncio +from pydantic import ValidationError + +from toolbox_langchain.async_tools import AsyncToolboxTool + + +@pytest.mark.asyncio +class TestAsyncToolboxTool: + @pytest.fixture + def tool_schema(self): + return { + "description": "Test Tool Description", + "parameters": [ + {"name": "param1", "type": "string", "description": "Param 1"}, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture + def auth_tool_schema(self): + return { + "description": "Test Tool Description", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + "authSources": ["test-auth-source"], + }, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest_asyncio.fixture + @patch("aiohttp.ClientSession") + async def toolbox_tool(self, MockClientSession, tool_schema): + mock_session = MockClientSession.return_value + mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + tool = AsyncToolboxTool( + name="test_tool", + schema=tool_schema, + url="http://test_url", + session=mock_session, + ) + return tool + + @pytest_asyncio.fixture + @patch("aiohttp.ClientSession") + async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema): + mock_session = MockClientSession.return_value + mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() + mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"result": "test-result"} + ) + with pytest.warns( + UserWarning, + match=r"Parameter\(s\) `param1` of tool test_tool require authentication", + ): + tool = AsyncToolboxTool( + name="test_tool", + schema=auth_tool_schema, + url="https://test-url", + session=mock_session, + ) + return tool + + @patch("aiohttp.ClientSession") + async def test_toolbox_tool_init(self, MockClientSession, tool_schema): + mock_session = MockClientSession.return_value + tool = AsyncToolboxTool( + name="test_tool", + schema=tool_schema, + url="https://test-url", + session=mock_session, + ) + assert tool.name == "test_tool" + assert tool.description == "Test Tool Description" + + @pytest.mark.parametrize( + "params, expected_bound_params", + [ + ({"param1": "bound-value"}, {"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), + ( + {"param1": "bound-value", "param2": 123}, + {"param1": "bound-value", "param2": 123}, + ), + ], + ) + async def test_toolbox_tool_bind_params( + self, toolbox_tool, params, expected_bound_params + ): + tool = toolbox_tool.bind_params(params) + for key, value in expected_bound_params.items(): + if callable(value): + assert value() == tool._AsyncToolboxTool__bound_params[key]() + else: + assert value == tool._AsyncToolboxTool__bound_params[key] + + @pytest.mark.parametrize("strict", [True, False]) + async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool, strict): + if strict: + with pytest.raises(ValueError) as e: + tool = toolbox_tool.bind_params( + {"param3": "bound-value"}, strict=strict + ) + assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) + else: + with pytest.warns(UserWarning) as record: + tool = toolbox_tool.bind_params( + {"param3": "bound-value"}, strict=strict + ) + assert len(record) == 1 + assert "Parameter(s) param3 missing and cannot be bound." in str( + record[0].message + ) + + async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): + tool = toolbox_tool.bind_params({"param1": "bound-value"}) + with pytest.raises(ValueError) as e: + tool = tool.bind_params({"param1": "bound-value"}) + assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( + e.value + ) + + async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): + with pytest.raises(ValueError) as e: + auth_toolbox_tool.bind_params({"param1": "bound-value"}) + assert "Parameter(s) param1 already authenticated and cannot be bound." in str( + e.value + ) + + @pytest.mark.parametrize( + "auth_token_getters, expected_auth_token_getters", + [ + ( + {"test-auth-source": lambda: "test-token"}, + {"test-auth-source": lambda: "test-token"}, + ), + ( + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + ), + ], + ) + async def test_toolbox_tool_add_auth_token_getters( + self, auth_toolbox_tool, auth_token_getters, expected_auth_token_getters + ): + tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) + for source, getter in expected_auth_token_getters.items(): + assert tool._AsyncToolboxTool__auth_token_getters[source]() == getter() + + async def test_toolbox_tool_add_auth_token_getters_duplicate( + self, auth_toolbox_tool + ): + tool = auth_toolbox_tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} + ) + with pytest.raises(ValueError) as e: + tool = tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} + ) + assert ( + "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." + in str(e.value) + ) + + async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): + with pytest.raises(PermissionError) as e: + auth_toolbox_tool._AsyncToolboxTool__validate_auth(strict=True) + assert "Parameter(s) `param1` of tool test_tool require authentication" in str( + e.value + ) + + async def test_toolbox_tool_call(self, toolbox_tool): + result = await toolbox_tool.ainvoke({"param1": "test-value", "param2": 123}) + assert result == "test-result" + toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "http://test_url/api/tool/test_tool/invoke", + json={"param1": "test-value", "param2": 123}, + headers={}, + ) + + @pytest.mark.parametrize( + "bound_param, expected_value", + [ + ({"param1": "bound-value"}, "bound-value"), + ({"param1": lambda: "dynamic-value"}, "dynamic-value"), + ], + ) + async def test_toolbox_tool_call_with_bound_params( + self, toolbox_tool, bound_param, expected_value + ): + tool = toolbox_tool.bind_params(bound_param) + result = await tool.ainvoke({"param2": 123}) + assert result == "test-result" + toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "http://test_url/api/tool/test_tool/invoke", + json={"param1": expected_value, "param2": 123}, + headers={}, + ) + + async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): + tool = auth_toolbox_tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool.ainvoke({"param2": 123}) + assert result == "test-result" + auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "https://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) + + async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool): + with pytest.warns( + UserWarning, + match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", + ): + auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" + tool = auth_toolbox_tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool.ainvoke({"param2": 123}) + assert result == "test-result" + auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + "http://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) + + async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): + with pytest.raises(ValidationError) as e: + await toolbox_tool.ainvoke({"param1": 123, "param2": "invalid"}) + assert "2 validation errors for test_tool" in str(e.value) + assert "param1\n Input should be a valid string" in str(e.value) + assert "param2\n Input should be a valid integer" in str(e.value) + + async def test_toolbox_tool_call_with_empty_input(self, toolbox_tool): + with pytest.raises(ValidationError) as e: + await toolbox_tool.ainvoke({}) + assert "2 validation errors for test_tool" in str(e.value) + assert "param1\n Field required" in str(e.value) + assert "param2\n Field required" in str(e.value) + + async def test_toolbox_tool_run_not_implemented(self, toolbox_tool): + with pytest.raises(NotImplementedError): + toolbox_tool._run() diff --git a/packages/toolbox-langchain/tests/test_client.py b/packages/toolbox-langchain/tests/test_client.py new file mode 100644 index 00000000..62999019 --- /dev/null +++ b/packages/toolbox-langchain/tests/test_client.py @@ -0,0 +1,259 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock, patch + +import pytest +from pydantic import BaseModel + +from toolbox_langchain.client import ToolboxClient +from toolbox_langchain.tools import ToolboxTool + +URL = "http://test_url" + + +class TestToolboxClient: + @pytest.fixture() + def toolbox_client(self): + client = ToolboxClient(URL) + assert isinstance(client, ToolboxClient) + assert client._ToolboxClient__async_client is not None + + # Check that the background loop was created and started + assert client._ToolboxClient__loop is not None + assert client._ToolboxClient__loop.is_running() + + return client + + @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") + def test_load_tool(self, mock_aload_tool, toolbox_client): + mock_tool = Mock(spec=ToolboxTool) + mock_tool.name = "mock-tool" + mock_tool.description = "mock description" + mock_tool.args_schema = BaseModel + mock_aload_tool.return_value = mock_tool + + tool = toolbox_client.load_tool("test_tool") + + assert tool.name == mock_tool.name + assert tool.description == mock_tool.description + assert tool.args_schema == mock_tool.args_schema + mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) + + @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") + def test_load_toolset(self, mock_aload_toolset, toolbox_client): + mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] + mock_tools[0].name = "mock-tool-0" + mock_tools[0].description = "mock description 0" + mock_tools[0].args_schema = BaseModel + mock_tools[1].name = "mock-tool-1" + mock_tools[1].description = "mock description 1" + mock_tools[1].args_schema = BaseModel + mock_aload_toolset.return_value = mock_tools + + tools = toolbox_client.load_toolset() + assert len(tools) == len(mock_tools) + assert all( + a.name == b.name + and a.description == b.description + and a.args_schema == b.args_schema + for a, b in zip(tools, mock_tools) + ) + mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) + + @pytest.mark.asyncio + @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") + async def test_aload_tool(self, mock_aload_tool, toolbox_client): + mock_tool = Mock(spec=ToolboxTool) + mock_tool.name = "mock-tool" + mock_tool.description = "mock description" + mock_tool.args_schema = BaseModel + mock_aload_tool.return_value = mock_tool + + tool = await toolbox_client.aload_tool("test_tool") + assert tool.name == mock_tool.name + assert tool.description == mock_tool.description + assert tool.args_schema == mock_tool.args_schema + mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) + + @pytest.mark.asyncio + @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") + async def test_aload_toolset(self, mock_aload_toolset, toolbox_client): + mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] + mock_tools[0].name = "mock-tool-0" + mock_tools[0].description = "mock description 0" + mock_tools[0].args_schema = BaseModel + mock_tools[1].name = "mock-tool-1" + mock_tools[1].description = "mock description 1" + mock_tools[1].args_schema = BaseModel + mock_aload_toolset.return_value = mock_tools + + tools = await toolbox_client.aload_toolset() + assert len(tools) == len(mock_tools) + assert all( + a.name == b.name + and a.description == b.description + and a.args_schema == b.args_schema + for a, b in zip(tools, mock_tools) + ) + mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) + + @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") + def test_load_tool_with_args(self, mock_aload_tool, toolbox_client): + mock_tool = Mock(spec=ToolboxTool) + mock_tool.name = "mock-tool" + mock_tool.description = "mock description" + mock_tool.args_schema = BaseModel + mock_aload_tool.return_value = mock_tool + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens = {"token1": lambda: "value2"} + auth_headers = {"header1": lambda: "value3"} + bound_params = {"param1": "value4"} + + tool = toolbox_client.load_tool( + "test_tool_name", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens, + auth_headers=auth_headers, + bound_params=bound_params, + strict=False, + ) + + assert tool.name == mock_tool.name + assert tool.description == mock_tool.description + assert tool.args_schema == mock_tool.args_schema + mock_aload_tool.assert_called_once_with( + "test_tool_name", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, + ) + + @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") + def test_load_toolset_with_args(self, mock_aload_toolset, toolbox_client): + mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] + mock_tools[0].name = "mock-tool-0" + mock_tools[0].description = "mock description 0" + mock_tools[0].args_schema = BaseModel + mock_tools[1].name = "mock-tool-1" + mock_tools[1].description = "mock description 1" + mock_tools[1].args_schema = BaseModel + mock_aload_toolset.return_value = mock_tools + + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens = {"token1": lambda: "value2"} + auth_headers = {"header1": lambda: "value3"} + bound_params = {"param1": "value4"} + + tools = toolbox_client.load_toolset( + toolset_name="my_toolset", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens, + auth_headers=auth_headers, + bound_params=bound_params, + strict=False, + ) + + assert len(tools) == len(mock_tools) + assert all( + a.name == b.name + and a.description == b.description + and a.args_schema == b.args_schema + for a, b in zip(tools, mock_tools) + ) + mock_aload_toolset.assert_called_once_with( + "my_toolset", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, + ) + + @pytest.mark.asyncio + @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") + async def test_aload_tool_with_args(self, mock_aload_tool, toolbox_client): + mock_tool = Mock(spec=ToolboxTool) + mock_tool.name = "mock-tool" + mock_tool.description = "mock description" + mock_tool.args_schema = BaseModel + mock_aload_tool.return_value = mock_tool + + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens = {"token1": lambda: "value2"} + auth_headers = {"header1": lambda: "value3"} + bound_params = {"param1": "value4"} + + tool = await toolbox_client.aload_tool( + "test_tool", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, + ) + assert tool.name == mock_tool.name + assert tool.description == mock_tool.description + assert tool.args_schema == mock_tool.args_schema + mock_aload_tool.assert_called_once_with( + "test_tool", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, + ) + + @pytest.mark.asyncio + @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") + async def test_aload_toolset_with_args(self, mock_aload_toolset, toolbox_client): + mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] + mock_tools[0].name = "mock-tool-0" + mock_tools[0].description = "mock description 0" + mock_tools[0].args_schema = BaseModel + mock_tools[1].name = "mock-tool-1" + mock_tools[1].description = "mock description 1" + mock_tools[1].args_schema = BaseModel + mock_aload_toolset.return_value = mock_tools + + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens = {"token1": lambda: "value2"} + auth_headers = {"header1": lambda: "value3"} + bound_params = {"param1": "value4"} + + tools = await toolbox_client.aload_toolset( + "my_toolset", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, + ) + assert len(tools) == len(mock_tools) + assert all( + a.name == b.name + and a.description == b.description + and a.args_schema == b.args_schema + for a, b in zip(tools, mock_tools) + ) + mock_aload_toolset.assert_called_once_with( + "my_toolset", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, + ) diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py new file mode 100644 index 00000000..751005af --- /dev/null +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -0,0 +1,238 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +import pytest +from pydantic import BaseModel + +from toolbox_langchain.async_tools import AsyncToolboxTool +from toolbox_langchain.tools import ToolboxTool + + +class TestToolboxTool: + @pytest.fixture + def tool_schema(self): + return { + "description": "Test Tool Description", + "name": "test_tool", + "parameters": [ + {"name": "param1", "type": "string", "description": "Param 1"}, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture + def auth_tool_schema(self): + return { + "description": "Test Tool Description", + "name": "test_tool", + "parameters": [ + { + "name": "param1", + "type": "string", + "description": "Param 1", + "authSources": ["test-auth-source"], + }, + {"name": "param2", "type": "integer", "description": "Param 2"}, + ], + } + + @pytest.fixture(scope="function") + def mock_async_tool(self, tool_schema): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool.name = "test_tool" + mock_async_tool.description = "test description" + mock_async_tool.args_schema = BaseModel + mock_async_tool._AsyncToolboxTool__name = "test_tool" + mock_async_tool._AsyncToolboxTool__schema = tool_schema + mock_async_tool._AsyncToolboxTool__url = "http://test_url" + mock_async_tool._AsyncToolboxTool__session = Mock() + mock_async_tool._AsyncToolboxTool__auth_token_getters = {} + mock_async_tool._AsyncToolboxTool__bound_params = {} + return mock_async_tool + + @pytest.fixture(scope="function") + def mock_async_auth_tool(self, auth_tool_schema): + mock_async_tool = Mock(spec=AsyncToolboxTool) + mock_async_tool.name = "test_tool" + mock_async_tool.description = "test description" + mock_async_tool.args_schema = BaseModel + mock_async_tool._AsyncToolboxTool__name = "test_tool" + mock_async_tool._AsyncToolboxTool__schema = auth_tool_schema + mock_async_tool._AsyncToolboxTool__url = "http://test_url" + mock_async_tool._AsyncToolboxTool__session = Mock() + mock_async_tool._AsyncToolboxTool__auth_token_getters = {} + mock_async_tool._AsyncToolboxTool__bound_params = {} + return mock_async_tool + + @pytest.fixture + def toolbox_tool(self, mock_async_tool): + return ToolboxTool( + async_tool=mock_async_tool, + loop=Mock(), + thread=Mock(), + ) + + @pytest.fixture + def auth_toolbox_tool(self, mock_async_auth_tool): + return ToolboxTool( + async_tool=mock_async_auth_tool, + loop=Mock(), + thread=Mock(), + ) + + def test_toolbox_tool_init(self, mock_async_tool): + tool = ToolboxTool( + async_tool=mock_async_tool, + loop=Mock(), + thread=Mock(), + ) + async_tool = tool._ToolboxTool__async_tool + assert async_tool.name == mock_async_tool.name + assert async_tool.description == mock_async_tool.description + assert async_tool.args_schema == mock_async_tool.args_schema + + @pytest.mark.parametrize( + "params, expected_bound_params", + [ + ({"param1": "bound-value"}, {"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), + ( + {"param1": "bound-value", "param2": 123}, + {"param1": "bound-value", "param2": 123}, + ), + ], + ) + def test_toolbox_tool_bind_params( + self, + params, + expected_bound_params, + toolbox_tool, + mock_async_tool, + ): + mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_params + mock_async_tool.bind_params.return_value = mock_async_tool + + tool = toolbox_tool.bind_params(params) + mock_async_tool.bind_params.assert_called_once_with(params, True) + assert isinstance(tool, ToolboxTool) + + for key, value in expected_bound_params.items(): + async_tool_bound_param_val = ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params[key] + ) + if callable(value): + assert value() == async_tool_bound_param_val() + else: + assert value == async_tool_bound_param_val + + def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): + expected_bound_param = {"param1": "bound-value"} + mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_param + mock_async_tool.bind_param.return_value = mock_async_tool + + tool = toolbox_tool.bind_param("param1", "bound-value") + mock_async_tool.bind_param.assert_called_once_with( + "param1", "bound-value", True + ) + + assert ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params + == expected_bound_param + ) + assert isinstance(tool, ToolboxTool) + + @pytest.mark.parametrize( + "auth_token_getters, expected_auth_token_getters", + [ + ( + {"test-auth-source": lambda: "test-token"}, + {"test-auth-source": lambda: "test-token"}, + ), + ( + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + { + "test-auth-source": lambda: "test-token", + "another-auth-source": lambda: "another-token", + }, + ), + ], + ) + def test_toolbox_tool_add_auth_token_getters( + self, + auth_token_getters, + expected_auth_token_getters, + mock_async_auth_tool, + auth_toolbox_tool, + ): + auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( + expected_auth_token_getters + ) + auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getters.return_value = ( + mock_async_auth_tool + ) + + tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) + mock_async_auth_tool.add_auth_token_getters.assert_called_once_with( + auth_token_getters, True + ) + for source, getter in expected_auth_token_getters.items(): + assert ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ + source + ]() + == getter() + ) + assert isinstance(tool, ToolboxTool) + + def test_toolbox_tool_add_auth_token_getter( + self, mock_async_auth_tool, auth_toolbox_tool + ): + get_id_token = lambda: "test-token" + expected_auth_token_getters = {"test-auth-source": get_id_token} + auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( + expected_auth_token_getters + ) + auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getter.return_value = ( + mock_async_auth_tool + ) + + tool = auth_toolbox_tool.add_auth_token_getter("test-auth-source", get_id_token) + mock_async_auth_tool.add_auth_token_getter.assert_called_once_with( + "test-auth-source", get_id_token, True + ) + + assert ( + tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ + "test-auth-source" + ]() + == "test-token" + ) + assert isinstance(tool, ToolboxTool) + + def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): + auth_toolbox_tool._ToolboxTool__async_tool._arun = Mock( + side_effect=PermissionError( + "Parameter(s) `param1` of tool test_tool require authentication" + ) + ) + with pytest.raises(PermissionError) as e: + auth_toolbox_tool._run() + assert "Parameter(s) `param1` of tool test_tool require authentication" in str( + e.value + ) From 395b2a19f20ef152b0f551ae73787a77937b767f Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 01:00:52 +0530 Subject: [PATCH 26/53] chore: Delint --- packages/toolbox-langchain/tests/test_async_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/tests/test_async_client.py b/packages/toolbox-langchain/tests/test_async_client.py index 7b3d38c9..015a5fa3 100644 --- a/packages/toolbox-langchain/tests/test_async_client.py +++ b/packages/toolbox-langchain/tests/test_async_client.py @@ -17,10 +17,10 @@ import pytest from aiohttp import ClientSession +from toolbox_core.protocol import ManifestSchema from toolbox_langchain.async_client import AsyncToolboxClient from toolbox_langchain.async_tools import AsyncToolboxTool -from toolbox_core.protocol import ManifestSchema URL = "http://test_url" MANIFEST_JSON = { From 6cc15d2ccdcbfa94ef692e2e45e47563c35df2f3 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 01:17:20 +0530 Subject: [PATCH 27/53] fix: Fix using correct protected member variables --- .../src/toolbox_langchain/client.py | 20 +++++++++---------- .../src/toolbox_langchain/tools.py | 8 ++++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index f36bdc7a..018da726 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -85,27 +85,27 @@ async def aload_tool( ) auth_token_getters = auth_tokens - coro = self.__core_sync_client._ToolboxSyncClient__async_client.load_tool( + coro = self.__core_sync_client._async_client.load_tool( name=tool_name, auth_token_getters=auth_token_getters, bound_params=bound_params, ) - if not self.__core_sync_client._ToolboxSyncClient__loop: + if not self.__core_sync_client._loop: # If a loop has not been provided, attempt to run in current thread. core_tool = await coro else: # Otherwise, run in the background thread. core_tool = await asyncio.wrap_future( asyncio.run_coroutine_threadsafe( - coro, self.__core_sync_client._ToolboxSyncClient__loop + coro, self.__core_sync_client._loop ) ) core_sync_tool = ToolboxSyncTool( core_tool, - self.__core_sync_client._ToolboxSyncClient__loop, - self.__core_sync_client._ToolboxSyncClient__thread, + self.__core_sync_client._loop, + self.__core_sync_client._thread, ) return ToolboxTool(core_sync_tool=core_sync_tool) @@ -166,29 +166,29 @@ async def aload_toolset( ) auth_token_getters = auth_tokens - coro = self.__core_sync_client._ToolboxSyncClient__async_client.load_toolset( + coro = self.__core_sync_client._async_client.load_toolset( name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, strict=strict, ) - if not self.__core_sync_client._ToolboxSyncClient__loop: + if not self.__core_sync_client._loop: # If a loop has not been provided, attempt to run in current thread. core_tools = await coro else: # Otherwise, run in the background thread. core_tools = await asyncio.wrap_future( asyncio.run_coroutine_threadsafe( - coro, self.__core_sync_client._ToolboxSyncClient__loop + coro, self.__core_sync_client._loop ) ) core_sync_tools = [ ToolboxSyncTool( core_tool, - self.__core_sync_client._ToolboxSyncClient__loop, - self.__core_sync_client._ToolboxSyncClient__thread, + self.__core_sync_client._loop, + self.__core_sync_client._thread, ) for core_tool in core_tools ] diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 505ff04c..b3c7b135 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -41,7 +41,7 @@ def __init__( super().__init__( name=core_sync_tool.__name__, description=core_sync_tool.__doc__, - args_schema=core_sync_tool._ToolboxSyncTool__async_tool._ToolboxTool__pydantic_model, + args_schema=core_sync_tool._async_tool._pydantic_model, ) self.__core_sync_tool = core_sync_tool @@ -49,16 +49,16 @@ def _run(self, **kwargs: Any) -> dict[str, Any]: return self.__core_sync_tool(**kwargs) async def _arun(self, **kwargs: Any) -> dict[str, Any]: - coro = self.__core_sync_tool._ToolboxSyncTool__async_tool(**kwargs) + coro = self.__core_sync_tool._async_tool(**kwargs) # If a loop has not been provided, attempt to run in current thread. - if not self.__core_sync_tool._ToolboxSyncTool__loop: + if not self.__core_sync_tool._loop: return await coro # Otherwise, run in the background thread. return await asyncio.wrap_future( asyncio.run_coroutine_threadsafe( - coro, self.__core_sync_tool._ToolboxSyncTool__loop + coro, self.__core_sync_tool._loop ) ) From 162538936cd253ad6368daf4ae2d789d275cc068 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 01:18:51 +0530 Subject: [PATCH 28/53] chore: Delint --- .../toolbox-langchain/src/toolbox_langchain/client.py | 8 ++------ packages/toolbox-langchain/src/toolbox_langchain/tools.py | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 018da726..5255840a 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -97,9 +97,7 @@ async def aload_tool( else: # Otherwise, run in the background thread. core_tool = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe( - coro, self.__core_sync_client._loop - ) + asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._loop) ) core_sync_tool = ToolboxSyncTool( @@ -179,9 +177,7 @@ async def aload_toolset( else: # Otherwise, run in the background thread. core_tools = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe( - coro, self.__core_sync_client._loop - ) + asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._loop) ) core_sync_tools = [ diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index b3c7b135..df97677a 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -57,9 +57,7 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: # Otherwise, run in the background thread. return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe( - coro, self.__core_sync_tool._loop - ) + asyncio.run_coroutine_threadsafe(coro, self.__core_sync_tool._loop) ) def add_auth_token_getters( From 6c224ac4e4bdadd39a84da8bcc5d8962b421be97 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 01:33:42 +0530 Subject: [PATCH 29/53] chore: Fix types --- .../toolbox-langchain/src/toolbox_langchain/async_tools.py | 2 +- packages/toolbox-langchain/src/toolbox_langchain/tools.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index d3a9955c..06fdc6fc 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -43,7 +43,7 @@ def __init__( super().__init__( name=core_tool.__name__, description=core_tool.__doc__, - args_schema=core_tool._ToolboxTool__pydantic_model, + args_schema=core_tool._pydantic_model, ) self.__core_tool = core_tool diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index df97677a..d18529ea 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -45,10 +45,10 @@ def __init__( ) self.__core_sync_tool = core_sync_tool - def _run(self, **kwargs: Any) -> dict[str, Any]: + def _run(self, **kwargs: Any) -> str: return self.__core_sync_tool(**kwargs) - async def _arun(self, **kwargs: Any) -> dict[str, Any]: + async def _arun(self, **kwargs: Any) -> str: coro = self.__core_sync_tool._async_tool(**kwargs) # If a loop has not been provided, attempt to run in current thread. From f8cfe3a70b303618cd713dd487d9db34a03ae269 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 01:37:20 +0530 Subject: [PATCH 30/53] fix: Ensure bg loop/thread not null --- packages/toolbox-langchain/src/toolbox_langchain/client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 5255840a..5125212e 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -100,6 +100,9 @@ async def aload_tool( asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._loop) ) + if not self.__core_sync_client._loop or not self.__core_sync_client._thread: + raise ValueError("Background loop or thread cannot be None.") + core_sync_tool = ToolboxSyncTool( core_tool, self.__core_sync_client._loop, From fb4d742d6ec31e99a9e3ccc8f90b1438aae00f05 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 01:38:06 +0530 Subject: [PATCH 31/53] fix: Check bg loop/thread value --- packages/toolbox-langchain/src/toolbox_langchain/client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 5125212e..2f708f9b 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -183,6 +183,9 @@ async def aload_toolset( asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._loop) ) + if not self.__core_sync_client._loop or not self.__core_sync_client._thread: + raise ValueError("Background loop or thread cannot be None.") + core_sync_tools = [ ToolboxSyncTool( core_tool, From 36d093bfdf21b5b21b121ff0b6b8ad7d1e8a4991 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 02:45:41 +0530 Subject: [PATCH 32/53] fix: Revert warnings to prefer auth_tokens over auth_headers --- .../src/toolbox_langchain/client.py | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 2f708f9b..72317fd3 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -59,31 +59,31 @@ async def aload_tool( Returns: A tool loaded from the Toolbox. """ - if auth_headers: + if auth_tokens: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_headers + auth_token_getters = auth_tokens - if auth_tokens: + if auth_headers: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_tokens + auth_token_getters = auth_headers coro = self.__core_sync_client._async_client.load_tool( name=tool_name, @@ -141,31 +141,31 @@ async def aload_toolset( Returns: A list of all tools loaded from the Toolbox. """ - if auth_headers: + if auth_tokens: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_headers + auth_token_getters = auth_tokens - if auth_tokens: + if auth_headers: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_tokens + auth_token_getters = auth_headers coro = self.__core_sync_client._async_client.load_toolset( name=toolset_name, @@ -222,31 +222,31 @@ def load_tool( Returns: A tool loaded from the Toolbox. """ - if auth_headers: + if auth_tokens: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_headers + auth_token_getters = auth_tokens - if auth_tokens: + if auth_headers: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_tokens + auth_token_getters = auth_headers core_sync_tool = self.__core_sync_client.load_tool( name=tool_name, @@ -286,31 +286,31 @@ def load_toolset( Returns: A list of all tools loaded from the Toolbox. """ - if auth_headers: + if auth_tokens: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_headers + auth_token_getters = auth_tokens - if auth_tokens: + if auth_headers: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_tokens + auth_token_getters = auth_headers core_sync_tools = self.__core_sync_client.load_toolset( name=toolset_name, From 248b1db341265496100b2bd0db88d122f69048b2 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 03:15:19 +0530 Subject: [PATCH 33/53] chore: Update unittests --- .../tests/test_async_client.py | 173 +++++-- .../tests/test_async_tools.py | 328 ++++++++----- .../toolbox-langchain/tests/test_client.py | 442 ++++++++++-------- .../toolbox-langchain/tests/test_tools.py | 232 +++++---- 4 files changed, 692 insertions(+), 483 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_async_client.py b/packages/toolbox-langchain/tests/test_async_client.py index 015a5fa3..988d3974 100644 --- a/packages/toolbox-langchain/tests/test_async_client.py +++ b/packages/toolbox-langchain/tests/test_async_client.py @@ -17,7 +17,11 @@ import pytest from aiohttp import ClientSession +from toolbox_core.client import ToolboxClient as ToolboxCoreClient from toolbox_core.protocol import ManifestSchema +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool +from toolbox_core.utils import params_to_pydantic_model from toolbox_langchain.async_client import AsyncToolboxClient from toolbox_langchain.async_tools import AsyncToolboxTool @@ -60,123 +64,200 @@ def manifest_schema(self): def mock_session(self): return AsyncMock(spec=ClientSession) + @pytest.fixture + def mock_core_client_instance(self, manifest_schema, mock_session): + mock = AsyncMock(spec=ToolboxCoreClient) + + async def mock_load_tool_impl(name, auth_token_getters, bound_params): + tool_schema_dict = MANIFEST_JSON["tools"].get(name) + if not tool_schema_dict: + raise ValueError(f"Tool '{name}' not in mock manifest_dict") + + core_params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + # Return a mock that looks like toolbox_core.tool.ToolboxTool + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) + core_tool_mock.__name__ = name + core_tool_mock.__doc__ = tool_schema_dict["description"] + core_tool_mock._pydantic_model = params_to_pydantic_model(name, core_params) + # Add other necessary attributes or method mocks if AsyncToolboxTool uses them + return core_tool_mock + + mock.load_tool = AsyncMock(side_effect=mock_load_tool_impl) + + async def mock_load_toolset_impl( + name, auth_token_getters, bound_params, strict + ): + core_tools_list = [] + for tool_name_iter, tool_schema_dict in MANIFEST_JSON["tools"].items(): + core_params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) + core_tool_mock.__name__ = tool_name_iter + core_tool_mock.__doc__ = tool_schema_dict["description"] + core_tool_mock._pydantic_model = params_to_pydantic_model( + tool_name_iter, core_params + ) + core_tools_list.append(core_tool_mock) + return core_tools_list + + mock.load_toolset = AsyncMock(side_effect=mock_load_toolset_impl) + # Mock the session attribute if it's directly accessed by AsyncToolboxClient tests + mock._ToolboxClient__session = mock_session + return mock + @pytest.fixture() - def mock_client(self, mock_session): - return AsyncToolboxClient(URL, session=mock_session) + def mock_client(self, mock_session, mock_core_client_instance): + # Patch the ToolboxCoreClient constructor used by AsyncToolboxClient + with patch( + "toolbox_langchain.async_client.ToolboxCoreClient", + return_value=mock_core_client_instance, + ): + client = AsyncToolboxClient(URL, session=mock_session) + # Ensure the mocked core client is used + client._AsyncToolboxClient__core_client = mock_core_client_instance + return client async def test_create_with_existing_session(self, mock_client, mock_session): - assert mock_client._AsyncToolboxClient__session == mock_session + # AsyncToolboxClient stores the core_client, which stores the session + assert ( + mock_client._AsyncToolboxClient__core_client._ToolboxClient__session + == mock_session + ) - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_tool( - self, mock_load_manifest, mock_client, mock_session, manifest_schema + self, + mock_client, + manifest_schema, # mock_session removed as it's part of mock_core_client_instance ): tool_name = "test_tool_1" - mock_load_manifest.return_value = manifest_schema + # manifest_schema is used by mock_core_client_instance fixture to provide tool details tool = await mock_client.aload_tool(tool_name) - mock_load_manifest.assert_called_once_with( - f"{URL}/api/tool/{tool_name}", mock_session + # Assert that the core client's load_tool was called correctly + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters={}, bound_params={} ) assert isinstance(tool, AsyncToolboxTool) - assert tool.name == tool_name + assert ( + tool.name == tool_name + ) # AsyncToolboxTool gets its name from the core_tool - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_tool_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema + self, mock_client, manifest_schema ): tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_lambda = lambda: "Bearer token" # Define lambda once with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( - tool_name, auth_headers={"Authorization": lambda: "Bearer token"} + tool_name, + auth_headers={"Authorization": auth_lambda}, # Use the defined lambda ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) - @patch("toolbox_langchain.async_client._load_manifest") + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, + auth_token_getters={"Authorization": auth_lambda}, + bound_params={}, + ) + async def test_aload_tool_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema + self, mock_client, manifest_schema ): tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_getters = {"test": lambda: "token"} + auth_headers_lambda = lambda: "Bearer token" # Define lambda once + with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( tool_name, - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_token_getters={"test": lambda: "token"}, + auth_headers={ + "Authorization": auth_headers_lambda + }, # Use defined lambda + auth_token_getters=auth_getters, ) - assert len(w) == 1 + assert ( + len(w) == 1 + ) # Only one warning because auth_token_getters takes precedence assert issubclass(w[-1].category, DeprecationWarning) - assert "auth_headers" in str(w[-1].message) + assert "auth_headers" in str(w[-1].message) # Warning for auth_headers + + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters=auth_getters, bound_params={} + ) - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_toolset( - self, mock_load_manifest, mock_client, mock_session, manifest_schema + self, mock_client, manifest_schema # mock_session removed ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest tools = await mock_client.aload_toolset() - mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) - assert len(tools) == 2 + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False + ) + assert len(tools) == 2 # Based on MANIFEST_JSON for tool in tools: assert isinstance(tool, AsyncToolboxTool) assert tool.name in ["test_tool_1", "test_tool_2"] - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_toolset_with_toolset_name( - self, mock_load_manifest, mock_client, mock_session, manifest_schema + self, mock_client, manifest_schema # mock_session removed ): - toolset_name = "test_toolset_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + toolset_name = "test_toolset_1" # This name isn't in MANIFEST_JSON, but load_toolset mock doesn't filter by it tools = await mock_client.aload_toolset(toolset_name=toolset_name) - mock_load_manifest.assert_called_once_with( - f"{URL}/api/toolset/{toolset_name}", mock_session + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=toolset_name, auth_token_getters={}, bound_params={}, strict=False ) assert len(tools) == 2 for tool in tools: assert isinstance(tool, AsyncToolboxTool) assert tool.name in ["test_tool_1", "test_tool_2"] - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_toolset_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema + self, mock_client, manifest_schema ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_lambda = lambda: "Bearer token" # Define lambda once with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"} + auth_headers={"Authorization": auth_lambda} # Use defined lambda ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters={"Authorization": auth_lambda}, + bound_params={}, + strict=False, + ) - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_toolset_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema + self, mock_client, manifest_schema ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_getters = {"test": lambda: "token"} + auth_headers_lambda = lambda: "Bearer token" # Define lambda once with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_token_getters={"test": lambda: "token"}, + auth_headers={ + "Authorization": auth_headers_lambda + }, # Use defined lambda + auth_token_getters=auth_getters, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, auth_token_getters=auth_getters, bound_params={}, strict=False + ) async def test_load_tool_not_implemented(self, mock_client): with pytest.raises(NotImplementedError) as excinfo: diff --git a/packages/toolbox-langchain/tests/test_async_tools.py b/packages/toolbox-langchain/tests/test_async_tools.py index e23aee85..88efcd05 100644 --- a/packages/toolbox-langchain/tests/test_async_tools.py +++ b/packages/toolbox-langchain/tests/test_async_tools.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types # For MappingProxyType from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio from pydantic import ValidationError +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool from toolbox_langchain.async_tools import AsyncToolboxTool @@ -24,7 +27,7 @@ @pytest.mark.asyncio class TestAsyncToolboxTool: @pytest.fixture - def tool_schema(self): + def tool_schema_dict(self): return { "description": "Test Tool Description", "parameters": [ @@ -34,9 +37,10 @@ def tool_schema(self): } @pytest.fixture - def auth_tool_schema(self): + def auth_tool_schema_dict(self): return { "description": "Test Tool Description", + "authRequired": ["test-auth-source"], "parameters": [ { "name": "param1", @@ -48,133 +52,193 @@ def auth_tool_schema(self): ], } + def _create_core_tool_from_dict( + self, session, name, schema_dict, url, initial_auth_getters=None + ): + core_params_schemas = [ + CoreParameterSchema(**p) for p in schema_dict["parameters"] + ] + + tool_constructor_params = [] + required_authn_for_core = {} + for p_schema in core_params_schemas: + if p_schema.authSources: + required_authn_for_core[p_schema.name] = p_schema.authSources + else: + tool_constructor_params.append(p_schema) + + return ToolboxCoreTool( + session=session, + base_url=url, + name=name, + description=schema_dict["description"], + params=tool_constructor_params, + required_authn_params=types.MappingProxyType(required_authn_for_core), + required_authz_tokens=schema_dict.get("authRequired", []), + auth_service_token_getters=types.MappingProxyType( + initial_auth_getters or {} + ), + bound_params=types.MappingProxyType({}), + client_headers=types.MappingProxyType({}), + ) + @pytest_asyncio.fixture @patch("aiohttp.ClientSession") - async def toolbox_tool(self, MockClientSession, tool_schema): + async def toolbox_tool(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} - ) - tool = AsyncToolboxTool( + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"result": "test-result"}) + mock_response.status = 200 # *** Fix: Set status for the mock response *** + + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, name="test_tool", - schema=tool_schema, + schema_dict=tool_schema_dict, url="http://test_url", - session=mock_session, ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) return tool @pytest_asyncio.fixture @patch("aiohttp.ClientSession") - async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema): + async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema_dict): mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"result": "test-result"}) + mock_response.status = 200 # *** Fix: Set status for the mock response *** + + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, + name="test_tool", + schema_dict=auth_tool_schema_dict, + url="https://test-url", ) - with pytest.warns( - UserWarning, - match=r"Parameter\(s\) `param1` of tool test_tool require authentication", - ): - tool = AsyncToolboxTool( - name="test_tool", - schema=auth_tool_schema, - url="https://test-url", - session=mock_session, - ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) return tool @patch("aiohttp.ClientSession") - async def test_toolbox_tool_init(self, MockClientSession, tool_schema): + async def test_toolbox_tool_init(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - tool = AsyncToolboxTool( + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.status = 200 + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, name="test_tool", - schema=tool_schema, + schema_dict=tool_schema_dict, url="https://test-url", - session=mock_session, ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) assert tool.name == "test_tool" - assert tool.description == "Test Tool Description" + assert tool.description == core_tool_instance.__doc__ @pytest.mark.parametrize( - "params, expected_bound_params", + "params_to_bind", [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), + ({"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}), + ({"param1": "bound-value", "param2": 123}), ], ) - async def test_toolbox_tool_bind_params( - self, toolbox_tool, params, expected_bound_params - ): - tool = toolbox_tool.bind_params(params) - for key, value in expected_bound_params.items(): - if callable(value): - assert value() == tool._AsyncToolboxTool__bound_params[key]() - else: - assert value == tool._AsyncToolboxTool__bound_params[key] - - @pytest.mark.parametrize("strict", [True, False]) - async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool, strict): - if strict: - with pytest.raises(ValueError) as e: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) - else: - with pytest.warns(UserWarning) as record: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert len(record) == 1 - assert "Parameter(s) param3 missing and cannot be bound." in str( - record[0].message + async def test_toolbox_tool_bind_params(self, toolbox_tool, params_to_bind): + original_core_tool = toolbox_tool._AsyncToolboxTool__core_tool + with patch.object( + original_core_tool, "bind_params", wraps=original_core_tool.bind_params + ) as mock_core_bind_params: + new_langchain_tool = toolbox_tool.bind_params(params_to_bind) + mock_core_bind_params.assert_called_once_with(params_to_bind) + assert isinstance( + new_langchain_tool._AsyncToolboxTool__core_tool, ToolboxCoreTool ) + new_core_tool_signature_params = ( + new_langchain_tool._AsyncToolboxTool__core_tool.__signature__.parameters + ) + for bound_param_name in params_to_bind.keys(): + assert bound_param_name not in new_core_tool_signature_params + + async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool): + with pytest.raises( + ValueError, match="unable to bind parameters: no parameter named param3" + ): + toolbox_tool.bind_params({"param3": "bound-value"}) async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): tool = toolbox_tool.bind_params({"param1": "bound-value"}) - with pytest.raises(ValueError) as e: - tool = tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( - e.value - ) + with pytest.raises( + ValueError, + match="cannot re-bind parameter: parameter 'param1' is already bound", + ): + tool.bind_params({"param1": "bound-value"}) async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): - with pytest.raises(ValueError) as e: + auth_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + # Verify that 'param1' is not in the list of bindable parameters for the core tool + # because it requires authentication. + assert "param1" not in [p.name for p in auth_core_tool._ToolboxTool__params] + with pytest.raises( + ValueError, match="unable to bind parameters: no parameter named param1" + ): auth_toolbox_tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) param1 already authenticated and cannot be bound." in str( - e.value + + async def test_toolbox_tool_add_valid_auth_token_getter(self, auth_toolbox_tool): + get_token_lambda = lambda: "test-token-value" + original_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + with patch.object( + original_core_tool, + "add_auth_token_getters", + wraps=original_core_tool.add_auth_token_getters, + ) as mock_core_add_getters: + tool = auth_toolbox_tool.add_auth_token_getters( + {"test-auth-source": get_token_lambda} + ) + mock_core_add_getters.assert_called_once_with( + {"test-auth-source": get_token_lambda} + ) + core_tool_after_add = tool._AsyncToolboxTool__core_tool + assert ( + "test-auth-source" + in core_tool_after_add._ToolboxTool__auth_service_token_getters + ) + assert ( + core_tool_after_add._ToolboxTool__auth_service_token_getters[ + "test-auth-source" + ] + is get_token_lambda + ) + assert not core_tool_after_add._ToolboxTool__required_authn_params.get( + "param1" + ) + assert ( + "test-auth-source" + not in core_tool_after_add._ToolboxTool__required_authz_tokens + ) + + async def test_toolbox_tool_add_unused_auth_token_getter_raises_error( + self, auth_toolbox_tool + ): + unused_lambda = lambda: "another-token" + with pytest.raises(ValueError) as excinfo: + auth_toolbox_tool.add_auth_token_getters( + {"another-auth-source": unused_lambda} + ) + assert ( + "Authentication source(s) `another-auth-source` unused by tool `test_tool`" + in str(excinfo.value) ) - @pytest.mark.parametrize( - "auth_token_getters, expected_auth_token_getters", - [ - ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), - ( + valid_lambda = lambda: "test-token" + with pytest.raises(ValueError) as excinfo_mixed: + auth_toolbox_tool.add_auth_token_getters( { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - ), - ], - ) - async def test_toolbox_tool_add_auth_token_getters( - self, auth_toolbox_tool, auth_token_getters, expected_auth_token_getters - ): - tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) - for source, getter in expected_auth_token_getters.items(): - assert tool._AsyncToolboxTool__auth_token_getters[source]() == getter() + "test-auth-source": valid_lambda, + "another-auth-source": unused_lambda, + } + ) + assert ( + "Authentication source(s) `another-auth-source` unused by tool `test_tool`" + in str(excinfo_mixed.value) + ) async def test_toolbox_tool_add_auth_token_getters_duplicate( self, auth_toolbox_tool @@ -182,45 +246,44 @@ async def test_toolbox_tool_add_auth_token_getters_duplicate( tool = auth_toolbox_tool.add_auth_token_getters( {"test-auth-source": lambda: "test-token"} ) - with pytest.raises(ValueError) as e: - tool = tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - assert ( - "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." - in str(e.value) - ) + with pytest.raises( + ValueError, + match="Authentication source\\(s\\) `test-auth-source` already registered in tool `test_tool`\\.", + ): + tool.add_auth_token_getters({"test-auth-source": lambda: "test-token"}) - async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - with pytest.raises(PermissionError) as e: - auth_toolbox_tool._AsyncToolboxTool__validate_auth(strict=True) - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value - ) + async def test_toolbox_tool_call_requires_auth_strict(self, auth_toolbox_tool): + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: test-auth-source", + ): + await auth_toolbox_tool.ainvoke({"param2": 123}) async def test_toolbox_tool_call(self, toolbox_tool): result = await toolbox_tool.ainvoke({"param1": "test-value", "param2": 123}) assert result == "test-result" - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + core_tool = toolbox_tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": "test-value", "param2": 123}, headers={}, ) @pytest.mark.parametrize( - "bound_param, expected_value", + "bound_param_map, expected_value", [ ({"param1": "bound-value"}, "bound-value"), ({"param1": lambda: "dynamic-value"}, "dynamic-value"), ], ) async def test_toolbox_tool_call_with_bound_params( - self, toolbox_tool, bound_param, expected_value + self, toolbox_tool, bound_param_map, expected_value ): - tool = toolbox_tool.bind_params(bound_param) + tool = toolbox_tool.bind_params(bound_param_map) result = await tool.ainvoke({"param2": 123}) assert result == "test-result" - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + core_tool = tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": expected_value, "param2": 123}, headers={}, @@ -232,24 +295,51 @@ async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): ) result = await tool.ainvoke({"param2": 123}) assert result == "test-result" - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + core_tool = tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "https://test-url/api/tool/test_tool/invoke", json={"param2": 123}, headers={"test-auth-source_token": "test-token"}, ) - async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool): + async def test_toolbox_tool_call_with_auth_tokens_insecure( + self, auth_toolbox_tool, auth_tool_schema_dict + ): # Add auth_tool_schema_dict fixture + core_tool_of_auth_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + mock_session = core_tool_of_auth_tool._ToolboxTool__session + + # *** Fix: Use the injected fixture value auth_tool_schema_dict *** + insecure_core_tool = self._create_core_tool_from_dict( + session=mock_session, + name="test_tool", + schema_dict=auth_tool_schema_dict, # Use the fixture value here + url="http://test-url", + ) + insecure_auth_langchain_tool = AsyncToolboxTool(core_tool=insecure_core_tool) + with pytest.warns( UserWarning, match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", ): - auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" - tool = auth_toolbox_tool.add_auth_token_getters( + tool_with_getter = insecure_auth_langchain_tool.add_auth_token_getters( {"test-auth-source": lambda: "test-token"} ) - result = await tool.ainvoke({"param2": 123}) + result = await tool_with_getter.ainvoke({"param2": 123}) assert result == "test-result" - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + + modified_core_tool_in_new_tool = ( + tool_with_getter._AsyncToolboxTool__core_tool + ) + assert ( + modified_core_tool_in_new_tool._ToolboxTool__base_url + == "http://test-url" + ) + assert ( + modified_core_tool_in_new_tool._ToolboxTool__url + == "http://test-url/api/tool/test_tool/invoke" + ) + + modified_core_tool_in_new_tool._ToolboxTool__session.post.assert_called_once_with( "http://test-url/api/tool/test_tool/invoke", json={"param2": 123}, headers={"test-auth-source_token": "test-token"}, diff --git a/packages/toolbox-langchain/tests/test_client.py b/packages/toolbox-langchain/tests/test_client.py index 62999019..bae8123e 100644 --- a/packages/toolbox-langchain/tests/test_client.py +++ b/packages/toolbox-langchain/tests/test_client.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from pydantic import BaseModel +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool # For spec +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool # For spec from toolbox_langchain.client import ToolboxClient from toolbox_langchain.tools import ToolboxTool @@ -28,232 +30,292 @@ class TestToolboxClient: def toolbox_client(self): client = ToolboxClient(URL) assert isinstance(client, ToolboxClient) - assert client._ToolboxClient__async_client is not None + assert client._ToolboxClient__core_sync_client is not None + assert client._ToolboxClient__core_sync_client._async_client is not None + assert client._ToolboxClient__core_sync_client._loop is not None + assert client._ToolboxClient__core_sync_client._loop.is_running() + assert client._ToolboxClient__core_sync_client._thread is not None + assert client._ToolboxClient__core_sync_client._thread.is_alive() + return client - # Check that the background loop was created and started - assert client._ToolboxClient__loop is not None - assert client._ToolboxClient__loop.is_running() + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + def test_load_tool(self, mock_core_load_tool, toolbox_client): + mock_core_sync_tool_instance = Mock( + spec=ToolboxCoreSyncTool + ) # Spec with Core Sync Tool + mock_core_sync_tool_instance.__name__ = "mock-core-sync-tool" + mock_core_sync_tool_instance.__doc__ = "mock core sync description" - return client + mock_underlying_async_tool = Mock( + spec=ToolboxCoreTool + ) # Core Async Tool for pydantic model + mock_underlying_async_tool._pydantic_model = BaseModel + mock_core_sync_tool_instance._async_tool = mock_underlying_async_tool + + mock_core_load_tool.return_value = mock_core_sync_tool_instance + + langchain_tool = toolbox_client.load_tool("test_tool") + + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == mock_core_sync_tool_instance.__name__ + assert langchain_tool.description == mock_core_sync_tool_instance.__doc__ + assert langchain_tool.args_schema == mock_underlying_async_tool._pydantic_model + + mock_core_load_tool.assert_called_once_with( + name="test_tool", auth_token_getters={}, bound_params={} + ) + + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + def test_load_toolset(self, mock_core_load_toolset, toolbox_client): + mock_core_sync_tool_instance1 = Mock(spec=ToolboxCoreSyncTool) + mock_core_sync_tool_instance1.__name__ = "mock-core-sync-tool-0" + mock_core_sync_tool_instance1.__doc__ = "desc 0" + mock_async_tool0 = Mock(spec=ToolboxCoreTool) + mock_async_tool0._pydantic_model = BaseModel + mock_core_sync_tool_instance1._async_tool = mock_async_tool0 + + mock_core_sync_tool_instance2 = Mock(spec=ToolboxCoreSyncTool) + mock_core_sync_tool_instance2.__name__ = "mock-core-sync-tool-1" + mock_core_sync_tool_instance2.__doc__ = "desc 1" + mock_async_tool1 = Mock(spec=ToolboxCoreTool) + mock_async_tool1._pydantic_model = BaseModel + mock_core_sync_tool_instance2._async_tool = mock_async_tool1 + + mock_core_load_toolset.return_value = [ + mock_core_sync_tool_instance1, + mock_core_sync_tool_instance2, + ] + + langchain_tools = toolbox_client.load_toolset() + assert len(langchain_tools) == 2 + assert isinstance(langchain_tools[0], ToolboxTool) + assert isinstance(langchain_tools[1], ToolboxTool) + assert langchain_tools[0].name == "mock-core-sync-tool-0" + assert langchain_tools[1].name == "mock-core-sync-tool-1" - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - def test_load_tool(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - - tool = toolbox_client.load_tool("test_tool") - - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - tools = toolbox_client.load_toolset() - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) + mock_core_load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False ) - mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - - tool = await toolbox_client.aload_tool("test_tool") - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) + @patch("toolbox_core.client.ToolboxClient.load_tool") + async def test_aload_tool(self, mock_core_aload_tool, toolbox_client): + mock_core_tool_instance = AsyncMock( + spec=ToolboxCoreTool + ) # *** Use AsyncMock for async method return *** + mock_core_tool_instance.__name__ = "mock-core-async-tool" + mock_core_tool_instance.__doc__ = "mock core async description" + mock_core_tool_instance._pydantic_model = BaseModel + mock_core_aload_tool.return_value = mock_core_tool_instance + + langchain_tool = await toolbox_client.aload_tool("test_tool") + + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == mock_core_tool_instance.__name__ + assert langchain_tool.description == mock_core_tool_instance.__doc__ + + toolbox_client._ToolboxClient__core_sync_client._async_client.load_tool.assert_called_once_with( + name="test_tool", auth_token_getters={}, bound_params={} + ) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - async def test_aload_toolset(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - tools = await toolbox_client.aload_toolset() - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) + @patch("toolbox_core.client.ToolboxClient.load_toolset") + async def test_aload_toolset(self, mock_core_aload_toolset, toolbox_client): + mock_core_tool_instance1 = AsyncMock( + spec=ToolboxCoreTool + ) # *** Use AsyncMock *** + mock_core_tool_instance1.__name__ = "mock-core-async-tool-0" + mock_core_tool_instance1.__doc__ = "desc 0" + mock_core_tool_instance1._pydantic_model = BaseModel + + mock_core_tool_instance2 = AsyncMock( + spec=ToolboxCoreTool + ) # *** Use AsyncMock *** + mock_core_tool_instance2.__name__ = "mock-core-async-tool-1" + mock_core_tool_instance2.__doc__ = "desc 1" + mock_core_tool_instance2._pydantic_model = BaseModel + + mock_core_aload_toolset.return_value = [ + mock_core_tool_instance1, + mock_core_tool_instance2, + ] + + langchain_tools = await toolbox_client.aload_toolset() + assert len(langchain_tools) == 2 + assert isinstance(langchain_tools[0], ToolboxTool) + assert isinstance(langchain_tools[1], ToolboxTool) + + toolbox_client._ToolboxClient__core_sync_client._async_client.load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False ) - mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - def test_load_tool_with_args(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool + + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): + mock_core_sync_tool_instance = Mock(spec=ToolboxCoreSyncTool) + mock_core_sync_tool_instance.__name__ = "mock-tool" + mock_async_tool = Mock(spec=ToolboxCoreTool) + mock_async_tool._pydantic_model = BaseModel + mock_core_sync_tool_instance._async_tool = mock_async_tool + mock_core_load_tool.return_value = mock_core_sync_tool_instance + auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - tool = toolbox_client.load_tool( - "test_tool_name", + # Test case where auth_token_getters takes precedence + with pytest.warns(DeprecationWarning) as record: + tool = toolbox_client.load_tool( + "test_tool_name", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + # Expect two warnings: one for auth_tokens, one for auth_headers + assert len(record) == 2 + messages = [str(r.message) for r in record] + assert any("auth_tokens` is deprecated" in m for m in messages) + assert any("auth_headers` is deprecated" in m for m in messages) + + assert isinstance(tool, ToolboxTool) + mock_core_load_tool.assert_called_with( # Use called_with for flexibility if called multiple times in setup + name="test_tool_name", auth_token_getters=auth_token_getters, - auth_tokens=auth_tokens, - auth_headers=auth_headers, bound_params=bound_params, - strict=False, ) + mock_core_load_tool.reset_mock() # Reset for next test case + + # Test case where auth_tokens is used (auth_token_getters is None) + with pytest.warns(DeprecationWarning, match="auth_tokens` is deprecated"): + toolbox_client.load_tool( + "test_tool_name_2", + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, # This will also warn + bound_params=bound_params, + ) + mock_core_load_tool.assert_called_with( + name="test_tool_name_2", + auth_token_getters=auth_tokens_deprecated, # auth_tokens becomes auth_token_getters + bound_params=bound_params, + ) + mock_core_load_tool.reset_mock() - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with( - "test_tool_name", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + # Test case where auth_headers is used (auth_token_getters and auth_tokens are None) + with pytest.warns(DeprecationWarning, match="auth_headers` is deprecated"): + toolbox_client.load_tool( + "test_tool_name_3", + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + mock_core_load_tool.assert_called_with( + name="test_tool_name_3", + auth_token_getters=auth_headers_deprecated, # auth_headers becomes auth_token_getters + bound_params=bound_params, ) - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset_with_args(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + def test_load_toolset_with_args(self, mock_core_load_toolset, toolbox_client): + mock_core_sync_tool_instance = Mock(spec=ToolboxCoreSyncTool) + mock_core_sync_tool_instance.__name__ = "mock-tool-0" + mock_async_tool = Mock(spec=ToolboxCoreTool) + mock_async_tool._pydantic_model = BaseModel + mock_core_sync_tool_instance._async_tool = mock_async_tool + mock_core_load_toolset.return_value = [mock_core_sync_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - tools = toolbox_client.load_toolset( - toolset_name="my_toolset", + with pytest.warns(DeprecationWarning) as record: # Expect 2 warnings + tools = toolbox_client.load_toolset( + toolset_name="my_toolset", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + strict=False, + ) + assert len(record) == 2 + messages = [str(r.message) for r in record] + assert any("auth_tokens` is deprecated" in m for m in messages) + assert any("auth_headers` is deprecated" in m for m in messages) + + assert len(tools) == 1 + mock_core_load_toolset.assert_called_with( + name="my_toolset", auth_token_getters=auth_token_getters, - auth_tokens=auth_tokens, - auth_headers=auth_headers, bound_params=bound_params, strict=False, ) - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool_with_args(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool + @patch("toolbox_core.client.ToolboxClient.load_tool") + async def test_aload_tool_with_args(self, mock_core_aload_tool, toolbox_client): + mock_core_tool_instance = AsyncMock(spec=ToolboxCoreTool) + mock_core_tool_instance.__name__ = "mock-tool" + mock_core_tool_instance._pydantic_model = BaseModel + mock_core_aload_tool.return_value = mock_core_tool_instance auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - tool = await toolbox_client.aload_tool( - "test_tool", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with( - "test_tool", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + with pytest.warns(DeprecationWarning) as record: # Expect 2 warnings + tool = await toolbox_client.aload_tool( + "test_tool", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + assert len(record) == 2 + messages = [str(r.message) for r in record] + assert any("auth_tokens` is deprecated" in m for m in messages) + assert any("auth_headers` is deprecated" in m for m in messages) + + assert isinstance(tool, ToolboxTool) + toolbox_client._ToolboxClient__core_sync_client._async_client.load_tool.assert_called_with( + name="test_tool", + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - async def test_aload_toolset_with_args(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools + @patch("toolbox_core.client.ToolboxClient.load_toolset") + async def test_aload_toolset_with_args( + self, mock_core_aload_toolset, toolbox_client + ): + mock_core_tool_instance = AsyncMock(spec=ToolboxCoreTool) + mock_core_tool_instance.__name__ = "mock-tool-0" + mock_core_tool_instance._pydantic_model = BaseModel + mock_core_aload_toolset.return_value = [mock_core_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - tools = await toolbox_client.aload_toolset( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + with pytest.warns(DeprecationWarning) as record: # Expect 2 warnings + tools = await toolbox_client.aload_toolset( + "my_toolset", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + strict=False, + ) + assert len(record) == 2 + messages = [str(r.message) for r in record] + assert any("auth_tokens` is deprecated" in m for m in messages) + assert any("auth_headers` is deprecated" in m for m in messages) + + assert len(tools) == 1 + toolbox_client._ToolboxClient__core_sync_client._async_client.load_toolset.assert_called_with( + name="my_toolset", + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=False, ) diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 751005af..090f0f55 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -16,17 +16,17 @@ import pytest from pydantic import BaseModel +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool -from toolbox_langchain.async_tools import AsyncToolboxTool from toolbox_langchain.tools import ToolboxTool class TestToolboxTool: @pytest.fixture - def tool_schema(self): + def tool_schema_dict(self): return { "description": "Test Tool Description", - "name": "test_tool", "parameters": [ {"name": "param1", "type": "string", "description": "Param 1"}, {"name": "param2", "type": "integer", "description": "Param 2"}, @@ -34,10 +34,10 @@ def tool_schema(self): } @pytest.fixture - def auth_tool_schema(self): + def auth_tool_schema_dict(self): return { - "description": "Test Tool Description", - "name": "test_tool", + "description": "Test Auth Tool Description", + "authRequired": ["test-auth-source"], "parameters": [ { "name": "param1", @@ -50,62 +50,66 @@ def auth_tool_schema(self): } @pytest.fixture(scope="function") - def mock_async_tool(self, tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_token_getters = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool + def mock_core_async_tool(self, tool_schema_dict): + mock = Mock(spec=ToolboxCoreTool) + mock.__name__ = "test_tool" + mock.__doc__ = tool_schema_dict["description"] + mock._pydantic_model = BaseModel + return mock @pytest.fixture(scope="function") - def mock_async_auth_tool(self, auth_tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = auth_tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_token_getters = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool + def mock_core_async_auth_tool(self, auth_tool_schema_dict): + mock = Mock(spec=ToolboxCoreTool) + mock.__name__ = "test_auth_tool" + mock.__doc__ = auth_tool_schema_dict["description"] + mock._pydantic_model = BaseModel + return mock @pytest.fixture - def toolbox_tool(self, mock_async_tool): - return ToolboxTool( - async_tool=mock_async_tool, - loop=Mock(), - thread=Mock(), - ) + def mock_core_sync_tool(self, mock_core_async_tool): + sync_mock = Mock(spec=ToolboxCoreSyncTool) + sync_mock.__name__ = mock_core_async_tool.__name__ + sync_mock.__doc__ = mock_core_async_tool.__doc__ + sync_mock._async_tool = mock_core_async_tool + sync_mock.add_auth_token_getters = Mock(return_value=sync_mock) + sync_mock.bind_params = Mock(return_value=sync_mock) + sync_mock.bind_param = Mock( + return_value=sync_mock + ) # Keep this if bind_param exists on core, otherwise remove + sync_mock.__call__ = Mock(return_value="mocked_sync_call_result") + return sync_mock @pytest.fixture - def auth_toolbox_tool(self, mock_async_auth_tool): - return ToolboxTool( - async_tool=mock_async_auth_tool, - loop=Mock(), - thread=Mock(), - ) + def mock_core_sync_auth_tool(self, mock_core_async_auth_tool): + sync_mock = Mock(spec=ToolboxCoreSyncTool) + sync_mock.__name__ = mock_core_async_auth_tool.__name__ + sync_mock.__doc__ = mock_core_async_auth_tool.__doc__ + sync_mock._async_tool = mock_core_async_auth_tool + sync_mock.add_auth_token_getters = Mock(return_value=sync_mock) + sync_mock.bind_params = Mock(return_value=sync_mock) + sync_mock.bind_param = Mock( + return_value=sync_mock + ) # Keep this if bind_param exists on core + sync_mock.__call__ = Mock(return_value="mocked_auth_sync_call_result") + return sync_mock - def test_toolbox_tool_init(self, mock_async_tool): - tool = ToolboxTool( - async_tool=mock_async_tool, - loop=Mock(), - thread=Mock(), - ) - async_tool = tool._ToolboxTool__async_tool - assert async_tool.name == mock_async_tool.name - assert async_tool.description == mock_async_tool.description - assert async_tool.args_schema == mock_async_tool.args_schema + @pytest.fixture + def toolbox_tool(self, mock_core_sync_tool): + return ToolboxTool(core_sync_tool=mock_core_sync_tool) + + @pytest.fixture + def auth_toolbox_tool(self, mock_core_sync_auth_tool): + return ToolboxTool(core_sync_tool=mock_core_sync_auth_tool) + + def test_toolbox_tool_init(self, mock_core_sync_tool): + tool = ToolboxTool(core_sync_tool=mock_core_sync_tool) + core_sync_tool_in_tool = tool._ToolboxTool__core_sync_tool + assert core_sync_tool_in_tool.__name__ == mock_core_sync_tool.__name__ + assert core_sync_tool_in_tool.__doc__ == mock_core_sync_tool.__doc__ + assert tool.args_schema == mock_core_sync_tool._async_tool._pydantic_model @pytest.mark.parametrize( - "params, expected_bound_params", + "params, expected_bound_params_on_core", [ ({"param1": "bound-value"}, {"param1": "bound-value"}), ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), @@ -118,44 +122,35 @@ def test_toolbox_tool_init(self, mock_async_tool): def test_toolbox_tool_bind_params( self, params, - expected_bound_params, + expected_bound_params_on_core, toolbox_tool, - mock_async_tool, + mock_core_sync_tool, ): - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_params - mock_async_tool.bind_params.return_value = mock_async_tool - - tool = toolbox_tool.bind_params(params) - mock_async_tool.bind_params.assert_called_once_with(params, True) - assert isinstance(tool, ToolboxTool) - - for key, value in expected_bound_params.items(): - async_tool_bound_param_val = ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params[key] - ) - if callable(value): - assert value() == async_tool_bound_param_val() - else: - assert value == async_tool_bound_param_val - - def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): - expected_bound_param = {"param1": "bound-value"} - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_param - mock_async_tool.bind_param.return_value = mock_async_tool - - tool = toolbox_tool.bind_param("param1", "bound-value") - mock_async_tool.bind_param.assert_called_once_with( - "param1", "bound-value", True + mock_core_sync_tool.bind_params.return_value = mock_core_sync_tool + new_langchain_tool = toolbox_tool.bind_params(params) + mock_core_sync_tool.bind_params.assert_called_once_with(params) + assert isinstance(new_langchain_tool, ToolboxTool) + assert ( + new_langchain_tool._ToolboxTool__core_sync_tool + == mock_core_sync_tool.bind_params.return_value ) + def test_toolbox_tool_bind_param(self, toolbox_tool, mock_core_sync_tool): + # ToolboxTool.bind_param calls core_sync_tool.bind_params + mock_core_sync_tool.bind_params.return_value = mock_core_sync_tool + new_langchain_tool = toolbox_tool.bind_param("param1", "bound-value") + # *** Fix: Assert that bind_params is called on the core tool *** + mock_core_sync_tool.bind_params.assert_called_once_with( + {"param1": "bound-value"} + ) + assert isinstance(new_langchain_tool, ToolboxTool) assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params - == expected_bound_param + new_langchain_tool._ToolboxTool__core_sync_tool + == mock_core_sync_tool.bind_params.return_value ) - assert isinstance(tool, ToolboxTool) @pytest.mark.parametrize( - "auth_token_getters, expected_auth_token_getters", + "auth_token_getters, expected_auth_getters_on_core", [ ( {"test-auth-source": lambda: "test-token"}, @@ -176,63 +171,44 @@ def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): def test_toolbox_tool_add_auth_token_getters( self, auth_token_getters, - expected_auth_token_getters, - mock_async_auth_tool, + expected_auth_getters_on_core, auth_toolbox_tool, + mock_core_sync_auth_tool, ): - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( - expected_auth_token_getters + mock_core_sync_auth_tool.add_auth_token_getters.return_value = ( + mock_core_sync_auth_tool ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getters.return_value = ( - mock_async_auth_tool + new_langchain_tool = auth_toolbox_tool.add_auth_token_getters( + auth_token_getters ) - - tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) - mock_async_auth_tool.add_auth_token_getters.assert_called_once_with( - auth_token_getters, True + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + auth_token_getters + ) + assert isinstance(new_langchain_tool, ToolboxTool) + assert ( + new_langchain_tool._ToolboxTool__core_sync_tool + == mock_core_sync_auth_tool.add_auth_token_getters.return_value ) - for source, getter in expected_auth_token_getters.items(): - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ - source - ]() - == getter() - ) - assert isinstance(tool, ToolboxTool) def test_toolbox_tool_add_auth_token_getter( - self, mock_async_auth_tool, auth_toolbox_tool + self, auth_toolbox_tool, mock_core_sync_auth_tool ): get_id_token = lambda: "test-token" - expected_auth_token_getters = {"test-auth-source": get_id_token} - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( - expected_auth_token_getters - ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getter.return_value = ( - mock_async_auth_tool + # ToolboxTool.add_auth_token_getter calls core_sync_tool.add_auth_token_getters + mock_core_sync_auth_tool.add_auth_token_getters.return_value = ( + mock_core_sync_auth_tool ) - tool = auth_toolbox_tool.add_auth_token_getter("test-auth-source", get_id_token) - mock_async_auth_tool.add_auth_token_getter.assert_called_once_with( - "test-auth-source", get_id_token, True + new_langchain_tool = auth_toolbox_tool.add_auth_token_getter( + "test-auth-source", get_id_token ) - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ - "test-auth-source" - ]() - == "test-token" - ) - assert isinstance(tool, ToolboxTool) - - def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - auth_toolbox_tool._ToolboxTool__async_tool._arun = Mock( - side_effect=PermissionError( - "Parameter(s) `param1` of tool test_tool require authentication" - ) + # *** Fix: Assert that add_auth_token_getters is called on the core tool *** + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + {"test-auth-source": get_id_token} ) - with pytest.raises(PermissionError) as e: - auth_toolbox_tool._run() - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value + assert isinstance(new_langchain_tool, ToolboxTool) + assert ( + new_langchain_tool._ToolboxTool__core_sync_tool + == mock_core_sync_auth_tool.add_auth_token_getters.return_value ) From 8e2ed2964601715a8301219cd5c29c48a8434959 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 10:08:39 +0530 Subject: [PATCH 34/53] docs: Improve docstrings --- packages/toolbox-langchain/src/toolbox_langchain/async_tools.py | 2 +- packages/toolbox-langchain/src/toolbox_langchain/tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index 06fdc6fc..8bbcf500 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -99,7 +99,7 @@ def add_auth_token_getter( Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth token. + instance, with added auth token getter. Raises: ValueError: If the provided auth parameter is already registered. diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index d18529ea..29a99bff 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -97,7 +97,7 @@ def add_auth_token_getter( Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth token. + instance, with added auth token getter. Raises: ValueError: If the provided auth parameter is already registered. From 3bccd7abed8ebfc4aa10a6e63fea5205aa273fe2 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 10:30:58 +0530 Subject: [PATCH 35/53] chore: Add TODO note --- packages/toolbox-langchain/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index 5cf99cc6..c052ef2e 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,6 +9,7 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ + # TODO: Replace with actual package dependencies (eg. "toolbox-core>=0.2.0,<1.0.0") "toolbox-core @ git+https://github.com/googleapis/mcp-toolbox-sdk-python.git@anubhav-lc-wraps-core#subdirectory=packages/toolbox-core", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", From b169c4596857e6dc7eb9bc9bd51f8235bd451b1b Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 10 May 2025 10:31:24 +0530 Subject: [PATCH 36/53] chore: Improve TODO note --- packages/toolbox-langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index c052ef2e..629745c9 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,7 +9,7 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ - # TODO: Replace with actual package dependencies (eg. "toolbox-core>=0.2.0,<1.0.0") + # TODO: Replace with actual package dependency (eg. "toolbox-core>=0.2.0,<1.0.0") "toolbox-core @ git+https://github.com/googleapis/mcp-toolbox-sdk-python.git@anubhav-lc-wraps-core#subdirectory=packages/toolbox-core", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", From 28c62abbd11e14ea3450b39e51a73aec4acf8ed0 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Wed, 14 May 2025 18:40:45 +0530 Subject: [PATCH 37/53] fix: Fix integration test --- .../tests/test_async_tools.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_async_tools.py b/packages/toolbox-langchain/tests/test_async_tools.py index 88efcd05..d63b90a2 100644 --- a/packages/toolbox-langchain/tests/test_async_tools.py +++ b/packages/toolbox-langchain/tests/test_async_tools.py @@ -304,46 +304,46 @@ async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): async def test_toolbox_tool_call_with_auth_tokens_insecure( self, auth_toolbox_tool, auth_tool_schema_dict - ): # Add auth_tool_schema_dict fixture + ): core_tool_of_auth_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool mock_session = core_tool_of_auth_tool._ToolboxTool__session - # *** Fix: Use the injected fixture value auth_tool_schema_dict *** - insecure_core_tool = self._create_core_tool_from_dict( - session=mock_session, - name="test_tool", - schema_dict=auth_tool_schema_dict, # Use the fixture value here - url="http://test-url", - ) - insecure_auth_langchain_tool = AsyncToolboxTool(core_tool=insecure_core_tool) - with pytest.warns( UserWarning, match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", ): - tool_with_getter = insecure_auth_langchain_tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} + insecure_core_tool = self._create_core_tool_from_dict( + session=mock_session, + name="test_tool", + schema_dict=auth_tool_schema_dict, + url="http://test-url", ) - result = await tool_with_getter.ainvoke({"param2": 123}) - assert result == "test-result" - modified_core_tool_in_new_tool = ( - tool_with_getter._AsyncToolboxTool__core_tool - ) - assert ( - modified_core_tool_in_new_tool._ToolboxTool__base_url - == "http://test-url" - ) - assert ( - modified_core_tool_in_new_tool._ToolboxTool__url - == "http://test-url/api/tool/test_tool/invoke" - ) + insecure_auth_langchain_tool = AsyncToolboxTool(core_tool=insecure_core_tool) - modified_core_tool_in_new_tool._ToolboxTool__session.post.assert_called_once_with( - "http://test-url/api/tool/test_tool/invoke", - json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, - ) + tool_with_getter = insecure_auth_langchain_tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool_with_getter.ainvoke({"param2": 123}) + assert result == "test-result" + + modified_core_tool_in_new_tool = ( + tool_with_getter._AsyncToolboxTool__core_tool + ) + assert ( + modified_core_tool_in_new_tool._ToolboxTool__base_url + == "http://test-url" + ) + assert ( + modified_core_tool_in_new_tool._ToolboxTool__url + == "http://test-url/api/tool/test_tool/invoke" + ) + + modified_core_tool_in_new_tool._ToolboxTool__session.post.assert_called_once_with( + "http://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): with pytest.raises(ValidationError) as e: From d5c57df1681669ba3d8cb4f7d4dcf484e10c5798 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Wed, 14 May 2025 18:42:26 +0530 Subject: [PATCH 38/53] chore: Delint --- packages/toolbox-langchain/tests/test_async_tools.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_async_tools.py b/packages/toolbox-langchain/tests/test_async_tools.py index d63b90a2..96bd7660 100644 --- a/packages/toolbox-langchain/tests/test_async_tools.py +++ b/packages/toolbox-langchain/tests/test_async_tools.py @@ -327,12 +327,9 @@ async def test_toolbox_tool_call_with_auth_tokens_insecure( result = await tool_with_getter.ainvoke({"param2": 123}) assert result == "test-result" - modified_core_tool_in_new_tool = ( - tool_with_getter._AsyncToolboxTool__core_tool - ) + modified_core_tool_in_new_tool = tool_with_getter._AsyncToolboxTool__core_tool assert ( - modified_core_tool_in_new_tool._ToolboxTool__base_url - == "http://test-url" + modified_core_tool_in_new_tool._ToolboxTool__base_url == "http://test-url" ) assert ( modified_core_tool_in_new_tool._ToolboxTool__url From ba29e2d635d4360d36a71e43b7f9262dc4a92507 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Wed, 14 May 2025 18:52:57 +0530 Subject: [PATCH 39/53] chore: Rename internal member variable names to be more concise --- .../src/toolbox_langchain/client.py | 38 +++++----- .../src/toolbox_langchain/tools.py | 28 ++++---- .../toolbox-langchain/tests/test_client.py | 70 +++++++++---------- packages/toolbox-langchain/tests/test_e2e.py | 12 ++-- .../toolbox-langchain/tests/test_tools.py | 48 ++++++------- 5 files changed, 98 insertions(+), 98 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 72317fd3..d26eede8 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -34,7 +34,7 @@ def __init__( Args: url: The base URL of the Toolbox service. """ - self.__core_sync_client = ToolboxCoreSyncClient(url=url) + self.__core_client = ToolboxCoreSyncClient(url=url) async def aload_tool( self, @@ -85,30 +85,30 @@ async def aload_tool( ) auth_token_getters = auth_headers - coro = self.__core_sync_client._async_client.load_tool( + coro = self.__core_client._async_client.load_tool( name=tool_name, auth_token_getters=auth_token_getters, bound_params=bound_params, ) - if not self.__core_sync_client._loop: + if not self.__core_client._loop: # If a loop has not been provided, attempt to run in current thread. core_tool = await coro else: # Otherwise, run in the background thread. core_tool = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._loop) + asyncio.run_coroutine_threadsafe(coro, self.__core_client._loop) ) - if not self.__core_sync_client._loop or not self.__core_sync_client._thread: + if not self.__core_client._loop or not self.__core_client._thread: raise ValueError("Background loop or thread cannot be None.") core_sync_tool = ToolboxSyncTool( core_tool, - self.__core_sync_client._loop, - self.__core_sync_client._thread, + self.__core_client._loop, + self.__core_client._thread, ) - return ToolboxTool(core_sync_tool=core_sync_tool) + return ToolboxTool(core_tool=core_sync_tool) async def aload_toolset( self, @@ -167,36 +167,36 @@ async def aload_toolset( ) auth_token_getters = auth_headers - coro = self.__core_sync_client._async_client.load_toolset( + coro = self.__core_client._async_client.load_toolset( name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, strict=strict, ) - if not self.__core_sync_client._loop: + if not self.__core_client._loop: # If a loop has not been provided, attempt to run in current thread. core_tools = await coro else: # Otherwise, run in the background thread. core_tools = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_sync_client._loop) + asyncio.run_coroutine_threadsafe(coro, self.__core_client._loop) ) - if not self.__core_sync_client._loop or not self.__core_sync_client._thread: + if not self.__core_client._loop or not self.__core_client._thread: raise ValueError("Background loop or thread cannot be None.") core_sync_tools = [ ToolboxSyncTool( core_tool, - self.__core_sync_client._loop, - self.__core_sync_client._thread, + self.__core_client._loop, + self.__core_client._thread, ) for core_tool in core_tools ] tools = [] for core_sync_tool in core_sync_tools: - tools.append(ToolboxTool(core_sync_tool=core_sync_tool)) + tools.append(ToolboxTool(core_tool=core_sync_tool)) return tools def load_tool( @@ -248,12 +248,12 @@ def load_tool( ) auth_token_getters = auth_headers - core_sync_tool = self.__core_sync_client.load_tool( + core_sync_tool = self.__core_client.load_tool( name=tool_name, auth_token_getters=auth_token_getters, bound_params=bound_params, ) - return ToolboxTool(core_sync_tool=core_sync_tool) + return ToolboxTool(core_tool=core_sync_tool) def load_toolset( self, @@ -312,7 +312,7 @@ def load_toolset( ) auth_token_getters = auth_headers - core_sync_tools = self.__core_sync_client.load_toolset( + core_sync_tools = self.__core_client.load_toolset( name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, @@ -321,5 +321,5 @@ def load_toolset( tools = [] for core_sync_tool in core_sync_tools: - tools.append(ToolboxTool(core_sync_tool=core_sync_tool)) + tools.append(ToolboxTool(core_tool=core_sync_tool)) return tools diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 29a99bff..659c5985 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -27,37 +27,37 @@ class ToolboxTool(BaseTool): def __init__( self, - core_sync_tool: ToolboxCoreSyncTool, + core_tool: ToolboxCoreSyncTool, ) -> None: """ Initializes a ToolboxTool instance. Args: - core_sync_tool: The underlying core sync ToolboxTool instance. + core_tool: The underlying core sync ToolboxTool instance. """ # Due to how pydantic works, we must initialize the underlying # BaseTool class before assigning values to member variables. super().__init__( - name=core_sync_tool.__name__, - description=core_sync_tool.__doc__, - args_schema=core_sync_tool._async_tool._pydantic_model, + name=core_tool.__name__, + description=core_tool.__doc__, + args_schema=core_tool._async_tool._pydantic_model, ) - self.__core_sync_tool = core_sync_tool + self.__core_tool = core_tool def _run(self, **kwargs: Any) -> str: - return self.__core_sync_tool(**kwargs) + return self.__core_tool(**kwargs) async def _arun(self, **kwargs: Any) -> str: - coro = self.__core_sync_tool._async_tool(**kwargs) + coro = self.__core_tool._async_tool(**kwargs) # If a loop has not been provided, attempt to run in current thread. - if not self.__core_sync_tool._loop: + if not self.__core_tool._loop: return await coro # Otherwise, run in the background thread. return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_sync_tool._loop) + asyncio.run_coroutine_threadsafe(coro, self.__core_tool._loop) ) def add_auth_token_getters( @@ -79,10 +79,10 @@ def add_auth_token_getters( ValueError: If any of the provided auth parameters is already registered. """ - new_core_sync_tool = self.__core_sync_tool.add_auth_token_getters( + new_core_tool = self.__core_tool.add_auth_token_getters( auth_token_getters ) - return ToolboxTool(core_sync_tool=new_core_sync_tool) + return ToolboxTool(core_tool=new_core_tool) def add_auth_token_getter( self, auth_source: str, get_id_token: Callable[[], str] @@ -123,8 +123,8 @@ def bind_params( Raises: ValueError: If any of the provided bound params is already bound. """ - new_core_sync_tool = self.__core_sync_tool.bind_params(bound_params) - return ToolboxTool(core_sync_tool=new_core_sync_tool) + new_core_tool = self.__core_tool.bind_params(bound_params) + return ToolboxTool(core_tool=new_core_tool) def bind_param( self, diff --git a/packages/toolbox-langchain/tests/test_client.py b/packages/toolbox-langchain/tests/test_client.py index bae8123e..d7eb62a8 100644 --- a/packages/toolbox-langchain/tests/test_client.py +++ b/packages/toolbox-langchain/tests/test_client.py @@ -30,35 +30,35 @@ class TestToolboxClient: def toolbox_client(self): client = ToolboxClient(URL) assert isinstance(client, ToolboxClient) - assert client._ToolboxClient__core_sync_client is not None - assert client._ToolboxClient__core_sync_client._async_client is not None - assert client._ToolboxClient__core_sync_client._loop is not None - assert client._ToolboxClient__core_sync_client._loop.is_running() - assert client._ToolboxClient__core_sync_client._thread is not None - assert client._ToolboxClient__core_sync_client._thread.is_alive() + assert client._ToolboxClient__core_client is not None + assert client._ToolboxClient__core_client._async_client is not None + assert client._ToolboxClient__core_client._loop is not None + assert client._ToolboxClient__core_client._loop.is_running() + assert client._ToolboxClient__core_client._thread is not None + assert client._ToolboxClient__core_client._thread.is_alive() return client @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") def test_load_tool(self, mock_core_load_tool, toolbox_client): - mock_core_sync_tool_instance = Mock( + mock_core_tool_instance = Mock( spec=ToolboxCoreSyncTool ) # Spec with Core Sync Tool - mock_core_sync_tool_instance.__name__ = "mock-core-sync-tool" - mock_core_sync_tool_instance.__doc__ = "mock core sync description" + mock_core_tool_instance.__name__ = "mock-core-sync-tool" + mock_core_tool_instance.__doc__ = "mock core sync description" mock_underlying_async_tool = Mock( spec=ToolboxCoreTool ) # Core Async Tool for pydantic model mock_underlying_async_tool._pydantic_model = BaseModel - mock_core_sync_tool_instance._async_tool = mock_underlying_async_tool + mock_core_tool_instance._async_tool = mock_underlying_async_tool - mock_core_load_tool.return_value = mock_core_sync_tool_instance + mock_core_load_tool.return_value = mock_core_tool_instance langchain_tool = toolbox_client.load_tool("test_tool") assert isinstance(langchain_tool, ToolboxTool) - assert langchain_tool.name == mock_core_sync_tool_instance.__name__ - assert langchain_tool.description == mock_core_sync_tool_instance.__doc__ + assert langchain_tool.name == mock_core_tool_instance.__name__ + assert langchain_tool.description == mock_core_tool_instance.__doc__ assert langchain_tool.args_schema == mock_underlying_async_tool._pydantic_model mock_core_load_tool.assert_called_once_with( @@ -67,23 +67,23 @@ def test_load_tool(self, mock_core_load_tool, toolbox_client): @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") def test_load_toolset(self, mock_core_load_toolset, toolbox_client): - mock_core_sync_tool_instance1 = Mock(spec=ToolboxCoreSyncTool) - mock_core_sync_tool_instance1.__name__ = "mock-core-sync-tool-0" - mock_core_sync_tool_instance1.__doc__ = "desc 0" + mock_core_tool_instance1 = Mock(spec=ToolboxCoreSyncTool) + mock_core_tool_instance1.__name__ = "mock-core-sync-tool-0" + mock_core_tool_instance1.__doc__ = "desc 0" mock_async_tool0 = Mock(spec=ToolboxCoreTool) mock_async_tool0._pydantic_model = BaseModel - mock_core_sync_tool_instance1._async_tool = mock_async_tool0 + mock_core_tool_instance1._async_tool = mock_async_tool0 - mock_core_sync_tool_instance2 = Mock(spec=ToolboxCoreSyncTool) - mock_core_sync_tool_instance2.__name__ = "mock-core-sync-tool-1" - mock_core_sync_tool_instance2.__doc__ = "desc 1" + mock_core_tool_instance2 = Mock(spec=ToolboxCoreSyncTool) + mock_core_tool_instance2.__name__ = "mock-core-sync-tool-1" + mock_core_tool_instance2.__doc__ = "desc 1" mock_async_tool1 = Mock(spec=ToolboxCoreTool) mock_async_tool1._pydantic_model = BaseModel - mock_core_sync_tool_instance2._async_tool = mock_async_tool1 + mock_core_tool_instance2._async_tool = mock_async_tool1 mock_core_load_toolset.return_value = [ - mock_core_sync_tool_instance1, - mock_core_sync_tool_instance2, + mock_core_tool_instance1, + mock_core_tool_instance2, ] langchain_tools = toolbox_client.load_toolset() @@ -114,7 +114,7 @@ async def test_aload_tool(self, mock_core_aload_tool, toolbox_client): assert langchain_tool.name == mock_core_tool_instance.__name__ assert langchain_tool.description == mock_core_tool_instance.__doc__ - toolbox_client._ToolboxClient__core_sync_client._async_client.load_tool.assert_called_once_with( + toolbox_client._ToolboxClient__core_client._async_client.load_tool.assert_called_once_with( name="test_tool", auth_token_getters={}, bound_params={} ) @@ -145,18 +145,18 @@ async def test_aload_toolset(self, mock_core_aload_toolset, toolbox_client): assert isinstance(langchain_tools[0], ToolboxTool) assert isinstance(langchain_tools[1], ToolboxTool) - toolbox_client._ToolboxClient__core_sync_client._async_client.load_toolset.assert_called_once_with( + toolbox_client._ToolboxClient__core_client._async_client.load_toolset.assert_called_once_with( name=None, auth_token_getters={}, bound_params={}, strict=False ) @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): - mock_core_sync_tool_instance = Mock(spec=ToolboxCoreSyncTool) - mock_core_sync_tool_instance.__name__ = "mock-tool" + mock_core_tool_instance = Mock(spec=ToolboxCoreSyncTool) + mock_core_tool_instance.__name__ = "mock-tool" mock_async_tool = Mock(spec=ToolboxCoreTool) mock_async_tool._pydantic_model = BaseModel - mock_core_sync_tool_instance._async_tool = mock_async_tool - mock_core_load_tool.return_value = mock_core_sync_tool_instance + mock_core_tool_instance._async_tool = mock_async_tool + mock_core_load_tool.return_value = mock_core_tool_instance auth_token_getters = {"token_getter1": lambda: "value1"} auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} @@ -216,12 +216,12 @@ def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") def test_load_toolset_with_args(self, mock_core_load_toolset, toolbox_client): - mock_core_sync_tool_instance = Mock(spec=ToolboxCoreSyncTool) - mock_core_sync_tool_instance.__name__ = "mock-tool-0" + mock_core_tool_instance = Mock(spec=ToolboxCoreSyncTool) + mock_core_tool_instance.__name__ = "mock-tool-0" mock_async_tool = Mock(spec=ToolboxCoreTool) mock_async_tool._pydantic_model = BaseModel - mock_core_sync_tool_instance._async_tool = mock_async_tool - mock_core_load_toolset.return_value = [mock_core_sync_tool_instance] + mock_core_tool_instance._async_tool = mock_async_tool + mock_core_load_toolset.return_value = [mock_core_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} @@ -277,7 +277,7 @@ async def test_aload_tool_with_args(self, mock_core_aload_tool, toolbox_client): assert any("auth_headers` is deprecated" in m for m in messages) assert isinstance(tool, ToolboxTool) - toolbox_client._ToolboxClient__core_sync_client._async_client.load_tool.assert_called_with( + toolbox_client._ToolboxClient__core_client._async_client.load_tool.assert_called_with( name="test_tool", auth_token_getters=auth_token_getters, bound_params=bound_params, @@ -313,7 +313,7 @@ async def test_aload_toolset_with_args( assert any("auth_headers` is deprecated" in m for m in messages) assert len(tools) == 1 - toolbox_client._ToolboxClient__core_sync_client._async_client.load_toolset.assert_called_with( + toolbox_client._ToolboxClient__core_client._async_client.load_toolset.assert_called_with( name="my_toolset", auth_token_getters=auth_token_getters, bound_params=bound_params, diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 7c9b417f..12002717 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -53,7 +53,7 @@ def toolbox(self): @pytest_asyncio.fixture(scope="function") async def get_n_rows_tool(self, toolbox): tool = await toolbox.aload_tool("get-n-rows") - assert tool._ToolboxTool__core_sync_tool.__name__ == "get-n-rows" + assert tool._ToolboxTool__core_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests @@ -70,7 +70,7 @@ async def test_aload_toolset_specific( toolset = await toolbox.aload_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__core_sync_tool.__name__ + name = tool._ToolboxTool__core_tool.__name__ assert name in expected_tools async def test_aload_toolset_all(self, toolbox): @@ -84,7 +84,7 @@ async def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__core_sync_tool.__name__ + name = tool._ToolboxTool__core_tool.__name__ assert name in tool_names async def test_run_tool_async(self, get_n_rows_tool): @@ -198,7 +198,7 @@ def toolbox(self): @pytest.fixture(scope="function") def get_n_rows_tool(self, toolbox): tool = toolbox.load_tool("get-n-rows") - assert tool._ToolboxTool__core_sync_tool.__name__ == "get-n-rows" + assert tool._ToolboxTool__core_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests @@ -215,7 +215,7 @@ def test_load_toolset_specific( toolset = toolbox.load_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__core_sync_tool.__name__ + name = tool._ToolboxTool__core_tool.__name__ assert name in expected_tools def test_aload_toolset_all(self, toolbox): @@ -229,7 +229,7 @@ def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__core_sync_tool.__name__ + name = tool._ToolboxTool__core_tool.__name__ assert name in tool_names @pytest.mark.asyncio diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 090f0f55..1d0b3a14 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -66,7 +66,7 @@ def mock_core_async_auth_tool(self, auth_tool_schema_dict): return mock @pytest.fixture - def mock_core_sync_tool(self, mock_core_async_tool): + def mock_core_tool(self, mock_core_async_tool): sync_mock = Mock(spec=ToolboxCoreSyncTool) sync_mock.__name__ = mock_core_async_tool.__name__ sync_mock.__doc__ = mock_core_async_tool.__doc__ @@ -94,19 +94,19 @@ def mock_core_sync_auth_tool(self, mock_core_async_auth_tool): return sync_mock @pytest.fixture - def toolbox_tool(self, mock_core_sync_tool): - return ToolboxTool(core_sync_tool=mock_core_sync_tool) + def toolbox_tool(self, mock_core_tool): + return ToolboxTool(core_tool=mock_core_tool) @pytest.fixture def auth_toolbox_tool(self, mock_core_sync_auth_tool): - return ToolboxTool(core_sync_tool=mock_core_sync_auth_tool) + return ToolboxTool(core_tool=mock_core_sync_auth_tool) - def test_toolbox_tool_init(self, mock_core_sync_tool): - tool = ToolboxTool(core_sync_tool=mock_core_sync_tool) - core_sync_tool_in_tool = tool._ToolboxTool__core_sync_tool - assert core_sync_tool_in_tool.__name__ == mock_core_sync_tool.__name__ - assert core_sync_tool_in_tool.__doc__ == mock_core_sync_tool.__doc__ - assert tool.args_schema == mock_core_sync_tool._async_tool._pydantic_model + def test_toolbox_tool_init(self, mock_core_tool): + tool = ToolboxTool(core_tool=mock_core_tool) + core_tool_in_tool = tool._ToolboxTool__core_tool + assert core_tool_in_tool.__name__ == mock_core_tool.__name__ + assert core_tool_in_tool.__doc__ == mock_core_tool.__doc__ + assert tool.args_schema == mock_core_tool._async_tool._pydantic_model @pytest.mark.parametrize( "params, expected_bound_params_on_core", @@ -124,29 +124,29 @@ def test_toolbox_tool_bind_params( params, expected_bound_params_on_core, toolbox_tool, - mock_core_sync_tool, + mock_core_tool, ): - mock_core_sync_tool.bind_params.return_value = mock_core_sync_tool + mock_core_tool.bind_params.return_value = mock_core_tool new_langchain_tool = toolbox_tool.bind_params(params) - mock_core_sync_tool.bind_params.assert_called_once_with(params) + mock_core_tool.bind_params.assert_called_once_with(params) assert isinstance(new_langchain_tool, ToolboxTool) assert ( - new_langchain_tool._ToolboxTool__core_sync_tool - == mock_core_sync_tool.bind_params.return_value + new_langchain_tool._ToolboxTool__core_tool + == mock_core_tool.bind_params.return_value ) - def test_toolbox_tool_bind_param(self, toolbox_tool, mock_core_sync_tool): - # ToolboxTool.bind_param calls core_sync_tool.bind_params - mock_core_sync_tool.bind_params.return_value = mock_core_sync_tool + def test_toolbox_tool_bind_param(self, toolbox_tool, mock_core_tool): + # ToolboxTool.bind_param calls core_tool.bind_params + mock_core_tool.bind_params.return_value = mock_core_tool new_langchain_tool = toolbox_tool.bind_param("param1", "bound-value") # *** Fix: Assert that bind_params is called on the core tool *** - mock_core_sync_tool.bind_params.assert_called_once_with( + mock_core_tool.bind_params.assert_called_once_with( {"param1": "bound-value"} ) assert isinstance(new_langchain_tool, ToolboxTool) assert ( - new_langchain_tool._ToolboxTool__core_sync_tool - == mock_core_sync_tool.bind_params.return_value + new_langchain_tool._ToolboxTool__core_tool + == mock_core_tool.bind_params.return_value ) @pytest.mark.parametrize( @@ -186,7 +186,7 @@ def test_toolbox_tool_add_auth_token_getters( ) assert isinstance(new_langchain_tool, ToolboxTool) assert ( - new_langchain_tool._ToolboxTool__core_sync_tool + new_langchain_tool._ToolboxTool__core_tool == mock_core_sync_auth_tool.add_auth_token_getters.return_value ) @@ -194,7 +194,7 @@ def test_toolbox_tool_add_auth_token_getter( self, auth_toolbox_tool, mock_core_sync_auth_tool ): get_id_token = lambda: "test-token" - # ToolboxTool.add_auth_token_getter calls core_sync_tool.add_auth_token_getters + # ToolboxTool.add_auth_token_getter calls core_tool.add_auth_token_getters mock_core_sync_auth_tool.add_auth_token_getters.return_value = ( mock_core_sync_auth_tool ) @@ -209,6 +209,6 @@ def test_toolbox_tool_add_auth_token_getter( ) assert isinstance(new_langchain_tool, ToolboxTool) assert ( - new_langchain_tool._ToolboxTool__core_sync_tool + new_langchain_tool._ToolboxTool__core_tool == mock_core_sync_auth_tool.add_auth_token_getters.return_value ) From a0f3020afb325570d2df501df769ce3255938498 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Wed, 14 May 2025 18:54:04 +0530 Subject: [PATCH 40/53] chore: Delint --- packages/toolbox-langchain/src/toolbox_langchain/tools.py | 4 +--- packages/toolbox-langchain/tests/test_tools.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 659c5985..4602903d 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -79,9 +79,7 @@ def add_auth_token_getters( ValueError: If any of the provided auth parameters is already registered. """ - new_core_tool = self.__core_tool.add_auth_token_getters( - auth_token_getters - ) + new_core_tool = self.__core_tool.add_auth_token_getters(auth_token_getters) return ToolboxTool(core_tool=new_core_tool) def add_auth_token_getter( diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 1d0b3a14..5560cf99 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -140,9 +140,7 @@ def test_toolbox_tool_bind_param(self, toolbox_tool, mock_core_tool): mock_core_tool.bind_params.return_value = mock_core_tool new_langchain_tool = toolbox_tool.bind_param("param1", "bound-value") # *** Fix: Assert that bind_params is called on the core tool *** - mock_core_tool.bind_params.assert_called_once_with( - {"param1": "bound-value"} - ) + mock_core_tool.bind_params.assert_called_once_with({"param1": "bound-value"}) assert isinstance(new_langchain_tool, ToolboxTool) assert ( new_langchain_tool._ToolboxTool__core_tool From abd6dd834e65d6f51cdfdf4785bdcf51e9db11bc Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 15 May 2025 13:56:43 +0530 Subject: [PATCH 41/53] chore: Add toolbox actual package version in toml and local path in requirements.txt --- packages/toolbox-langchain/pyproject.toml | 3 +-- packages/toolbox-langchain/requirements.txt | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index 629745c9..f0cb6276 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,8 +9,7 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ - # TODO: Replace with actual package dependency (eg. "toolbox-core>=0.2.0,<1.0.0") - "toolbox-core @ git+https://github.com/googleapis/mcp-toolbox-sdk-python.git@anubhav-lc-wraps-core#subdirectory=packages/toolbox-core", + "toolbox-core>=0.2.0,<1.0.0", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.7.0,<3.0.0", diff --git a/packages/toolbox-langchain/requirements.txt b/packages/toolbox-langchain/requirements.txt index 5fd65843..1da4fefa 100644 --- a/packages/toolbox-langchain/requirements.txt +++ b/packages/toolbox-langchain/requirements.txt @@ -1,3 +1,4 @@ +-e packages/toolbox-core langchain-core==0.3.56 PyYAML==6.0.2 pydantic==2.11.4 From f3817b38c2d2b4daf7df76554c77c7efdd2adbc8 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 15 May 2025 13:58:28 +0530 Subject: [PATCH 42/53] fix: Fix editable toolbox-core package path in requirements.txt --- packages/toolbox-langchain/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/requirements.txt b/packages/toolbox-langchain/requirements.txt index 1da4fefa..3ada831d 100644 --- a/packages/toolbox-langchain/requirements.txt +++ b/packages/toolbox-langchain/requirements.txt @@ -1,4 +1,4 @@ --e packages/toolbox-core +-e ../toolbox-core langchain-core==0.3.56 PyYAML==6.0.2 pydantic==2.11.4 From 3562bd7c7e79892c93173134cdbb2fef7164bc70 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 15 May 2025 14:05:15 +0530 Subject: [PATCH 43/53] fix: Fix lowest supported version until released --- packages/toolbox-langchain/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index f0cb6276..9aaa254a 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,7 +9,8 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ - "toolbox-core>=0.2.0,<1.0.0", + # TODO: Bump lowest supported version to 0.2.0 + "toolbox-core>=0.1.0,<1.0.0", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.7.0,<3.0.0", From 23618c4a2def9bde72d8bbda31b8b25c569e76f7 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 15 May 2025 14:10:06 +0530 Subject: [PATCH 44/53] fix: Fix issue causing relative path in requirements.txt to cause issues --- packages/toolbox-langchain/integration.cloudbuild.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/integration.cloudbuild.yaml b/packages/toolbox-langchain/integration.cloudbuild.yaml index 644794fb..51f0ce81 100644 --- a/packages/toolbox-langchain/integration.cloudbuild.yaml +++ b/packages/toolbox-langchain/integration.cloudbuild.yaml @@ -15,10 +15,11 @@ steps: - id: Install library requirements name: 'python:${_VERSION}' + dir: 'packages/toolbox-langchain' args: - install - '-r' - - 'packages/toolbox-langchain/requirements.txt' + - 'requirements.txt' - '--user' entrypoint: pip - id: Install test requirements From ac535a3801094df8b35c57ec7fb553e0e259a49c Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 15 May 2025 17:02:58 +0530 Subject: [PATCH 45/53] docs: Fix issue README --- packages/toolbox-langchain/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/README.md b/packages/toolbox-langchain/README.md index 9f698694..fca7736b 100644 --- a/packages/toolbox-langchain/README.md +++ b/packages/toolbox-langchain/README.md @@ -227,7 +227,7 @@ tools = toolbox.load_toolset() auth_tool = tools[0].add_auth_token_getter("my_auth", get_auth_token) # Single token -multi_auth_tool = tools[0].add_auth_token_getters({"my_auth", get_auth_token}) # Multiple tokens +multi_auth_tool = tools[0].add_auth_token_getters({"auth_1": get_auth_1}, {"auth_2": get_auth_2}) # Multiple tokens # OR From 9d3552f200c1f25aca2426c1a1b66fabb0ff19b5 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 16 May 2025 21:13:42 +0530 Subject: [PATCH 46/53] fix: Use correct core client interfaces in langchain client --- .../src/toolbox_langchain/client.py | 66 +++++-------------- 1 file changed, 16 insertions(+), 50 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index d26eede8..a69f5f2c 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio +from asyncio import wrap_future from typing import Any, Callable, Optional, Union from warnings import warn @@ -85,30 +85,14 @@ async def aload_tool( ) auth_token_getters = auth_headers - coro = self.__core_client._async_client.load_tool( - name=tool_name, - auth_token_getters=auth_token_getters, - bound_params=bound_params, - ) - - if not self.__core_client._loop: - # If a loop has not been provided, attempt to run in current thread. - core_tool = await coro - else: - # Otherwise, run in the background thread. - core_tool = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_client._loop) + core_tool = await wrap_future( + self.__core_client._load_tool_future( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) - - if not self.__core_client._loop or not self.__core_client._thread: - raise ValueError("Background loop or thread cannot be None.") - - core_sync_tool = ToolboxSyncTool( - core_tool, - self.__core_client._loop, - self.__core_client._thread, ) - return ToolboxTool(core_tool=core_sync_tool) + return ToolboxTool(core_tool=core_tool) async def aload_toolset( self, @@ -167,36 +151,18 @@ async def aload_toolset( ) auth_token_getters = auth_headers - coro = self.__core_client._async_client.load_toolset( - name=toolset_name, - auth_token_getters=auth_token_getters, - bound_params=bound_params, - strict=strict, - ) - - if not self.__core_client._loop: - # If a loop has not been provided, attempt to run in current thread. - core_tools = await coro - else: - # Otherwise, run in the background thread. - core_tools = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_client._loop) + core_tools = await wrap_future( + self.__core_client._load_toolset_future( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, ) + ) - if not self.__core_client._loop or not self.__core_client._thread: - raise ValueError("Background loop or thread cannot be None.") - - core_sync_tools = [ - ToolboxSyncTool( - core_tool, - self.__core_client._loop, - self.__core_client._thread, - ) - for core_tool in core_tools - ] tools = [] - for core_sync_tool in core_sync_tools: - tools.append(ToolboxTool(core_tool=core_sync_tool)) + for core_tool in core_tools: + tools.append(ToolboxTool(core_tool=core_tool)) return tools def load_tool( From 551b113c969907b4b8d9305d2686f9f92d1c2c36 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 16 May 2025 21:13:58 +0530 Subject: [PATCH 47/53] fix: Use correct core tool interfaces in langchain tool --- .../src/toolbox_langchain/tools.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 4602903d..17e60a7d 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio +from asyncio import wrap_future from typing import Any, Callable, Union from langchain_core.tools import BaseTool from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.utils import params_to_pydantic_model class ToolboxTool(BaseTool): @@ -41,7 +42,7 @@ def __init__( super().__init__( name=core_tool.__name__, description=core_tool.__doc__, - args_schema=core_tool._async_tool._pydantic_model, + args_schema=params_to_pydantic_model(core_tool._name, core_tool._params), ) self.__core_tool = core_tool @@ -49,16 +50,7 @@ def _run(self, **kwargs: Any) -> str: return self.__core_tool(**kwargs) async def _arun(self, **kwargs: Any) -> str: - coro = self.__core_tool._async_tool(**kwargs) - - # If a loop has not been provided, attempt to run in current thread. - if not self.__core_tool._loop: - return await coro - - # Otherwise, run in the background thread. - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__core_tool._loop) - ) + return await wrap_future(self.__core_tool._call_future(**kwargs)) def add_auth_token_getters( self, auth_token_getters: dict[str, Callable[[], str]] From 690639349a9dd49a59fc3e0eb9693f2a0460e16e Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Fri, 16 May 2025 22:14:03 +0530 Subject: [PATCH 48/53] fix: Use correct interface from core tool --- .../toolbox-langchain/src/toolbox_langchain/async_tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index 8bbcf500..627b18e1 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -16,6 +16,7 @@ from langchain_core.tools import BaseTool from toolbox_core.tool import ToolboxTool as ToolboxCoreTool +from toolbox_core.utils import params_to_pydantic_model # This class is an internal implementation detail and is not exposed to the @@ -43,7 +44,7 @@ def __init__( super().__init__( name=core_tool.__name__, description=core_tool.__doc__, - args_schema=core_tool._pydantic_model, + args_schema=params_to_pydantic_model(core_tool._name, core_tool._params), ) self.__core_tool = core_tool From f61e81ebea9a0358f4dcebf6c002a7a2f2fcf8be Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 17 May 2025 01:38:49 +0530 Subject: [PATCH 49/53] fix: Use correct interfaces of toolbox-core --- .../src/toolbox_langchain/client.py | 26 +++++++++---------- .../src/toolbox_langchain/tools.py | 4 +-- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index a69f5f2c..1d395585 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from asyncio import wrap_future +from asyncio import to_thread from typing import Any, Callable, Optional, Union from warnings import warn @@ -85,12 +85,11 @@ async def aload_tool( ) auth_token_getters = auth_headers - core_tool = await wrap_future( - self.__core_client._load_tool_future( - name=tool_name, - auth_token_getters=auth_token_getters, - bound_params=bound_params, - ) + core_tool = await to_thread( + self.__core_client.load_tool, + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) return ToolboxTool(core_tool=core_tool) @@ -151,13 +150,12 @@ async def aload_toolset( ) auth_token_getters = auth_headers - core_tools = await wrap_future( - self.__core_client._load_toolset_future( - name=toolset_name, - auth_token_getters=auth_token_getters, - bound_params=bound_params, - strict=strict, - ) + core_tools = await to_thread( + self.__core_client.load_toolset, + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, ) tools = [] diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 17e60a7d..fd7ab197 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from asyncio import wrap_future +from asyncio import to_thread from typing import Any, Callable, Union from langchain_core.tools import BaseTool @@ -50,7 +50,7 @@ def _run(self, **kwargs: Any) -> str: return self.__core_tool(**kwargs) async def _arun(self, **kwargs: Any) -> str: - return await wrap_future(self.__core_tool._call_future(**kwargs)) + return await to_thread(self.__core_tool, **kwargs) def add_auth_token_getters( self, auth_token_getters: dict[str, Callable[[], str]] From 9bfb8d1f0b35a38e4af8f990e53c88c04a11118d Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 17 May 2025 02:54:04 +0530 Subject: [PATCH 50/53] chore: Update async client unit tests --- .../tests/test_async_client.py | 181 ++++++++++++------ 1 file changed, 124 insertions(+), 57 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_async_client.py b/packages/toolbox-langchain/tests/test_async_client.py index 988d3974..6b398560 100644 --- a/packages/toolbox-langchain/tests/test_async_client.py +++ b/packages/toolbox-langchain/tests/test_async_client.py @@ -18,10 +18,8 @@ import pytest from aiohttp import ClientSession from toolbox_core.client import ToolboxClient as ToolboxCoreClient -from toolbox_core.protocol import ManifestSchema from toolbox_core.protocol import ParameterSchema as CoreParameterSchema from toolbox_core.tool import ToolboxTool as ToolboxCoreTool -from toolbox_core.utils import params_to_pydantic_model from toolbox_langchain.async_client import AsyncToolboxClient from toolbox_langchain.async_tools import AsyncToolboxTool @@ -56,16 +54,12 @@ @pytest.mark.asyncio class TestAsyncToolboxClient: - @pytest.fixture() - def manifest_schema(self): - return ManifestSchema(**MANIFEST_JSON) - @pytest.fixture() def mock_session(self): return AsyncMock(spec=ClientSession) @pytest.fixture - def mock_core_client_instance(self, manifest_schema, mock_session): + def mock_core_client_instance(self, mock_session): mock = AsyncMock(spec=ToolboxCoreClient) async def mock_load_tool_impl(name, auth_token_getters, bound_params): @@ -80,7 +74,8 @@ async def mock_load_tool_impl(name, auth_token_getters, bound_params): core_tool_mock = AsyncMock(spec=ToolboxCoreTool) core_tool_mock.__name__ = name core_tool_mock.__doc__ = tool_schema_dict["description"] - core_tool_mock._pydantic_model = params_to_pydantic_model(name, core_params) + core_tool_mock._name = name + core_tool_mock._params = core_params # Add other necessary attributes or method mocks if AsyncToolboxTool uses them return core_tool_mock @@ -97,9 +92,8 @@ async def mock_load_toolset_impl( core_tool_mock = AsyncMock(spec=ToolboxCoreTool) core_tool_mock.__name__ = tool_name_iter core_tool_mock.__doc__ = tool_schema_dict["description"] - core_tool_mock._pydantic_model = params_to_pydantic_model( - tool_name_iter, core_params - ) + core_tool_mock._name = tool_name_iter + core_tool_mock._params = core_params core_tools_list.append(core_tool_mock) return core_tools_list @@ -130,36 +124,34 @@ async def test_create_with_existing_session(self, mock_client, mock_session): async def test_aload_tool( self, mock_client, - manifest_schema, # mock_session removed as it's part of mock_core_client_instance ): tool_name = "test_tool_1" - # manifest_schema is used by mock_core_client_instance fixture to provide tool details + test_bound_params = {"bp1": "value1"} - tool = await mock_client.aload_tool(tool_name) + tool = await mock_client.aload_tool(tool_name, bound_params=test_bound_params) # Assert that the core client's load_tool was called correctly mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( - name=tool_name, auth_token_getters={}, bound_params={} + name=tool_name, auth_token_getters={}, bound_params=test_bound_params ) assert isinstance(tool, AsyncToolboxTool) assert ( tool.name == tool_name ) # AsyncToolboxTool gets its name from the core_tool - async def test_aload_tool_auth_headers_deprecated( - self, mock_client, manifest_schema - ): + async def test_aload_tool_auth_headers_deprecated(self, mock_client): tool_name = "test_tool_1" - auth_lambda = lambda: "Bearer token" # Define lambda once + auth_lambda = lambda: "Bearer token" with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( tool_name, - auth_headers={"Authorization": auth_lambda}, # Use the defined lambda + auth_headers={"Authorization": auth_lambda}, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( name=tool_name, @@ -167,49 +159,87 @@ async def test_aload_tool_auth_headers_deprecated( bound_params={}, ) - async def test_aload_tool_auth_headers_and_tokens( - self, mock_client, manifest_schema - ): + async def test_aload_tool_auth_headers_and_getters_precedence(self, mock_client): tool_name = "test_tool_1" - auth_getters = {"test": lambda: "token"} - auth_headers_lambda = lambda: "Bearer token" # Define lambda once + auth_getters = {"test_source": lambda: "id_token_from_getters"} + auth_headers_lambda = lambda: "Bearer token_from_headers" with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( tool_name, - auth_headers={ - "Authorization": auth_headers_lambda - }, # Use defined lambda + auth_headers={"Authorization": auth_headers_lambda}, auth_token_getters=auth_getters, ) - assert ( - len(w) == 1 - ) # Only one warning because auth_token_getters takes precedence + assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) - assert "auth_headers" in str(w[-1].message) # Warning for auth_headers + assert "auth_headers" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( name=tool_name, auth_token_getters=auth_getters, bound_params={} ) - async def test_aload_toolset( - self, mock_client, manifest_schema # mock_session removed - ): - tools = await mock_client.aload_toolset() + async def test_aload_tool_auth_tokens_deprecated(self, mock_client): + tool_name = "test_tool_1" + token_lambda = lambda: "id_token" + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, + auth_tokens={"some_token_key": token_lambda}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, + auth_token_getters={"some_token_key": token_lambda}, + bound_params={}, + ) + + async def test_aload_tool_auth_tokens_and_getters_precedence(self, mock_client): + tool_name = "test_tool_1" + auth_getters = {"real_source": lambda: "token_from_getters"} + token_lambda = lambda: "token_from_auth_tokens" + + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, + auth_tokens={"deprecated_source": token_lambda}, + auth_token_getters=auth_getters, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters=auth_getters, bound_params={} + ) + + async def test_aload_toolset(self, mock_client): + test_bound_params = {"bp_set": "value_set"} + tools = await mock_client.aload_toolset( + bound_params=test_bound_params, strict=True + ) mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( - name=None, auth_token_getters={}, bound_params={}, strict=False + name=None, + auth_token_getters={}, + bound_params=test_bound_params, + strict=True, ) - assert len(tools) == 2 # Based on MANIFEST_JSON + assert len(tools) == 2 for tool in tools: assert isinstance(tool, AsyncToolboxTool) assert tool.name in ["test_tool_1", "test_tool_2"] - async def test_aload_toolset_with_toolset_name( - self, mock_client, manifest_schema # mock_session removed - ): - toolset_name = "test_toolset_1" # This name isn't in MANIFEST_JSON, but load_toolset mock doesn't filter by it + async def test_aload_toolset_with_toolset_name(self, mock_client): + toolset_name = "test_toolset_1" tools = await mock_client.aload_toolset(toolset_name=toolset_name) mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( @@ -218,20 +248,17 @@ async def test_aload_toolset_with_toolset_name( assert len(tools) == 2 for tool in tools: assert isinstance(tool, AsyncToolboxTool) - assert tool.name in ["test_tool_1", "test_tool_2"] - async def test_aload_toolset_auth_headers_deprecated( - self, mock_client, manifest_schema - ): - auth_lambda = lambda: "Bearer token" # Define lambda once + async def test_aload_toolset_auth_headers_deprecated(self, mock_client): + auth_lambda = lambda: "Bearer token" with catch_warnings(record=True) as w: simplefilter("always") - await mock_client.aload_toolset( - auth_headers={"Authorization": auth_lambda} # Use defined lambda - ) + await mock_client.aload_toolset(auth_headers={"Authorization": auth_lambda}) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( name=None, auth_token_getters={"Authorization": auth_lambda}, @@ -239,22 +266,62 @@ async def test_aload_toolset_auth_headers_deprecated( strict=False, ) - async def test_aload_toolset_auth_headers_and_tokens( - self, mock_client, manifest_schema + async def test_aload_toolset_auth_headers_and_getters_precedence( # Renamed for clarity + self, mock_client ): - auth_getters = {"test": lambda: "token"} - auth_headers_lambda = lambda: "Bearer token" # Define lambda once + auth_getters = {"test_source": lambda: "id_token_from_getters"} + auth_headers_lambda = lambda: "Bearer token_from_headers" with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_toolset( - auth_headers={ - "Authorization": auth_headers_lambda - }, # Use defined lambda + auth_headers={"Authorization": auth_headers_lambda}, auth_token_getters=auth_getters, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters=auth_getters, + bound_params={}, + strict=False, # auth_getters takes precedence + ) + + async def test_aload_toolset_auth_tokens_deprecated(self, mock_client): + token_lambda = lambda: "id_token" + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_tokens={"some_token_key": token_lambda} + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters={"some_token_key": token_lambda}, + bound_params={}, + strict=False, + ) + + async def test_aload_toolset_auth_tokens_and_getters_precedence(self, mock_client): + auth_getters = {"real_source": lambda: "token_from_getters"} + token_lambda = lambda: "token_from_auth_tokens" + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_tokens={"deprecated_source": token_lambda}, + auth_token_getters=auth_getters, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( name=None, auth_token_getters=auth_getters, bound_params={}, strict=False ) From e6cfed28dbe20703a96a77236e3cc3a1b87286a3 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 17 May 2025 03:05:38 +0530 Subject: [PATCH 51/53] chore: Fix client unit tests --- .../toolbox-langchain/tests/test_client.py | 321 ++++++++++-------- 1 file changed, 183 insertions(+), 138 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_client.py b/packages/toolbox-langchain/tests/test_client.py index d7eb62a8..ff097198 100644 --- a/packages/toolbox-langchain/tests/test_client.py +++ b/packages/toolbox-langchain/tests/test_client.py @@ -12,18 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock, patch import pytest from pydantic import BaseModel -from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool # For spec -from toolbox_core.tool import ToolboxTool as ToolboxCoreTool # For spec + +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.utils import params_to_pydantic_model from toolbox_langchain.client import ToolboxClient from toolbox_langchain.tools import ToolboxTool URL = "http://test_url" +def create_mock_core_sync_tool(name="mock-sync-tool", doc="Mock sync description.", model_name="MockSyncModel", params=None): + mock_tool = Mock(spec=ToolboxCoreSyncTool) + mock_tool.__name__ = name + mock_tool.__doc__ = doc + mock_tool._name = model_name + if params is None: + mock_tool._params = [CoreParameterSchema(name="param1", type="string", description="Param 1")] + else: + mock_tool._params = params + return mock_tool + +def assert_pydantic_models_equivalent(model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str): + assert issubclass(model_cls1, BaseModel), "model_cls1 is not a Pydantic BaseModel" + assert issubclass(model_cls2, BaseModel), "model_cls2 is not a Pydantic BaseModel" + + assert model_cls1.__name__ == expected_model_name, f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}" + assert model_cls2.__name__ == expected_model_name, f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}" + + fields1 = model_cls1.model_fields + fields2 = model_cls2.model_fields + + assert fields1.keys() == fields2.keys(), \ + f"Field names mismatch: {fields1.keys()} != {fields2.keys()}" + + for field_name in fields1.keys(): + field_info1 = fields1[field_name] + field_info2 = fields2[field_name] + + assert field_info1.annotation == field_info2.annotation, \ + f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})" + assert field_info1.description == field_info2.description, \ + f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')" + is_required1 = field_info1.is_required() if hasattr(field_info1, 'is_required') else not field_info1.is_nullable() + is_required2 = field_info2.is_required() if hasattr(field_info2, 'is_required') else not field_info2.is_nullable() + assert is_required1 == is_required2, \ + f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})" + class TestToolboxClient: @pytest.fixture() @@ -31,35 +70,35 @@ def toolbox_client(self): client = ToolboxClient(URL) assert isinstance(client, ToolboxClient) assert client._ToolboxClient__core_client is not None - assert client._ToolboxClient__core_client._async_client is not None - assert client._ToolboxClient__core_client._loop is not None - assert client._ToolboxClient__core_client._loop.is_running() - assert client._ToolboxClient__core_client._thread is not None - assert client._ToolboxClient__core_client._thread.is_alive() return client @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") def test_load_tool(self, mock_core_load_tool, toolbox_client): - mock_core_tool_instance = Mock( - spec=ToolboxCoreSyncTool - ) # Spec with Core Sync Tool - mock_core_tool_instance.__name__ = "mock-core-sync-tool" - mock_core_tool_instance.__doc__ = "mock core sync description" - - mock_underlying_async_tool = Mock( - spec=ToolboxCoreTool - ) # Core Async Tool for pydantic model - mock_underlying_async_tool._pydantic_model = BaseModel - mock_core_tool_instance._async_tool = mock_underlying_async_tool - + mock_core_tool_instance = create_mock_core_sync_tool( + name="test_tool_sync", + doc="Sync tool description.", + model_name="TestToolSyncModel", + params=[CoreParameterSchema(name="sp1", type="integer", description="Sync Param 1")] + ) mock_core_load_tool.return_value = mock_core_tool_instance - + langchain_tool = toolbox_client.load_tool("test_tool") - + assert isinstance(langchain_tool, ToolboxTool) assert langchain_tool.name == mock_core_tool_instance.__name__ assert langchain_tool.description == mock_core_tool_instance.__doc__ - assert langchain_tool.args_schema == mock_underlying_async_tool._pydantic_model + + # Generate the expected schema once for comparison + expected_args_schema = params_to_pydantic_model( + mock_core_tool_instance._name, + mock_core_tool_instance._params + ) + + assert_pydantic_models_equivalent( + langchain_tool.args_schema, + expected_args_schema, + mock_core_tool_instance._name + ) mock_core_load_tool.assert_called_once_with( name="test_tool", auth_token_getters={}, bound_params={} @@ -67,19 +106,8 @@ def test_load_tool(self, mock_core_load_tool, toolbox_client): @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") def test_load_toolset(self, mock_core_load_toolset, toolbox_client): - mock_core_tool_instance1 = Mock(spec=ToolboxCoreSyncTool) - mock_core_tool_instance1.__name__ = "mock-core-sync-tool-0" - mock_core_tool_instance1.__doc__ = "desc 0" - mock_async_tool0 = Mock(spec=ToolboxCoreTool) - mock_async_tool0._pydantic_model = BaseModel - mock_core_tool_instance1._async_tool = mock_async_tool0 - - mock_core_tool_instance2 = Mock(spec=ToolboxCoreSyncTool) - mock_core_tool_instance2.__name__ = "mock-core-sync-tool-1" - mock_core_tool_instance2.__doc__ = "desc 1" - mock_async_tool1 = Mock(spec=ToolboxCoreTool) - mock_async_tool1._pydantic_model = BaseModel - mock_core_tool_instance2._async_tool = mock_async_tool1 + mock_core_tool_instance1 = create_mock_core_sync_tool(name="tool-0", doc="desc 0", model_name="T0Model") + mock_core_tool_instance2 = create_mock_core_sync_tool(name="tool-1", doc="desc 1", model_name="T1Model", params=[]) mock_core_load_toolset.return_value = [ mock_core_tool_instance1, @@ -88,82 +116,102 @@ def test_load_toolset(self, mock_core_load_toolset, toolbox_client): langchain_tools = toolbox_client.load_toolset() assert len(langchain_tools) == 2 - assert isinstance(langchain_tools[0], ToolboxTool) - assert isinstance(langchain_tools[1], ToolboxTool) - assert langchain_tools[0].name == "mock-core-sync-tool-0" - assert langchain_tools[1].name == "mock-core-sync-tool-1" + + tool_instances_mocks = [mock_core_tool_instance1, mock_core_tool_instance2] + for i, tool_instance_mock in enumerate(tool_instances_mocks): + langchain_tool = langchain_tools[i] + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == tool_instance_mock.__name__ + assert langchain_tool.description == tool_instance_mock.__doc__ + + expected_args_schema = params_to_pydantic_model( + tool_instance_mock._name, + tool_instance_mock._params + ) + assert_pydantic_models_equivalent( + langchain_tool.args_schema, + expected_args_schema, + tool_instance_mock._name + ) mock_core_load_toolset.assert_called_once_with( name=None, auth_token_getters={}, bound_params={}, strict=False ) @pytest.mark.asyncio - @patch("toolbox_core.client.ToolboxClient.load_tool") - async def test_aload_tool(self, mock_core_aload_tool, toolbox_client): - mock_core_tool_instance = AsyncMock( - spec=ToolboxCoreTool - ) # *** Use AsyncMock for async method return *** - mock_core_tool_instance.__name__ = "mock-core-async-tool" - mock_core_tool_instance.__doc__ = "mock core async description" - mock_core_tool_instance._pydantic_model = BaseModel - mock_core_aload_tool.return_value = mock_core_tool_instance + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client): + mock_core_sync_tool_instance = create_mock_core_sync_tool( + name="test_async_loaded_tool", + doc="Async loaded sync tool description.", + model_name="AsyncTestToolModel" + ) + mock_sync_core_load_tool.return_value = mock_core_sync_tool_instance langchain_tool = await toolbox_client.aload_tool("test_tool") assert isinstance(langchain_tool, ToolboxTool) - assert langchain_tool.name == mock_core_tool_instance.__name__ - assert langchain_tool.description == mock_core_tool_instance.__doc__ - - toolbox_client._ToolboxClient__core_client._async_client.load_tool.assert_called_once_with( + assert langchain_tool.name == mock_core_sync_tool_instance.__name__ + assert langchain_tool.description == mock_core_sync_tool_instance.__doc__ + + expected_args_schema = params_to_pydantic_model( + mock_core_sync_tool_instance._name, + mock_core_sync_tool_instance._params + ) + assert_pydantic_models_equivalent( + langchain_tool.args_schema, + expected_args_schema, + mock_core_sync_tool_instance._name + ) + + mock_sync_core_load_tool.assert_called_once_with( name="test_tool", auth_token_getters={}, bound_params={} ) @pytest.mark.asyncio - @patch("toolbox_core.client.ToolboxClient.load_toolset") - async def test_aload_toolset(self, mock_core_aload_toolset, toolbox_client): - mock_core_tool_instance1 = AsyncMock( - spec=ToolboxCoreTool - ) # *** Use AsyncMock *** - mock_core_tool_instance1.__name__ = "mock-core-async-tool-0" - mock_core_tool_instance1.__doc__ = "desc 0" - mock_core_tool_instance1._pydantic_model = BaseModel - - mock_core_tool_instance2 = AsyncMock( - spec=ToolboxCoreTool - ) # *** Use AsyncMock *** - mock_core_tool_instance2.__name__ = "mock-core-async-tool-1" - mock_core_tool_instance2.__doc__ = "desc 1" - mock_core_tool_instance2._pydantic_model = BaseModel - - mock_core_aload_toolset.return_value = [ - mock_core_tool_instance1, - mock_core_tool_instance2, + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + async def test_aload_toolset(self, mock_sync_core_load_toolset, toolbox_client): + mock_core_sync_tool1 = create_mock_core_sync_tool(name="async-tool-0", doc="async desc 0", model_name="AT0Model") + mock_core_sync_tool2 = create_mock_core_sync_tool(name="async-tool-1", doc="async desc 1", model_name="AT1Model", params=[CoreParameterSchema(name="p1", type="string", description="P1")]) + + mock_sync_core_load_toolset.return_value = [ + mock_core_sync_tool1, + mock_core_sync_tool2, ] langchain_tools = await toolbox_client.aload_toolset() assert len(langchain_tools) == 2 - assert isinstance(langchain_tools[0], ToolboxTool) - assert isinstance(langchain_tools[1], ToolboxTool) + + tool_instances_mocks = [mock_core_sync_tool1, mock_core_sync_tool2] + for i, tool_instance_mock in enumerate(tool_instances_mocks): + langchain_tool = langchain_tools[i] + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == tool_instance_mock.__name__ + + expected_args_schema = params_to_pydantic_model( + tool_instance_mock._name, + tool_instance_mock._params + ) + assert_pydantic_models_equivalent( + langchain_tool.args_schema, + expected_args_schema, + tool_instance_mock._name + ) - toolbox_client._ToolboxClient__core_client._async_client.load_toolset.assert_called_once_with( + mock_sync_core_load_toolset.assert_called_once_with( name=None, auth_token_getters={}, bound_params={}, strict=False ) @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): - mock_core_tool_instance = Mock(spec=ToolboxCoreSyncTool) - mock_core_tool_instance.__name__ = "mock-tool" - mock_async_tool = Mock(spec=ToolboxCoreTool) - mock_async_tool._pydantic_model = BaseModel - mock_core_tool_instance._async_tool = mock_async_tool + mock_core_tool_instance = create_mock_core_sync_tool() mock_core_load_tool.return_value = mock_core_tool_instance auth_token_getters = {"token_getter1": lambda: "value1"} auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - - # Test case where auth_token_getters takes precedence + # Scenario 1: auth_token_getters takes precedence with pytest.warns(DeprecationWarning) as record: tool = toolbox_client.load_tool( "test_tool_name", @@ -172,98 +220,101 @@ def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): auth_headers=auth_headers_deprecated, bound_params=bound_params, ) - # Expect two warnings: one for auth_tokens, one for auth_headers assert len(record) == 2 - messages = [str(r.message) for r in record] - assert any("auth_tokens` is deprecated" in m for m in messages) - assert any("auth_headers` is deprecated" in m for m in messages) - + messages = sorted([str(r.message) for r in record]) + # Warning for auth_headers when auth_token_getters is also present + assert "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." in messages + # Warning for auth_tokens when auth_token_getters is also present + assert "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used." in messages + assert isinstance(tool, ToolboxTool) - mock_core_load_tool.assert_called_with( # Use called_with for flexibility if called multiple times in setup + mock_core_load_tool.assert_called_with( name="test_tool_name", auth_token_getters=auth_token_getters, bound_params=bound_params, ) - mock_core_load_tool.reset_mock() # Reset for next test case - - # Test case where auth_tokens is used (auth_token_getters is None) - with pytest.warns(DeprecationWarning, match="auth_tokens` is deprecated"): + mock_core_load_tool.reset_mock() + + # Scenario 2: auth_tokens and auth_headers provided, auth_token_getters is default (empty initially) + with pytest.warns(DeprecationWarning) as record: toolbox_client.load_tool( "test_tool_name_2", - auth_tokens=auth_tokens_deprecated, - auth_headers=auth_headers_deprecated, # This will also warn + auth_tokens=auth_tokens_deprecated, # This will be used for auth_token_getters + auth_headers=auth_headers_deprecated, # This will warn as auth_token_getters is now populated bound_params=bound_params, ) + assert len(record) == 2 + messages = sorted([str(r.message) for r in record]) + + assert messages[0] == "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead." + assert messages[1] == "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." + + expected_getters_for_call = auth_tokens_deprecated + mock_core_load_tool.assert_called_with( name="test_tool_name_2", - auth_token_getters=auth_tokens_deprecated, # auth_tokens becomes auth_token_getters + auth_token_getters=expected_getters_for_call, bound_params=bound_params, ) mock_core_load_tool.reset_mock() - - # Test case where auth_headers is used (auth_token_getters and auth_tokens are None) - with pytest.warns(DeprecationWarning, match="auth_headers` is deprecated"): + + with pytest.warns(DeprecationWarning, match="Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.") as record: toolbox_client.load_tool( "test_tool_name_3", auth_headers=auth_headers_deprecated, bound_params=bound_params, ) + assert len(record) == 1 + mock_core_load_tool.assert_called_with( name="test_tool_name_3", - auth_token_getters=auth_headers_deprecated, # auth_headers becomes auth_token_getters + auth_token_getters=auth_headers_deprecated, bound_params=bound_params, ) @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") def test_load_toolset_with_args(self, mock_core_load_toolset, toolbox_client): - mock_core_tool_instance = Mock(spec=ToolboxCoreSyncTool) - mock_core_tool_instance.__name__ = "mock-tool-0" - mock_async_tool = Mock(spec=ToolboxCoreTool) - mock_async_tool._pydantic_model = BaseModel - mock_core_tool_instance._async_tool = mock_async_tool + mock_core_tool_instance = create_mock_core_sync_tool(model_name="MySetModel") mock_core_load_toolset.return_value = [mock_core_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} + toolset_name = "my_toolset" - with pytest.warns(DeprecationWarning) as record: # Expect 2 warnings + with pytest.warns(DeprecationWarning) as record: tools = toolbox_client.load_toolset( - toolset_name="my_toolset", + toolset_name=toolset_name, auth_token_getters=auth_token_getters, auth_tokens=auth_tokens_deprecated, auth_headers=auth_headers_deprecated, bound_params=bound_params, - strict=False, + strict=True, ) assert len(record) == 2 - messages = [str(r.message) for r in record] - assert any("auth_tokens` is deprecated" in m for m in messages) - assert any("auth_headers` is deprecated" in m for m in messages) assert len(tools) == 1 + assert isinstance(tools[0], ToolboxTool) mock_core_load_toolset.assert_called_with( - name="my_toolset", + name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, - strict=False, + strict=True, ) @pytest.mark.asyncio - @patch("toolbox_core.client.ToolboxClient.load_tool") - async def test_aload_tool_with_args(self, mock_core_aload_tool, toolbox_client): - mock_core_tool_instance = AsyncMock(spec=ToolboxCoreTool) - mock_core_tool_instance.__name__ = "mock-tool" - mock_core_tool_instance._pydantic_model = BaseModel - mock_core_aload_tool.return_value = mock_core_tool_instance + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + async def test_aload_tool_with_args(self, mock_sync_core_load_tool, toolbox_client): + mock_core_tool_instance = create_mock_core_sync_tool(model_name="MyAsyncToolModel") + mock_sync_core_load_tool.return_value = mock_core_tool_instance auth_token_getters = {"token_getter1": lambda: "value1"} auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - with pytest.warns(DeprecationWarning) as record: # Expect 2 warnings + with pytest.warns(DeprecationWarning) as record: tool = await toolbox_client.aload_tool( "test_tool", auth_token_getters=auth_token_getters, @@ -272,50 +323,44 @@ async def test_aload_tool_with_args(self, mock_core_aload_tool, toolbox_client): bound_params=bound_params, ) assert len(record) == 2 - messages = [str(r.message) for r in record] - assert any("auth_tokens` is deprecated" in m for m in messages) - assert any("auth_headers` is deprecated" in m for m in messages) assert isinstance(tool, ToolboxTool) - toolbox_client._ToolboxClient__core_client._async_client.load_tool.assert_called_with( + mock_sync_core_load_tool.assert_called_with( name="test_tool", auth_token_getters=auth_token_getters, bound_params=bound_params, ) @pytest.mark.asyncio - @patch("toolbox_core.client.ToolboxClient.load_toolset") + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") async def test_aload_toolset_with_args( - self, mock_core_aload_toolset, toolbox_client + self, mock_sync_core_load_toolset, toolbox_client ): - mock_core_tool_instance = AsyncMock(spec=ToolboxCoreTool) - mock_core_tool_instance.__name__ = "mock-tool-0" - mock_core_tool_instance._pydantic_model = BaseModel - mock_core_aload_toolset.return_value = [mock_core_tool_instance] + mock_core_tool_instance = create_mock_core_sync_tool(model_name="MyAsyncSetModel") + mock_sync_core_load_toolset.return_value = [mock_core_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} + toolset_name = "my_async_toolset" - with pytest.warns(DeprecationWarning) as record: # Expect 2 warnings + with pytest.warns(DeprecationWarning) as record: tools = await toolbox_client.aload_toolset( - "my_toolset", + toolset_name, auth_token_getters=auth_token_getters, auth_tokens=auth_tokens_deprecated, auth_headers=auth_headers_deprecated, bound_params=bound_params, - strict=False, + strict=True, ) assert len(record) == 2 - messages = [str(r.message) for r in record] - assert any("auth_tokens` is deprecated" in m for m in messages) - assert any("auth_headers` is deprecated" in m for m in messages) assert len(tools) == 1 - toolbox_client._ToolboxClient__core_client._async_client.load_toolset.assert_called_with( - name="my_toolset", + assert isinstance(tools[0], ToolboxTool) + mock_sync_core_load_toolset.assert_called_with( + name=toolset_name, auth_token_getters=auth_token_getters, bound_params=bound_params, - strict=False, + strict=True, ) From 3637bba30fc726354bee0316447de168936f0d90 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 17 May 2025 03:07:20 +0530 Subject: [PATCH 52/53] chore: Delint --- .../toolbox-langchain/tests/test_client.py | 190 ++++++++++++------ 1 file changed, 123 insertions(+), 67 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_client.py b/packages/toolbox-langchain/tests/test_client.py index ff097198..98f29e53 100644 --- a/packages/toolbox-langchain/tests/test_client.py +++ b/packages/toolbox-langchain/tests/test_client.py @@ -16,9 +16,8 @@ import pytest from pydantic import BaseModel - -from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool from toolbox_core.utils import params_to_pydantic_model from toolbox_langchain.client import ToolboxClient @@ -26,42 +25,69 @@ URL = "http://test_url" -def create_mock_core_sync_tool(name="mock-sync-tool", doc="Mock sync description.", model_name="MockSyncModel", params=None): + +def create_mock_core_sync_tool( + name="mock-sync-tool", + doc="Mock sync description.", + model_name="MockSyncModel", + params=None, +): mock_tool = Mock(spec=ToolboxCoreSyncTool) mock_tool.__name__ = name mock_tool.__doc__ = doc mock_tool._name = model_name if params is None: - mock_tool._params = [CoreParameterSchema(name="param1", type="string", description="Param 1")] + mock_tool._params = [ + CoreParameterSchema(name="param1", type="string", description="Param 1") + ] else: mock_tool._params = params return mock_tool -def assert_pydantic_models_equivalent(model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str): + +def assert_pydantic_models_equivalent( + model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str +): assert issubclass(model_cls1, BaseModel), "model_cls1 is not a Pydantic BaseModel" assert issubclass(model_cls2, BaseModel), "model_cls2 is not a Pydantic BaseModel" - - assert model_cls1.__name__ == expected_model_name, f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}" - assert model_cls2.__name__ == expected_model_name, f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}" + + assert ( + model_cls1.__name__ == expected_model_name + ), f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}" + assert ( + model_cls2.__name__ == expected_model_name + ), f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}" fields1 = model_cls1.model_fields fields2 = model_cls2.model_fields - assert fields1.keys() == fields2.keys(), \ - f"Field names mismatch: {fields1.keys()} != {fields2.keys()}" + assert ( + fields1.keys() == fields2.keys() + ), f"Field names mismatch: {fields1.keys()} != {fields2.keys()}" for field_name in fields1.keys(): field_info1 = fields1[field_name] field_info2 = fields2[field_name] - assert field_info1.annotation == field_info2.annotation, \ - f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})" - assert field_info1.description == field_info2.description, \ - f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')" - is_required1 = field_info1.is_required() if hasattr(field_info1, 'is_required') else not field_info1.is_nullable() - is_required2 = field_info2.is_required() if hasattr(field_info2, 'is_required') else not field_info2.is_nullable() - assert is_required1 == is_required2, \ - f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})" + assert ( + field_info1.annotation == field_info2.annotation + ), f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})" + assert ( + field_info1.description == field_info2.description + ), f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')" + is_required1 = ( + field_info1.is_required() + if hasattr(field_info1, "is_required") + else not field_info1.is_nullable() + ) + is_required2 = ( + field_info2.is_required() + if hasattr(field_info2, "is_required") + else not field_info2.is_nullable() + ) + assert ( + is_required1 == is_required2 + ), f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})" class TestToolboxClient: @@ -78,26 +104,29 @@ def test_load_tool(self, mock_core_load_tool, toolbox_client): name="test_tool_sync", doc="Sync tool description.", model_name="TestToolSyncModel", - params=[CoreParameterSchema(name="sp1", type="integer", description="Sync Param 1")] + params=[ + CoreParameterSchema( + name="sp1", type="integer", description="Sync Param 1" + ) + ], ) mock_core_load_tool.return_value = mock_core_tool_instance - + langchain_tool = toolbox_client.load_tool("test_tool") - + assert isinstance(langchain_tool, ToolboxTool) assert langchain_tool.name == mock_core_tool_instance.__name__ assert langchain_tool.description == mock_core_tool_instance.__doc__ - + # Generate the expected schema once for comparison expected_args_schema = params_to_pydantic_model( - mock_core_tool_instance._name, - mock_core_tool_instance._params + mock_core_tool_instance._name, mock_core_tool_instance._params ) - + assert_pydantic_models_equivalent( - langchain_tool.args_schema, + langchain_tool.args_schema, expected_args_schema, - mock_core_tool_instance._name + mock_core_tool_instance._name, ) mock_core_load_tool.assert_called_once_with( @@ -106,8 +135,12 @@ def test_load_tool(self, mock_core_load_tool, toolbox_client): @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") def test_load_toolset(self, mock_core_load_toolset, toolbox_client): - mock_core_tool_instance1 = create_mock_core_sync_tool(name="tool-0", doc="desc 0", model_name="T0Model") - mock_core_tool_instance2 = create_mock_core_sync_tool(name="tool-1", doc="desc 1", model_name="T1Model", params=[]) + mock_core_tool_instance1 = create_mock_core_sync_tool( + name="tool-0", doc="desc 0", model_name="T0Model" + ) + mock_core_tool_instance2 = create_mock_core_sync_tool( + name="tool-1", doc="desc 1", model_name="T1Model", params=[] + ) mock_core_load_toolset.return_value = [ mock_core_tool_instance1, @@ -116,22 +149,21 @@ def test_load_toolset(self, mock_core_load_toolset, toolbox_client): langchain_tools = toolbox_client.load_toolset() assert len(langchain_tools) == 2 - + tool_instances_mocks = [mock_core_tool_instance1, mock_core_tool_instance2] for i, tool_instance_mock in enumerate(tool_instances_mocks): langchain_tool = langchain_tools[i] assert isinstance(langchain_tool, ToolboxTool) assert langchain_tool.name == tool_instance_mock.__name__ assert langchain_tool.description == tool_instance_mock.__doc__ - + expected_args_schema = params_to_pydantic_model( - tool_instance_mock._name, - tool_instance_mock._params + tool_instance_mock._name, tool_instance_mock._params ) assert_pydantic_models_equivalent( - langchain_tool.args_schema, + langchain_tool.args_schema, expected_args_schema, - tool_instance_mock._name + tool_instance_mock._name, ) mock_core_load_toolset.assert_called_once_with( @@ -144,7 +176,7 @@ async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client): mock_core_sync_tool_instance = create_mock_core_sync_tool( name="test_async_loaded_tool", doc="Async loaded sync tool description.", - model_name="AsyncTestToolModel" + model_name="AsyncTestToolModel", ) mock_sync_core_load_tool.return_value = mock_core_sync_tool_instance @@ -153,17 +185,16 @@ async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client): assert isinstance(langchain_tool, ToolboxTool) assert langchain_tool.name == mock_core_sync_tool_instance.__name__ assert langchain_tool.description == mock_core_sync_tool_instance.__doc__ - + expected_args_schema = params_to_pydantic_model( - mock_core_sync_tool_instance._name, - mock_core_sync_tool_instance._params + mock_core_sync_tool_instance._name, mock_core_sync_tool_instance._params ) assert_pydantic_models_equivalent( - langchain_tool.args_schema, + langchain_tool.args_schema, expected_args_schema, - mock_core_sync_tool_instance._name + mock_core_sync_tool_instance._name, ) - + mock_sync_core_load_tool.assert_called_once_with( name="test_tool", auth_token_getters={}, bound_params={} ) @@ -171,8 +202,15 @@ async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client): @pytest.mark.asyncio @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") async def test_aload_toolset(self, mock_sync_core_load_toolset, toolbox_client): - mock_core_sync_tool1 = create_mock_core_sync_tool(name="async-tool-0", doc="async desc 0", model_name="AT0Model") - mock_core_sync_tool2 = create_mock_core_sync_tool(name="async-tool-1", doc="async desc 1", model_name="AT1Model", params=[CoreParameterSchema(name="p1", type="string", description="P1")]) + mock_core_sync_tool1 = create_mock_core_sync_tool( + name="async-tool-0", doc="async desc 0", model_name="AT0Model" + ) + mock_core_sync_tool2 = create_mock_core_sync_tool( + name="async-tool-1", + doc="async desc 1", + model_name="AT1Model", + params=[CoreParameterSchema(name="p1", type="string", description="P1")], + ) mock_sync_core_load_toolset.return_value = [ mock_core_sync_tool1, @@ -181,21 +219,20 @@ async def test_aload_toolset(self, mock_sync_core_load_toolset, toolbox_client): langchain_tools = await toolbox_client.aload_toolset() assert len(langchain_tools) == 2 - + tool_instances_mocks = [mock_core_sync_tool1, mock_core_sync_tool2] for i, tool_instance_mock in enumerate(tool_instances_mocks): langchain_tool = langchain_tools[i] assert isinstance(langchain_tool, ToolboxTool) assert langchain_tool.name == tool_instance_mock.__name__ - + expected_args_schema = params_to_pydantic_model( - tool_instance_mock._name, - tool_instance_mock._params + tool_instance_mock._name, tool_instance_mock._params ) assert_pydantic_models_equivalent( - langchain_tool.args_schema, + langchain_tool.args_schema, expected_args_schema, - tool_instance_mock._name + tool_instance_mock._name, ) mock_sync_core_load_toolset.assert_called_once_with( @@ -223,10 +260,16 @@ def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): assert len(record) == 2 messages = sorted([str(r.message) for r in record]) # Warning for auth_headers when auth_token_getters is also present - assert "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." in messages + assert ( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." + in messages + ) # Warning for auth_tokens when auth_token_getters is also present - assert "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used." in messages - + assert ( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used." + in messages + ) + assert isinstance(tool, ToolboxTool) mock_core_load_tool.assert_called_with( name="test_tool_name", @@ -234,38 +277,47 @@ def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): bound_params=bound_params, ) mock_core_load_tool.reset_mock() - + # Scenario 2: auth_tokens and auth_headers provided, auth_token_getters is default (empty initially) with pytest.warns(DeprecationWarning) as record: toolbox_client.load_tool( "test_tool_name_2", - auth_tokens=auth_tokens_deprecated, # This will be used for auth_token_getters - auth_headers=auth_headers_deprecated, # This will warn as auth_token_getters is now populated + auth_tokens=auth_tokens_deprecated, # This will be used for auth_token_getters + auth_headers=auth_headers_deprecated, # This will warn as auth_token_getters is now populated bound_params=bound_params, ) assert len(record) == 2 messages = sorted([str(r.message) for r in record]) - - assert messages[0] == "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead." - assert messages[1] == "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." - - expected_getters_for_call = auth_tokens_deprecated - + + assert ( + messages[0] + == "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead." + ) + assert ( + messages[1] + == "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." + ) + + expected_getters_for_call = auth_tokens_deprecated + mock_core_load_tool.assert_called_with( name="test_tool_name_2", auth_token_getters=expected_getters_for_call, bound_params=bound_params, ) mock_core_load_tool.reset_mock() - - with pytest.warns(DeprecationWarning, match="Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.") as record: + + with pytest.warns( + DeprecationWarning, + match="Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + ) as record: toolbox_client.load_tool( "test_tool_name_3", auth_headers=auth_headers_deprecated, bound_params=bound_params, ) assert len(record) == 1 - + mock_core_load_tool.assert_called_with( name="test_tool_name_3", auth_token_getters=auth_headers_deprecated, @@ -306,7 +358,9 @@ def test_load_toolset_with_args(self, mock_core_load_toolset, toolbox_client): @pytest.mark.asyncio @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") async def test_aload_tool_with_args(self, mock_sync_core_load_tool, toolbox_client): - mock_core_tool_instance = create_mock_core_sync_tool(model_name="MyAsyncToolModel") + mock_core_tool_instance = create_mock_core_sync_tool( + model_name="MyAsyncToolModel" + ) mock_sync_core_load_tool.return_value = mock_core_tool_instance auth_token_getters = {"token_getter1": lambda: "value1"} @@ -336,7 +390,9 @@ async def test_aload_tool_with_args(self, mock_sync_core_load_tool, toolbox_clie async def test_aload_toolset_with_args( self, mock_sync_core_load_toolset, toolbox_client ): - mock_core_tool_instance = create_mock_core_sync_tool(model_name="MyAsyncSetModel") + mock_core_tool_instance = create_mock_core_sync_tool( + model_name="MyAsyncSetModel" + ) mock_sync_core_load_toolset.return_value = [mock_core_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} From 9a34f6f3377e22a7d7d51d2410f6e8986211fbf5 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Sat, 17 May 2025 03:34:16 +0530 Subject: [PATCH 53/53] chore: Fix tools unit tests --- .../toolbox-langchain/tests/test_tools.py | 265 ++++++++++++------ 1 file changed, 179 insertions(+), 86 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 5560cf99..90fddf4b 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -12,16 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +import asyncio +from unittest.mock import AsyncMock, Mock, call, patch import pytest from pydantic import BaseModel +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool -from toolbox_core.tool import ToolboxTool as ToolboxCoreTool +from toolbox_core.tool import ToolboxTool as CoreAsyncTool +from toolbox_core.utils import params_to_pydantic_model from toolbox_langchain.tools import ToolboxTool +def assert_pydantic_models_equivalent( + model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str +): + assert issubclass(model_cls1, BaseModel), "model_cls1 is not a Pydantic BaseModel" + assert issubclass(model_cls2, BaseModel), "model_cls2 is not a Pydantic BaseModel" + assert ( + model_cls1.__name__ == expected_model_name + ), f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}" + assert ( + model_cls2.__name__ == expected_model_name + ), f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}" + + fields1 = model_cls1.model_fields + fields2 = model_cls2.model_fields + + assert ( + fields1.keys() == fields2.keys() + ), f"Field names mismatch: {fields1.keys()} != {fields2.keys()}" + + for field_name in fields1.keys(): + field_info1 = fields1[field_name] + field_info2 = fields2[field_name] + + assert ( + field_info1.annotation == field_info2.annotation + ), f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})" + assert ( + field_info1.description == field_info2.description + ), f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')" + is_required1 = ( + field_info1.is_required() + if hasattr(field_info1, "is_required") + else not field_info1.is_nullable() + ) + is_required2 = ( + field_info2.is_required() + if hasattr(field_info2, "is_required") + else not field_info2.is_nullable() + ) + assert ( + is_required1 == is_required2 + ), f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})" + + class TestToolboxTool: @pytest.fixture def tool_schema_dict(self): @@ -49,48 +96,76 @@ def auth_tool_schema_dict(self): ], } - @pytest.fixture(scope="function") - def mock_core_async_tool(self, tool_schema_dict): - mock = Mock(spec=ToolboxCoreTool) - mock.__name__ = "test_tool" - mock.__doc__ = tool_schema_dict["description"] - mock._pydantic_model = BaseModel - return mock - - @pytest.fixture(scope="function") - def mock_core_async_auth_tool(self, auth_tool_schema_dict): - mock = Mock(spec=ToolboxCoreTool) - mock.__name__ = "test_auth_tool" - mock.__doc__ = auth_tool_schema_dict["description"] - mock._pydantic_model = BaseModel - return mock - @pytest.fixture - def mock_core_tool(self, mock_core_async_tool): + def mock_core_tool(self, tool_schema_dict): sync_mock = Mock(spec=ToolboxCoreSyncTool) - sync_mock.__name__ = mock_core_async_tool.__name__ - sync_mock.__doc__ = mock_core_async_tool.__doc__ - sync_mock._async_tool = mock_core_async_tool - sync_mock.add_auth_token_getters = Mock(return_value=sync_mock) - sync_mock.bind_params = Mock(return_value=sync_mock) - sync_mock.bind_param = Mock( - return_value=sync_mock - ) # Keep this if bind_param exists on core, otherwise remove - sync_mock.__call__ = Mock(return_value="mocked_sync_call_result") + + sync_mock.__name__ = "test_tool_name_for_langchain" + sync_mock.__doc__ = tool_schema_dict["description"] + sync_mock._name = "TestToolPydanticModel" + sync_mock._params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + + mock_async_tool_attr = AsyncMock(spec=CoreAsyncTool) + mock_async_tool_attr.return_value = "dummy_internal_async_tool_result" + sync_mock._ToolboxSyncTool__async_tool = mock_async_tool_attr + sync_mock._ToolboxSyncTool__loop = Mock(spec=asyncio.AbstractEventLoop) + sync_mock._ToolboxSyncTool__thread = Mock() + + new_mock_instance_for_methods = Mock(spec=ToolboxCoreSyncTool) + new_mock_instance_for_methods.__name__ = sync_mock.__name__ + new_mock_instance_for_methods.__doc__ = sync_mock.__doc__ + new_mock_instance_for_methods._name = sync_mock._name + new_mock_instance_for_methods._params = sync_mock._params + new_mock_instance_for_methods._ToolboxSyncTool__async_tool = AsyncMock( + spec=CoreAsyncTool + ) + new_mock_instance_for_methods._ToolboxSyncTool__loop = Mock( + spec=asyncio.AbstractEventLoop + ) + new_mock_instance_for_methods._ToolboxSyncTool__thread = Mock() + + sync_mock.add_auth_token_getters = Mock( + return_value=new_mock_instance_for_methods + ) + sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) + return sync_mock @pytest.fixture - def mock_core_sync_auth_tool(self, mock_core_async_auth_tool): + def mock_core_sync_auth_tool(self, auth_tool_schema_dict): sync_mock = Mock(spec=ToolboxCoreSyncTool) - sync_mock.__name__ = mock_core_async_auth_tool.__name__ - sync_mock.__doc__ = mock_core_async_auth_tool.__doc__ - sync_mock._async_tool = mock_core_async_auth_tool - sync_mock.add_auth_token_getters = Mock(return_value=sync_mock) - sync_mock.bind_params = Mock(return_value=sync_mock) - sync_mock.bind_param = Mock( - return_value=sync_mock - ) # Keep this if bind_param exists on core - sync_mock.__call__ = Mock(return_value="mocked_auth_sync_call_result") + sync_mock.__name__ = "test_auth_tool_lc_name" + sync_mock.__doc__ = auth_tool_schema_dict["description"] + sync_mock._name = "TestAuthToolPydanticModel" + sync_mock._params = [ + CoreParameterSchema(**p) for p in auth_tool_schema_dict["parameters"] + ] + + mock_async_tool_attr = AsyncMock(spec=CoreAsyncTool) + mock_async_tool_attr.return_value = "dummy_internal_async_auth_tool_result" + sync_mock._ToolboxSyncTool__async_tool = mock_async_tool_attr + sync_mock._ToolboxSyncTool__loop = Mock(spec=asyncio.AbstractEventLoop) + sync_mock._ToolboxSyncTool__thread = Mock() + + new_mock_instance_for_methods = Mock(spec=ToolboxCoreSyncTool) + new_mock_instance_for_methods.__name__ = sync_mock.__name__ + new_mock_instance_for_methods.__doc__ = sync_mock.__doc__ + new_mock_instance_for_methods._name = sync_mock._name + new_mock_instance_for_methods._params = sync_mock._params + new_mock_instance_for_methods._ToolboxSyncTool__async_tool = AsyncMock( + spec=CoreAsyncTool + ) + new_mock_instance_for_methods._ToolboxSyncTool__loop = Mock( + spec=asyncio.AbstractEventLoop + ) + new_mock_instance_for_methods._ToolboxSyncTool__thread = Mock() + + sync_mock.add_auth_token_getters = Mock( + return_value=new_mock_instance_for_methods + ) + sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) return sync_mock @pytest.fixture @@ -103,110 +178,128 @@ def auth_toolbox_tool(self, mock_core_sync_auth_tool): def test_toolbox_tool_init(self, mock_core_tool): tool = ToolboxTool(core_tool=mock_core_tool) - core_tool_in_tool = tool._ToolboxTool__core_tool - assert core_tool_in_tool.__name__ == mock_core_tool.__name__ - assert core_tool_in_tool.__doc__ == mock_core_tool.__doc__ - assert tool.args_schema == mock_core_tool._async_tool._pydantic_model + + assert tool.name == mock_core_tool.__name__ + assert tool.description == mock_core_tool.__doc__ + assert tool._ToolboxTool__core_tool == mock_core_tool + + expected_args_schema = params_to_pydantic_model( + mock_core_tool._name, mock_core_tool._params + ) + assert_pydantic_models_equivalent( + tool.args_schema, expected_args_schema, mock_core_tool._name + ) @pytest.mark.parametrize( - "params, expected_bound_params_on_core", + "params", [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), + ({"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}), + ({"param1": "bound-value", "param2": 123}), ], ) def test_toolbox_tool_bind_params( self, params, - expected_bound_params_on_core, toolbox_tool, mock_core_tool, ): - mock_core_tool.bind_params.return_value = mock_core_tool + returned_core_tool_mock = mock_core_tool.bind_params.return_value new_langchain_tool = toolbox_tool.bind_params(params) + mock_core_tool.bind_params.assert_called_once_with(params) assert isinstance(new_langchain_tool, ToolboxTool) - assert ( - new_langchain_tool._ToolboxTool__core_tool - == mock_core_tool.bind_params.return_value - ) + assert new_langchain_tool._ToolboxTool__core_tool == returned_core_tool_mock def test_toolbox_tool_bind_param(self, toolbox_tool, mock_core_tool): - # ToolboxTool.bind_param calls core_tool.bind_params - mock_core_tool.bind_params.return_value = mock_core_tool + returned_core_tool_mock = mock_core_tool.bind_params.return_value new_langchain_tool = toolbox_tool.bind_param("param1", "bound-value") - # *** Fix: Assert that bind_params is called on the core tool *** + mock_core_tool.bind_params.assert_called_once_with({"param1": "bound-value"}) assert isinstance(new_langchain_tool, ToolboxTool) - assert ( - new_langchain_tool._ToolboxTool__core_tool - == mock_core_tool.bind_params.return_value - ) + assert new_langchain_tool._ToolboxTool__core_tool == returned_core_tool_mock @pytest.mark.parametrize( - "auth_token_getters, expected_auth_getters_on_core", + "auth_token_getters", [ - ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), + ({"test-auth-source": lambda: "test-token"}), ( { "test-auth-source": lambda: "test-token", "another-auth-source": lambda: "another-token", - }, - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, + } ), ], ) def test_toolbox_tool_add_auth_token_getters( self, auth_token_getters, - expected_auth_getters_on_core, auth_toolbox_tool, mock_core_sync_auth_tool, ): - mock_core_sync_auth_tool.add_auth_token_getters.return_value = ( - mock_core_sync_auth_tool + returned_core_tool_mock = ( + mock_core_sync_auth_tool.add_auth_token_getters.return_value ) new_langchain_tool = auth_toolbox_tool.add_auth_token_getters( auth_token_getters ) + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( auth_token_getters ) assert isinstance(new_langchain_tool, ToolboxTool) - assert ( - new_langchain_tool._ToolboxTool__core_tool - == mock_core_sync_auth_tool.add_auth_token_getters.return_value - ) + assert new_langchain_tool._ToolboxTool__core_tool == returned_core_tool_mock def test_toolbox_tool_add_auth_token_getter( self, auth_toolbox_tool, mock_core_sync_auth_tool ): get_id_token = lambda: "test-token" - # ToolboxTool.add_auth_token_getter calls core_tool.add_auth_token_getters - mock_core_sync_auth_tool.add_auth_token_getters.return_value = ( - mock_core_sync_auth_tool + returned_core_tool_mock = ( + mock_core_sync_auth_tool.add_auth_token_getters.return_value ) new_langchain_tool = auth_toolbox_tool.add_auth_token_getter( "test-auth-source", get_id_token ) - # *** Fix: Assert that add_auth_token_getters is called on the core tool *** mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( {"test-auth-source": get_id_token} ) assert isinstance(new_langchain_tool, ToolboxTool) - assert ( - new_langchain_tool._ToolboxTool__core_tool - == mock_core_sync_auth_tool.add_auth_token_getters.return_value + assert new_langchain_tool._ToolboxTool__core_tool == returned_core_tool_mock + + def test_toolbox_tool_run(self, toolbox_tool, mock_core_tool): + kwargs_to_run = {"param1": "run_value1", "param2": 100} + expected_result = "sync_run_output" + mock_core_tool.return_value = expected_result + + result = toolbox_tool._run(**kwargs_to_run) + + assert result == expected_result + assert mock_core_tool.call_count == 1 + assert mock_core_tool.call_args == call(**kwargs_to_run) + + @pytest.mark.asyncio + @patch("toolbox_langchain.tools.to_thread", new_callable=AsyncMock) + async def test_toolbox_tool_arun( + self, mock_to_thread_in_tools, toolbox_tool, mock_core_tool + ): + kwargs_to_run = {"param1": "arun_value1", "param2": 200} + expected_result = "async_run_output" + + mock_core_tool.return_value = expected_result + + async def to_thread_side_effect(func, *args, **kwargs_for_func): + return func(**kwargs_for_func) + + mock_to_thread_in_tools.side_effect = to_thread_side_effect + + result = await toolbox_tool._arun(**kwargs_to_run) + + assert result == expected_result + mock_to_thread_in_tools.assert_awaited_once_with( + mock_core_tool, **kwargs_to_run ) + + assert mock_core_tool.call_count == 1 + assert mock_core_tool.call_args == call(**kwargs_to_run)