From d0d61f5df5f8563a14e55f73f0065ded3fe42999 Mon Sep 17 00:00:00 2001 From: Anubhav Dhawan Date: Mon, 21 Apr 2025 03:13:35 +0530 Subject: [PATCH] chore: Update auth_token(s) as auth_token_getter(s) and add_auth_token(s) as add_auth_token_getter(s) This is to align with the toolbox-core APIs and also because the suffix of _getter(s) make these APIs more descriptive and accurate. It actively encourages best practices regarding token security and lifecycle management (like refresh logic), which is crucial for production systems. --- packages/toolbox-langchain/README.md | 12 +-- .../src/toolbox_langchain/async_client.py | 72 ++++++++++++----- .../src/toolbox_langchain/async_tools.py | 40 +++++----- .../src/toolbox_langchain/client.py | 68 +++++++++++----- .../src/toolbox_langchain/tools.py | 14 ++-- .../tests/test_async_client.py | 4 +- .../tests/test_async_tools.py | 26 +++--- .../toolbox-langchain/tests/test_client.py | 80 ++++++++++++++----- packages/toolbox-langchain/tests/test_e2e.py | 24 +++--- .../toolbox-langchain/tests/test_tools.py | 48 ++++++----- 10 files changed, 249 insertions(+), 139 deletions(-) diff --git a/packages/toolbox-langchain/README.md b/packages/toolbox-langchain/README.md index 268f3c66..9f698694 100644 --- a/packages/toolbox-langchain/README.md +++ b/packages/toolbox-langchain/README.md @@ -225,21 +225,21 @@ async def get_auth_token(): toolbox = ToolboxClient("http://127.0.0.1:5000") tools = toolbox.load_toolset() -auth_tool = tools[0].add_auth_token("my_auth", get_auth_token) # Single token +auth_tool = tools[0].add_auth_token_getter("my_auth", get_auth_token) # Single token -multi_auth_tool = tools[0].add_auth_tokens({"my_auth", get_auth_token}) # Multiple tokens +multi_auth_tool = tools[0].add_auth_token_getters({"my_auth", get_auth_token}) # Multiple tokens # OR -auth_tools = [tool.add_auth_token("my_auth", get_auth_token) for tool in tools] +auth_tools = [tool.add_auth_token_getter("my_auth", get_auth_token) for tool in tools] ``` #### Add Authentication While Loading ```py -auth_tool = toolbox.load_tool(auth_tokens={"my_auth": get_auth_token}) +auth_tool = toolbox.load_tool(auth_token_getters={"my_auth": get_auth_token}) -auth_tools = toolbox.load_toolset(auth_tokens={"my_auth": get_auth_token}) +auth_tools = toolbox.load_toolset(auth_token_getters={"my_auth": get_auth_token}) ``` > [!NOTE] @@ -260,7 +260,7 @@ async def get_auth_token(): toolbox = ToolboxClient("http://127.0.0.1:5000") tool = toolbox.load_tool("my-tool") -auth_tool = tool.add_auth_token("my_auth", get_auth_token) +auth_tool = tool.add_auth_token_getter("my_auth", get_auth_token) result = auth_tool.invoke({"input": "some input"}) print(result) ``` diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py index b65c8ccf..aacbc5af 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py @@ -44,7 +44,8 @@ def __init__( async def aload_tool( self, tool_name: str, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, @@ -54,9 +55,10 @@ async def aload_tool( Args: tool_name: The name of the tool to load. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. strict: If True, raises a ValueError if any of the given bound @@ -67,17 +69,30 @@ async def aload_tool( A tool loaded from the Toolbox. """ if auth_headers: - if auth_tokens: + if auth_token_getters: warn( - "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_tokens = auth_headers + auth_token_getters = auth_headers + + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens url = f"{self.__url}/api/tool/{tool_name}" manifest: ManifestSchema = await _load_manifest(url, self.__session) @@ -87,7 +102,7 @@ async def aload_tool( manifest.tools[tool_name], self.__url, self.__session, - auth_tokens, + auth_token_getters, bound_params, strict, ) @@ -95,7 +110,8 @@ async def aload_tool( async def aload_toolset( self, toolset_name: Optional[str] = None, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, @@ -107,9 +123,10 @@ async def aload_toolset( Args: toolset_name: The name of the toolset to load. If not provided, all tools are loaded. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. strict: If True, raises a ValueError if any of the given bound @@ -120,17 +137,30 @@ async def aload_toolset( A list of all tools loaded from the Toolbox. """ if auth_headers: - if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + if auth_tokens: + if auth_token_getters: warn( - "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_tokens = auth_headers + auth_token_getters = auth_tokens url = f"{self.__url}/api/toolset/{toolset_name or ''}" manifest: ManifestSchema = await _load_manifest(url, self.__session) @@ -143,7 +173,7 @@ async def aload_toolset( tool_schema, self.__url, self.__session, - auth_tokens, + auth_token_getters, bound_params, strict, ) @@ -153,7 +183,8 @@ async def aload_toolset( def load_tool( self, tool_name: str, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, @@ -163,7 +194,8 @@ def load_tool( def load_toolset( self, toolset_name: Optional[str] = None, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index c7aafc12..40e21ee6 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -45,7 +45,7 @@ def __init__( schema: ToolSchema, url: str, session: ClientSession, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, ) -> None: @@ -57,8 +57,8 @@ def __init__( schema: The tool schema. url: The base URL of the Toolbox service. session: The HTTP client session. - auth_tokens: A mapping of authentication source names to functions - that retrieve ID tokens. + auth_token_getters: A mapping of authentication source names to + functions that retrieve ID tokens. bound_params: A mapping of parameter names to their bound values. strict: If True, raises a ValueError if any of the given bound @@ -132,7 +132,7 @@ def __init__( self.__schema = schema self.__url = url self.__session = session - self.__auth_tokens = auth_tokens + self.__auth_token_getters = auth_token_getters self.__auth_params = auth_params self.__bound_params = bound_params @@ -172,7 +172,7 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: kwargs.update(evaluated_params) return await _invoke_tool( - self.__url, self.__session, self.__name, kwargs, self.__auth_tokens + self.__url, self.__session, self.__name, kwargs, self.__auth_token_getters ) def __validate_auth(self, strict: bool = True) -> None: @@ -199,7 +199,7 @@ def __validate_auth(self, strict: bool = True) -> None: # Check tool for at least 1 required auth source for src in self.__schema.authRequired: - if src in self.__auth_tokens: + if src in self.__auth_token_getters: is_authenticated = True break @@ -211,7 +211,7 @@ def __validate_auth(self, strict: bool = True) -> None: for src in param.authSources: # Find first auth source that is specified - if src in self.__auth_tokens: + if src in self.__auth_token_getters: has_auth = True break if not has_auth: @@ -238,7 +238,7 @@ def __validate_auth(self, strict: bool = True) -> None: def __create_copy( self, *, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool, ) -> "AsyncToolboxTool": @@ -253,8 +253,8 @@ def __create_copy( original instance, ensuring immutability. Args: - auth_tokens: A dictionary of auth source names to functions that - retrieve ID tokens. These tokens will be merged with the + auth_token_getters: A dictionary of auth source names to functions + that retrieve ID tokens. These tokens will be merged with the existing auth tokens. bound_params: A dictionary of parameter names to their bound values or functions to retrieve the values. These params @@ -281,21 +281,21 @@ def __create_copy( schema=new_schema, url=self.__url, session=self.__session, - auth_tokens={**self.__auth_tokens, **auth_tokens}, + auth_token_getters={**self.__auth_token_getters, **auth_token_getters}, bound_params={**self.__bound_params, **bound_params}, strict=strict, ) - def add_auth_tokens( - self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + def add_auth_token_getters( + self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True ) -> "AsyncToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding authentication sources. Args: - auth_tokens: A dictionary of authentication source names to the - functions that return corresponding ID token. + auth_token_getters: A dictionary of authentication source names to + the functions that return corresponding ID token getters. strict: If True, a ValueError is raised if any of the provided auth parameters is already bound. If False, only a warning is issued. @@ -313,8 +313,8 @@ def add_auth_tokens( # Check if the authentication source is already registered. dupe_tokens: list[str] = [] - for auth_token, _ in auth_tokens.items(): - if auth_token in self.__auth_tokens: + for auth_token, _ in auth_token_getters.items(): + if auth_token in self.__auth_token_getters: dupe_tokens.append(auth_token) if dupe_tokens: @@ -322,9 +322,9 @@ def add_auth_tokens( f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." ) - return self.__create_copy(auth_tokens=auth_tokens, strict=strict) + return self.__create_copy(auth_token_getters=auth_token_getters, strict=strict) - def add_auth_token( + def add_auth_token_getter( self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True ) -> "AsyncToolboxTool": """ @@ -346,7 +346,7 @@ def add_auth_token( ValueError: If the provided auth parameter is already bound and strict is True. """ - return self.add_auth_tokens({auth_source: get_id_token}, strict=strict) + return self.add_auth_token_getters({auth_source: get_id_token}, strict=strict) def bind_params( self, diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index f30d5766..3c75779c 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -88,7 +88,8 @@ async def __run_as_async(self, coro: Awaitable[T]) -> T: async def aload_tool( self, tool_name: str, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, @@ -98,9 +99,10 @@ async def aload_tool( Args: tool_name: The name of the tool to load. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. strict: If True, raises a ValueError if any of the given bound @@ -112,7 +114,12 @@ async def aload_tool( """ async_tool = await self.__run_as_async( self.__async_client.aload_tool( - tool_name, auth_tokens, auth_headers, bound_params, strict + tool_name, + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + strict, ) ) @@ -123,7 +130,8 @@ async def aload_tool( async def aload_toolset( self, toolset_name: Optional[str] = None, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, @@ -135,9 +143,10 @@ async def aload_toolset( Args: toolset_name: The name of the toolset to load. If not provided, all tools are loaded. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. strict: If True, raises a ValueError if any of the given bound @@ -149,7 +158,12 @@ async def aload_toolset( """ async_tools = await self.__run_as_async( self.__async_client.aload_toolset( - toolset_name, auth_tokens, auth_headers, bound_params, strict + toolset_name, + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + strict, ) ) @@ -164,7 +178,8 @@ async def aload_toolset( def load_tool( self, tool_name: str, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, @@ -174,9 +189,10 @@ def load_tool( Args: tool_name: The name of the tool to load. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. strict: If True, raises a ValueError if any of the given bound @@ -188,7 +204,12 @@ def load_tool( """ async_tool = self.__run_as_sync( self.__async_client.aload_tool( - tool_name, auth_tokens, auth_headers, bound_params, strict + tool_name, + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + strict, ) ) @@ -199,7 +220,8 @@ def load_tool( def load_toolset( self, toolset_name: Optional[str] = None, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, strict: bool = True, @@ -211,9 +233,10 @@ def load_toolset( Args: toolset_name: The name of the toolset to load. If not provided, all tools are loaded. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. strict: If True, raises a ValueError if any of the given bound @@ -225,7 +248,12 @@ def load_toolset( """ async_tools = self.__run_as_sync( self.__async_client.aload_toolset( - toolset_name, auth_tokens, auth_headers, bound_params, strict + toolset_name, + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + strict, ) ) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index b7793b96..feb2a597 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -83,16 +83,16 @@ def _run(self, **kwargs: Any) -> dict[str, Any]: async def _arun(self, **kwargs: Any) -> dict[str, Any]: return await self.__run_as_async(self.__async_tool._arun(**kwargs)) - def add_auth_tokens( - self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + def add_auth_token_getters( + self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True ) -> "ToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding authentication sources. Args: - auth_tokens: A dictionary of authentication source names to the - functions that return corresponding ID token. + auth_token_getters: A dictionary of authentication source names to + the functions that return corresponding ID token. strict: If True, a ValueError is raised if any of the provided auth parameters is already bound. If False, only a warning is issued. @@ -107,12 +107,12 @@ def add_auth_tokens( and strict is True. """ return ToolboxTool( - self.__async_tool.add_auth_tokens(auth_tokens, strict), + self.__async_tool.add_auth_token_getters(auth_token_getters, strict), self.__loop, self.__thread, ) - def add_auth_token( + def add_auth_token_getter( self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True ) -> "ToolboxTool": """ @@ -135,7 +135,7 @@ def add_auth_token( strict is True. """ return ToolboxTool( - self.__async_tool.add_auth_token(auth_source, get_id_token, strict), + self.__async_tool.add_auth_token_getter(auth_source, get_id_token, strict), self.__loop, self.__thread, ) diff --git a/packages/toolbox-langchain/tests/test_async_client.py b/packages/toolbox-langchain/tests/test_async_client.py index 2f31c7dc..25ad78eb 100644 --- a/packages/toolbox-langchain/tests/test_async_client.py +++ b/packages/toolbox-langchain/tests/test_async_client.py @@ -110,7 +110,7 @@ async def test_aload_tool_auth_headers_and_tokens( await mock_client.aload_tool( tool_name, auth_headers={"Authorization": lambda: "Bearer token"}, - auth_tokens={"test": lambda: "token"}, + auth_token_getters={"test": lambda: "token"}, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) @@ -172,7 +172,7 @@ async def test_aload_toolset_auth_headers_and_tokens( simplefilter("always") await mock_client.aload_toolset( auth_headers={"Authorization": lambda: "Bearer token"}, - auth_tokens={"test": lambda: "token"}, + auth_token_getters={"test": lambda: "token"}, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) diff --git a/packages/toolbox-langchain/tests/test_async_tools.py b/packages/toolbox-langchain/tests/test_async_tools.py index e7fa1bd0..e23aee85 100644 --- a/packages/toolbox-langchain/tests/test_async_tools.py +++ b/packages/toolbox-langchain/tests/test_async_tools.py @@ -151,7 +151,7 @@ async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): ) @pytest.mark.parametrize( - "auth_tokens, expected_auth_tokens", + "auth_token_getters, expected_auth_token_getters", [ ( {"test-auth-source": lambda: "test-token"}, @@ -169,19 +169,23 @@ async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): ), ], ) - async def test_toolbox_tool_add_auth_tokens( - self, auth_toolbox_tool, auth_tokens, expected_auth_tokens + async def test_toolbox_tool_add_auth_token_getters( + self, auth_toolbox_tool, auth_token_getters, expected_auth_token_getters ): - tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) - for source, getter in expected_auth_tokens.items(): - assert tool._AsyncToolboxTool__auth_tokens[source]() == getter() + tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) + for source, getter in expected_auth_token_getters.items(): + assert tool._AsyncToolboxTool__auth_token_getters[source]() == getter() - async def test_toolbox_tool_add_auth_tokens_duplicate(self, auth_toolbox_tool): - tool = auth_toolbox_tool.add_auth_tokens( + async def test_toolbox_tool_add_auth_token_getters_duplicate( + self, auth_toolbox_tool + ): + tool = auth_toolbox_tool.add_auth_token_getters( {"test-auth-source": lambda: "test-token"} ) with pytest.raises(ValueError) as e: - tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) + tool = tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} + ) assert ( "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." in str(e.value) @@ -223,7 +227,7 @@ async def test_toolbox_tool_call_with_bound_params( ) async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): - tool = auth_toolbox_tool.add_auth_tokens( + tool = auth_toolbox_tool.add_auth_token_getters( {"test-auth-source": lambda: "test-token"} ) result = await tool.ainvoke({"param2": 123}) @@ -240,7 +244,7 @@ async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_to match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", ): auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" - tool = auth_toolbox_tool.add_auth_tokens( + tool = auth_toolbox_tool.add_auth_token_getters( {"test-auth-source": lambda: "test-token"} ) result = await tool.ainvoke({"param2": 123}) diff --git a/packages/toolbox-langchain/tests/test_client.py b/packages/toolbox-langchain/tests/test_client.py index c9cd262a..62999019 100644 --- a/packages/toolbox-langchain/tests/test_client.py +++ b/packages/toolbox-langchain/tests/test_client.py @@ -49,7 +49,7 @@ def test_load_tool(self, mock_aload_tool, toolbox_client): assert tool.name == mock_tool.name assert tool.description == mock_tool.description assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, {}, True) + mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") def test_load_toolset(self, mock_aload_toolset, toolbox_client): @@ -70,7 +70,7 @@ def test_load_toolset(self, mock_aload_toolset, toolbox_client): and a.args_schema == b.args_schema for a, b in zip(tools, mock_tools) ) - mock_aload_toolset.assert_called_once_with(None, {}, None, {}, True) + mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) @pytest.mark.asyncio @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") @@ -85,7 +85,7 @@ async def test_aload_tool(self, mock_aload_tool, toolbox_client): assert tool.name == mock_tool.name assert tool.description == mock_tool.description assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, {}, True) + mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) @pytest.mark.asyncio @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") @@ -107,7 +107,7 @@ async def test_aload_toolset(self, mock_aload_toolset, toolbox_client): and a.args_schema == b.args_schema for a, b in zip(tools, mock_tools) ) - mock_aload_toolset.assert_called_once_with(None, {}, None, {}, True) + mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") def test_load_tool_with_args(self, mock_aload_tool, toolbox_client): @@ -116,12 +116,14 @@ def test_load_tool_with_args(self, mock_aload_tool, toolbox_client): mock_tool.description = "mock description" mock_tool.args_schema = BaseModel mock_aload_tool.return_value = mock_tool - auth_tokens = {"token1": lambda: "value1"} - auth_headers = {"header1": lambda: "value2"} - bound_params = {"param1": "value3"} + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens = {"token1": lambda: "value2"} + auth_headers = {"header1": lambda: "value3"} + bound_params = {"param1": "value4"} tool = toolbox_client.load_tool( "test_tool_name", + auth_token_getters=auth_token_getters, auth_tokens=auth_tokens, auth_headers=auth_headers, bound_params=bound_params, @@ -132,7 +134,12 @@ def test_load_tool_with_args(self, mock_aload_tool, toolbox_client): assert tool.description == mock_tool.description assert tool.args_schema == mock_tool.args_schema mock_aload_tool.assert_called_once_with( - "test_tool_name", auth_tokens, auth_headers, bound_params, False + "test_tool_name", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, ) @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") @@ -146,12 +153,14 @@ def test_load_toolset_with_args(self, mock_aload_toolset, toolbox_client): mock_tools[1].args_schema = BaseModel mock_aload_toolset.return_value = mock_tools - auth_tokens = {"token1": lambda: "value1"} - auth_headers = {"header1": lambda: "value2"} - bound_params = {"param1": "value3"} + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens = {"token1": lambda: "value2"} + auth_headers = {"header1": lambda: "value3"} + bound_params = {"param1": "value4"} tools = toolbox_client.load_toolset( toolset_name="my_toolset", + auth_token_getters=auth_token_getters, auth_tokens=auth_tokens, auth_headers=auth_headers, bound_params=bound_params, @@ -166,7 +175,12 @@ def test_load_toolset_with_args(self, mock_aload_toolset, toolbox_client): for a, b in zip(tools, mock_tools) ) mock_aload_toolset.assert_called_once_with( - "my_toolset", auth_tokens, auth_headers, bound_params, False + "my_toolset", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, ) @pytest.mark.asyncio @@ -178,18 +192,29 @@ async def test_aload_tool_with_args(self, mock_aload_tool, toolbox_client): mock_tool.args_schema = BaseModel mock_aload_tool.return_value = mock_tool - auth_tokens = {"token1": lambda: "value1"} - auth_headers = {"header1": lambda: "value2"} - bound_params = {"param1": "value3"} + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens = {"token1": lambda: "value2"} + auth_headers = {"header1": lambda: "value3"} + bound_params = {"param1": "value4"} tool = await toolbox_client.aload_tool( - "test_tool", auth_tokens, auth_headers, bound_params, False + "test_tool", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, ) assert tool.name == mock_tool.name assert tool.description == mock_tool.description assert tool.args_schema == mock_tool.args_schema mock_aload_tool.assert_called_once_with( - "test_tool", auth_tokens, auth_headers, bound_params, False + "test_tool", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, ) @pytest.mark.asyncio @@ -204,12 +229,18 @@ async def test_aload_toolset_with_args(self, mock_aload_toolset, toolbox_client) mock_tools[1].args_schema = BaseModel mock_aload_toolset.return_value = mock_tools - auth_tokens = {"token1": lambda: "value1"} - auth_headers = {"header1": lambda: "value2"} - bound_params = {"param1": "value3"} + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens = {"token1": lambda: "value2"} + auth_headers = {"header1": lambda: "value3"} + bound_params = {"param1": "value4"} tools = await toolbox_client.aload_toolset( - "my_toolset", auth_tokens, auth_headers, bound_params, False + "my_toolset", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, ) assert len(tools) == len(mock_tools) assert all( @@ -219,5 +250,10 @@ async def test_aload_toolset_with_args(self, mock_aload_toolset, toolbox_client) for a, b in zip(tools, mock_tools) ) mock_aload_toolset.assert_called_once_with( - "my_toolset", auth_tokens, auth_headers, bound_params, False + "my_toolset", + auth_token_getters, + auth_tokens, + auth_headers, + bound_params, + False, ) diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 8945657a..214ea305 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -115,7 +115,7 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool): async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" tool = await toolbox.aload_tool( - "get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2} + "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} ) response = await tool.ainvoke({"id": "2"}) assert "row2" in response @@ -136,7 +136,7 @@ async def test_run_tool_wrong_auth(self, toolbox, auth_token2): tool = await toolbox.aload_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token2) + auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) with pytest.raises( ToolException, match="{'status': 'Unauthorized', 'error': 'tool invocation not authorized. Please make sure your specify correct auth headers'}", @@ -148,7 +148,7 @@ async def test_run_tool_auth(self, toolbox, auth_token1): tool = await toolbox.aload_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token1) + auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token1) response = await auth_tool.ainvoke({"id": "2"}) assert "row2" in response @@ -164,7 +164,8 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): async def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" tool = await toolbox.aload_tool( - "get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + "get-row-by-email-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, ) response = await tool.ainvoke({}) assert "row4" in response @@ -174,7 +175,8 @@ async def test_run_tool_param_auth(self, toolbox, auth_token1): async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with insufficient auth.""" tool = await toolbox.aload_tool( - "get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + "get-row-by-content-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, ) with pytest.raises( ToolException, @@ -255,7 +257,7 @@ def test_run_tool_wrong_param_type(self, get_n_rows_tool): def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" tool = toolbox.load_tool( - "get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2} + "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} ) response = tool.invoke({"id": "2"}) assert "row2" in response @@ -276,7 +278,7 @@ def test_run_tool_wrong_auth(self, toolbox, auth_token2): tool = toolbox.load_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token2) + auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) with pytest.raises( ToolException, match="{'status': 'Unauthorized', 'error': 'tool invocation not authorized. Please make sure your specify correct auth headers'}", @@ -288,7 +290,7 @@ def test_run_tool_auth(self, toolbox, auth_token1): tool = toolbox.load_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token1) + auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token1) response = auth_tool.invoke({"id": "2"}) assert "row2" in response @@ -304,7 +306,8 @@ def test_run_tool_param_auth_no_auth(self, toolbox): def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" tool = toolbox.load_tool( - "get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + "get-row-by-email-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, ) response = tool.invoke({}) assert "row4" in response @@ -314,7 +317,8 @@ def test_run_tool_param_auth(self, toolbox, auth_token1): def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with insufficient auth.""" tool = toolbox.load_tool( - "get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + "get-row-by-content-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, ) with pytest.raises( ToolException, diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 56a3714f..751005af 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -59,7 +59,7 @@ def mock_async_tool(self, tool_schema): mock_async_tool._AsyncToolboxTool__schema = tool_schema mock_async_tool._AsyncToolboxTool__url = "http://test_url" mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_tokens = {} + mock_async_tool._AsyncToolboxTool__auth_token_getters = {} mock_async_tool._AsyncToolboxTool__bound_params = {} return mock_async_tool @@ -73,7 +73,7 @@ def mock_async_auth_tool(self, auth_tool_schema): mock_async_tool._AsyncToolboxTool__schema = auth_tool_schema mock_async_tool._AsyncToolboxTool__url = "http://test_url" mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_tokens = {} + mock_async_tool._AsyncToolboxTool__auth_token_getters = {} mock_async_tool._AsyncToolboxTool__bound_params = {} return mock_async_tool @@ -155,7 +155,7 @@ def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): assert isinstance(tool, ToolboxTool) @pytest.mark.parametrize( - "auth_tokens, expected_auth_tokens", + "auth_token_getters, expected_auth_token_getters", [ ( {"test-auth-source": lambda: "test-token"}, @@ -173,46 +173,52 @@ def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): ), ], ) - def test_toolbox_tool_add_auth_tokens( + def test_toolbox_tool_add_auth_token_getters( self, - auth_tokens, - expected_auth_tokens, + auth_token_getters, + expected_auth_token_getters, mock_async_auth_tool, auth_toolbox_tool, ): - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens = ( - expected_auth_tokens + auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( + expected_auth_token_getters ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_tokens.return_value = ( + auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getters.return_value = ( mock_async_auth_tool ) - tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) - mock_async_auth_tool.add_auth_tokens.assert_called_once_with(auth_tokens, True) - for source, getter in expected_auth_tokens.items(): + tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) + mock_async_auth_tool.add_auth_token_getters.assert_called_once_with( + auth_token_getters, True + ) + for source, getter in expected_auth_token_getters.items(): assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[source]() + tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ + source + ]() == getter() ) assert isinstance(tool, ToolboxTool) - def test_toolbox_tool_add_auth_token(self, mock_async_auth_tool, auth_toolbox_tool): + def test_toolbox_tool_add_auth_token_getter( + self, mock_async_auth_tool, auth_toolbox_tool + ): get_id_token = lambda: "test-token" - expected_auth_tokens = {"test-auth-source": get_id_token} - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens = ( - expected_auth_tokens + expected_auth_token_getters = {"test-auth-source": get_id_token} + auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( + expected_auth_token_getters ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token.return_value = ( + auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getter.return_value = ( mock_async_auth_tool ) - tool = auth_toolbox_tool.add_auth_token("test-auth-source", get_id_token) - mock_async_auth_tool.add_auth_token.assert_called_once_with( + tool = auth_toolbox_tool.add_auth_token_getter("test-auth-source", get_id_token) + mock_async_auth_tool.add_auth_token_getter.assert_called_once_with( "test-auth-source", get_id_token, True ) assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[ + tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ "test-auth-source" ]() == "test-token"