diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 500f03ae..e03b37f8 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -13,7 +13,7 @@ # limitations under the License. from asyncio import to_thread -from typing import Any, Callable, Union +from typing import Any, Awaitable, Callable, Mapping, Sequence, Union from deprecated import deprecated from langchain_core.tools import BaseTool @@ -47,6 +47,32 @@ def __init__( ) self.__core_tool = core_tool + @property + def _bound_params( + self, + ) -> Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]: + return self.__core_tool._bound_params + + @property + def _required_authn_params(self) -> Mapping[str, list[str]]: + return self.__core_tool._required_authn_params + + @property + def _required_authz_tokens(self) -> Sequence[str]: + return self.__core_tool._required_authz_tokens + + @property + def _auth_service_token_getters( + self, + ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]: + return self.__core_tool._auth_service_token_getters + + @property + def _client_headers( + self, + ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]: + return self.__core_tool._client_headers + def _run(self, **kwargs: Any) -> str: return self.__core_tool(**kwargs) diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 90fddf4b..56d574da 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -131,6 +131,12 @@ def mock_core_tool(self, tool_schema_dict): ) sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) + sync_mock._bound_params = {"mock_bound_param": "mock_bound_value"} + sync_mock._required_authn_params = {"mock_auth_source": ["mock_param"]} + sync_mock._required_authz_tokens = ["mock_authz_token"] + sync_mock._auth_service_token_getters = {"mock_service": lambda: "mock_token"} + sync_mock._client_headers = {"mock_header": "mock_header_value"} + return sync_mock @pytest.fixture @@ -166,6 +172,13 @@ def mock_core_sync_auth_tool(self, auth_tool_schema_dict): return_value=new_mock_instance_for_methods ) sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) + + sync_mock._bound_params = {"mock_bound_param": "mock_bound_value"} + sync_mock._required_authn_params = {"mock_auth_source": ["mock_param"]} + sync_mock._required_authz_tokens = ["mock_authz_token"] + sync_mock._auth_service_token_getters = {"mock_service": lambda: "mock_token"} + sync_mock._client_headers = {"mock_header": "mock_header_value"} + return sync_mock @pytest.fixture @@ -303,3 +316,48 @@ async def to_thread_side_effect(func, *args, **kwargs_for_func): assert mock_core_tool.call_count == 1 assert mock_core_tool.call_args == call(**kwargs_to_run) + + def test_toolbox_tool_properties(self, toolbox_tool, mock_core_tool): + """Tests that the properties correctly proxy to the core tool.""" + assert toolbox_tool._bound_params == mock_core_tool._bound_params + assert ( + toolbox_tool._required_authn_params == mock_core_tool._required_authn_params + ) + assert ( + toolbox_tool._required_authz_tokens == mock_core_tool._required_authz_tokens + ) + assert ( + toolbox_tool._auth_service_token_getters + == mock_core_tool._auth_service_token_getters + ) + assert toolbox_tool._client_headers == mock_core_tool._client_headers + + def test_toolbox_tool_add_auth_tokens_deprecated( + self, auth_toolbox_tool, mock_core_sync_auth_tool + ): + """Tests the deprecated add_auth_tokens method.""" + auth_tokens = {"test-auth-source": lambda: "test-token"} + with pytest.warns(DeprecationWarning): + new_tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) + + # Check that the call was correctly forwarded to the new method on the core tool + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + auth_tokens + ) + assert isinstance(new_tool, ToolboxTool) + + def test_toolbox_tool_add_auth_token_deprecated( + self, auth_toolbox_tool, mock_core_sync_auth_tool + ): + """Tests the deprecated add_auth_token method.""" + get_id_token = lambda: "test-token" + with pytest.warns(DeprecationWarning): + new_tool = auth_toolbox_tool.add_auth_token( + "test-auth-source", get_id_token + ) + + # Check that the call was correctly forwarded to the new method on the core tool + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + {"test-auth-source": get_id_token} + ) + assert isinstance(new_tool, ToolboxTool)