From 1c0dfca316a9a492f9a1cdd85499ae6fe1b3e778 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 19 Jun 2025 19:16:02 +0530 Subject: [PATCH 1/4] feat(toolbox-langchain): Implement self-authenticated tools --- .../src/toolbox_langchain/async_tools.py | 35 +++++++++- .../src/toolbox_langchain/tools.py | 69 +++++++++++++++++-- 2 files changed, 96 insertions(+), 8 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index fee763c3..971ebd24 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -15,6 +15,7 @@ from typing import Any, Callable, Union from deprecated import deprecated +from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from toolbox_core.tool import ToolboxTool as ToolboxCoreTool from toolbox_core.utils import params_to_pydantic_model @@ -52,7 +53,11 @@ def __init__( def _run(self, **kwargs: Any) -> str: raise NotImplementedError("Synchronous methods not supported by async tools.") - async def _arun(self, **kwargs: Any) -> str: + async def _arun( + self, + config: RunnableConfig, + **kwargs: Any, + ) -> str: """ The coroutine that invokes the tool with the given arguments. @@ -63,7 +68,33 @@ async def _arun(self, **kwargs: Any) -> str: A dictionary containing the parsed JSON response from the tool invocation. """ - return await self.__core_tool(**kwargs) + tool_to_run = self.__core_tool + if ( + config + and "configurable" in config + and "auth_token_getters" in config["configurable"] + ): + auth_token_getters = config["configurable"]["auth_token_getters"] + if auth_token_getters: + + # The `add_auth_token_getters` method requires that all provided + # getters are used by the tool. To prevent validation errors, + # filter the incoming getters to include only those that this + # specific tool requires. + required_auth_keys = set(self.__core_tool._required_authz_tokens) + for auth_list in self.__core_tool._required_authn_params.values(): + required_auth_keys.update(auth_list) + filtered_getters = { + k: v + for k, v in auth_token_getters.items() + if k in required_auth_keys + } + if filtered_getters: + tool_to_run = self.__core_tool.add_auth_token_getters( + filtered_getters + ) + + return await tool_to_run(**kwargs) def add_auth_token_getters( self, auth_token_getters: dict[str, Callable[[], str]] diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index e03b37f8..34654882 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -13,9 +13,10 @@ # limitations under the License. from asyncio import to_thread -from typing import Any, Awaitable, Callable, Mapping, Sequence, Union +from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union from deprecated import deprecated +from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool from toolbox_core.utils import params_to_pydantic_model @@ -73,11 +74,67 @@ def _client_headers( ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]: return self.__core_tool._client_headers - def _run(self, **kwargs: Any) -> str: - return self.__core_tool(**kwargs) - - async def _arun(self, **kwargs: Any) -> str: - return await to_thread(self.__core_tool, **kwargs) + def _run( + self, + config: RunnableConfig, + **kwargs: Any, + ) -> str: + tool_to_run = self.__core_tool + if ( + config + and "configurable" in config + and "auth_token_getters" in config["configurable"] + ): + auth_token_getters = config["configurable"]["auth_token_getters"] + if auth_token_getters: + + # The `add_auth_token_getters` method requires that all provided + # getters are used by the tool. To prevent validation errors, + # filter the incoming getters to include only those that this + # specific tool requires. + required_auth_keys = set(self.__core_tool._required_authz_tokens) + for auth_list in self.__core_tool._required_authn_params.values(): + required_auth_keys.update(auth_list) + filtered_getters = { + k: v + for k, v in auth_token_getters.items() + if k in required_auth_keys + } + if filtered_getters: + tool_to_run = self.__core_tool.add_auth_token_getters( + filtered_getters + ) + + return tool_to_run(**kwargs) + + async def _arun(self, config: RunnableConfig, **kwargs: Any) -> str: + tool_to_run = self.__core_tool + if ( + config + and "configurable" in config + and "auth_token_getters" in config["configurable"] + ): + auth_token_getters = config["configurable"]["auth_token_getters"] + if auth_token_getters: + + # The `add_auth_token_getters` method requires that all provided + # getters are used by the tool. To prevent validation errors, + # filter the incoming getters to include only those that this + # specific tool requires. + required_auth_keys = set(self.__core_tool._required_authz_tokens) + for auth_list in self.__core_tool._required_authn_params.values(): + required_auth_keys.update(auth_list) + filtered_getters = { + k: v + for k, v in auth_token_getters.items() + if k in required_auth_keys + } + if filtered_getters: + tool_to_run = self.__core_tool.add_auth_token_getters( + filtered_getters + ) + + return await to_thread(tool_to_run, **kwargs) def add_auth_token_getters( self, auth_token_getters: dict[str, Callable[[], str]] From 980a0d2b64c223f23601c1d11eb87b86f84b84d3 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Thu, 19 Jun 2025 20:06:34 +0530 Subject: [PATCH 2/4] chore: Fix unit tests --- packages/toolbox-langchain/tests/test_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 56d574da..3f890598 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -286,7 +286,7 @@ def test_toolbox_tool_run(self, toolbox_tool, mock_core_tool): expected_result = "sync_run_output" mock_core_tool.return_value = expected_result - result = toolbox_tool._run(**kwargs_to_run) + result = toolbox_tool._run(**kwargs_to_run, config={}) assert result == expected_result assert mock_core_tool.call_count == 1 @@ -307,7 +307,7 @@ async def to_thread_side_effect(func, *args, **kwargs_for_func): mock_to_thread_in_tools.side_effect = to_thread_side_effect - result = await toolbox_tool._arun(**kwargs_to_run) + result = await toolbox_tool._arun(**kwargs_to_run, config={}) assert result == expected_result mock_to_thread_in_tools.assert_awaited_once_with( From 9ba40e5e88fa9ab287ab128a8c6f2651b5040ca8 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 23 Jun 2025 23:13:37 +0530 Subject: [PATCH 3/4] chore: Refactor getting tool to run into a reusable helper --- .../src/toolbox_langchain/tools.py | 42 +++++-------------- 1 file changed, 10 insertions(+), 32 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index 34654882..c41dcca0 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -13,7 +13,7 @@ # limitations under the License. from asyncio import to_thread -from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Awaitable, Callable, Mapping, Sequence, Union from deprecated import deprecated from langchain_core.runnables import RunnableConfig @@ -74,11 +74,7 @@ def _client_headers( ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]: return self.__core_tool._client_headers - def _run( - self, - config: RunnableConfig, - **kwargs: Any, - ) -> str: + def __get_tool_to_run(self, config: RunnableConfig) -> ToolboxCoreSyncTool: tool_to_run = self.__core_tool if ( config @@ -104,36 +100,18 @@ def _run( tool_to_run = self.__core_tool.add_auth_token_getters( filtered_getters ) + return tool_to_run + def _run( + self, + config: RunnableConfig, + **kwargs: Any, + ) -> str: + tool_to_run = self.__get_tool_to_run(config) return tool_to_run(**kwargs) async def _arun(self, config: RunnableConfig, **kwargs: Any) -> str: - tool_to_run = self.__core_tool - if ( - config - and "configurable" in config - and "auth_token_getters" in config["configurable"] - ): - auth_token_getters = config["configurable"]["auth_token_getters"] - if auth_token_getters: - - # The `add_auth_token_getters` method requires that all provided - # getters are used by the tool. To prevent validation errors, - # filter the incoming getters to include only those that this - # specific tool requires. - required_auth_keys = set(self.__core_tool._required_authz_tokens) - for auth_list in self.__core_tool._required_authn_params.values(): - required_auth_keys.update(auth_list) - filtered_getters = { - k: v - for k, v in auth_token_getters.items() - if k in required_auth_keys - } - if filtered_getters: - tool_to_run = self.__core_tool.add_auth_token_getters( - filtered_getters - ) - + tool_to_run = self.__get_tool_to_run(config) return await to_thread(tool_to_run, **kwargs) def add_auth_token_getters( From f41566a5cac43cd632564f2ed5ba5b3c0a8eca5f Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 23 Jun 2025 23:23:36 +0530 Subject: [PATCH 4/4] chore: Make variable names more intuitive --- .../toolbox-langchain/src/toolbox_langchain/async_tools.py | 6 +++--- packages/toolbox-langchain/src/toolbox_langchain/tools.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index 971ebd24..530f992a 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -81,13 +81,13 @@ async def _arun( # getters are used by the tool. To prevent validation errors, # filter the incoming getters to include only those that this # specific tool requires. - required_auth_keys = set(self.__core_tool._required_authz_tokens) + req_auth_services = set(self.__core_tool._required_authz_tokens) for auth_list in self.__core_tool._required_authn_params.values(): - required_auth_keys.update(auth_list) + req_auth_services.update(auth_list) filtered_getters = { k: v for k, v in auth_token_getters.items() - if k in required_auth_keys + if k in req_auth_services } if filtered_getters: tool_to_run = self.__core_tool.add_auth_token_getters( diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index c41dcca0..e300bd3c 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -88,13 +88,13 @@ def __get_tool_to_run(self, config: RunnableConfig) -> ToolboxCoreSyncTool: # getters are used by the tool. To prevent validation errors, # filter the incoming getters to include only those that this # specific tool requires. - required_auth_keys = set(self.__core_tool._required_authz_tokens) + req_auth_services = set(self.__core_tool._required_authz_tokens) for auth_list in self.__core_tool._required_authn_params.values(): - required_auth_keys.update(auth_list) + req_auth_services.update(auth_list) filtered_getters = { k: v for k, v in auth_token_getters.items() - if k in required_auth_keys + if k in req_auth_services } if filtered_getters: tool_to_run = self.__core_tool.add_auth_token_getters(