From 37e0a9e1b3bfb53a438a8d708f3bb20200337cb2 Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Wed, 19 Feb 2025 13:38:46 -0800 Subject: [PATCH] feat!: manifest updates to remove authSource in favor of authService --- src/toolbox_langchain/async_client.py | 4 +-- src/toolbox_langchain/async_tools.py | 40 +++++++++++++-------------- src/toolbox_langchain/client.py | 8 +++--- src/toolbox_langchain/tools.py | 12 ++++---- src/toolbox_langchain/utils.py | 14 +++++----- tests/test_async_tools.py | 32 ++++++++++----------- tests/test_e2e.py | 4 +-- tests/test_tools.py | 26 ++++++++--------- tests/test_utils.py | 2 +- 9 files changed, 71 insertions(+), 71 deletions(-) diff --git a/src/toolbox_langchain/async_client.py b/src/toolbox_langchain/async_client.py index b65c8ccf..4febae79 100644 --- a/src/toolbox_langchain/async_client.py +++ b/src/toolbox_langchain/async_client.py @@ -54,7 +54,7 @@ async def aload_tool( Args: tool_name: The name of the tool to load. - auth_tokens: An optional mapping of authentication source names to + auth_tokens: An optional mapping of authentication service names to functions that retrieve ID tokens. auth_headers: Deprecated. Use `auth_tokens` instead. bound_params: An optional mapping of parameter names to their @@ -107,7 +107,7 @@ async def aload_toolset( Args: toolset_name: The name of the toolset to load. If not provided, all tools are loaded. - auth_tokens: An optional mapping of authentication source names to + auth_tokens: An optional mapping of authentication service names to functions that retrieve ID tokens. auth_headers: Deprecated. Use `auth_tokens` instead. bound_params: An optional mapping of parameter names to their diff --git a/src/toolbox_langchain/async_tools.py b/src/toolbox_langchain/async_tools.py index 998580d4..f4f2fb31 100644 --- a/src/toolbox_langchain/async_tools.py +++ b/src/toolbox_langchain/async_tools.py @@ -57,7 +57,7 @@ def __init__( schema: The tool schema. url: The base URL of the Toolbox service. session: The HTTP client session. - auth_tokens: A mapping of authentication source names to functions + auth_tokens: A mapping of authentication service names to functions that retrieve ID tokens. bound_params: A mapping of parameter names to their bound values. @@ -157,7 +157,7 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: # If the tool had parameters that require authentication, then right # before invoking that tool, we check whether all these required - # authentication sources have been registered or not. + # authentication services have been registered or not. self.__validate_auth() # Evaluate dynamic parameter values if any @@ -182,28 +182,28 @@ def __validate_auth(self, strict: bool = True) -> None: A tool is considered authenticated if all of its parameters meet at least one of the following conditions: - * The parameter has at least one registered authentication source. + * The parameter has at least one registered authentication service. * The parameter requires no authentication. Args: strict: If True, raises a PermissionError if any required - authentication sources are not registered. If False, only issues + authentication services are not registered. If False, only issues a warning. Raises: PermissionError: If strict is True and any required authentication - sources are not registered. + services are not registered. """ params_missing_auth: list[str] = [] - # Check each parameter for at least 1 required auth source + # Check each parameter for at least 1 required auth service for param in self.__auth_params: - if not param.authSources: - raise ValueError("Auth sources cannot be None.") + if not param.authServices: + raise ValueError("Auth services cannot be None.") has_auth = False - for src in param.authSources: + for src in param.authServices: - # Find first auth source that is specified + # Find first auth service that is specified if src in self.__auth_tokens: has_auth = True break @@ -211,7 +211,7 @@ def __validate_auth(self, strict: bool = True) -> None: params_missing_auth.append(param.name) if params_missing_auth: - message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." + message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication services are registered. Please register the required services before use." if strict: raise PermissionError(message) @@ -235,7 +235,7 @@ def __create_copy( original instance, ensuring immutability. Args: - auth_tokens: A dictionary of auth source names to functions that + auth_tokens: A dictionary of auth service names to functions that retrieve ID tokens. These tokens will be merged with the existing auth tokens. bound_params: A dictionary of parameter names to their @@ -273,10 +273,10 @@ def add_auth_tokens( ) -> "AsyncToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding - authentication sources. + authentication services. Args: - auth_tokens: A dictionary of authentication source names to the + auth_tokens: A dictionary of authentication service names to the functions that return corresponding ID token. strict: If True, a ValueError is raised if any of the provided auth tokens are already bound. If False, only a warning is issued. @@ -291,7 +291,7 @@ def add_auth_tokens( is True. """ - # Check if the authentication source is already registered. + # Check if the authentication service is already registered. dupe_tokens: list[str] = [] for auth_token, _ in auth_tokens.items(): if auth_token in self.__auth_tokens: @@ -299,20 +299,20 @@ def add_auth_tokens( if dupe_tokens: raise ValueError( - f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." + f"Authentication service(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." ) return self.__create_copy(auth_tokens=auth_tokens, strict=strict) def add_auth_token( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + self, auth_service: str, get_id_token: Callable[[], str], strict: bool = True ) -> "AsyncToolboxTool": """ Registers a function to retrieve an ID token for a given authentication - source. + service. Args: - auth_source: The name of the authentication source. + auth_service: The name of the authentication service. get_id_token: A function that returns the ID token. strict: If True, a ValueError is raised if any of the provided auth token is already bound. If False, only a warning is issued. @@ -326,7 +326,7 @@ def add_auth_token( ValueError: If the provided auth token is already bound and strict is True. """ - return self.add_auth_tokens({auth_source: get_id_token}, strict=strict) + return self.add_auth_tokens({auth_service: get_id_token}, strict=strict) def bind_params( self, diff --git a/src/toolbox_langchain/client.py b/src/toolbox_langchain/client.py index f30d5766..ce0b87fa 100644 --- a/src/toolbox_langchain/client.py +++ b/src/toolbox_langchain/client.py @@ -98,7 +98,7 @@ async def aload_tool( Args: tool_name: The name of the tool to load. - auth_tokens: An optional mapping of authentication source names to + auth_tokens: An optional mapping of authentication service names to functions that retrieve ID tokens. auth_headers: Deprecated. Use `auth_tokens` instead. bound_params: An optional mapping of parameter names to their @@ -135,7 +135,7 @@ async def aload_toolset( Args: toolset_name: The name of the toolset to load. If not provided, all tools are loaded. - auth_tokens: An optional mapping of authentication source names to + auth_tokens: An optional mapping of authentication service names to functions that retrieve ID tokens. auth_headers: Deprecated. Use `auth_tokens` instead. bound_params: An optional mapping of parameter names to their @@ -174,7 +174,7 @@ def load_tool( Args: tool_name: The name of the tool to load. - auth_tokens: An optional mapping of authentication source names to + auth_tokens: An optional mapping of authentication service names to functions that retrieve ID tokens. auth_headers: Deprecated. Use `auth_tokens` instead. bound_params: An optional mapping of parameter names to their @@ -211,7 +211,7 @@ def load_toolset( Args: toolset_name: The name of the toolset to load. If not provided, all tools are loaded. - auth_tokens: An optional mapping of authentication source names to + auth_tokens: An optional mapping of authentication service names to functions that retrieve ID tokens. auth_headers: Deprecated. Use `auth_tokens` instead. bound_params: An optional mapping of parameter names to their diff --git a/src/toolbox_langchain/tools.py b/src/toolbox_langchain/tools.py index f19b3d61..a57bdd53 100644 --- a/src/toolbox_langchain/tools.py +++ b/src/toolbox_langchain/tools.py @@ -88,10 +88,10 @@ def add_auth_tokens( ) -> "ToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding - authentication sources. + authentication services. Args: - auth_tokens: A dictionary of authentication source names to the + auth_tokens: A dictionary of authentication service names to the functions that return corresponding ID token. strict: If True, a ValueError is raised if any of the provided auth tokens are already bound. If False, only a warning is issued. @@ -112,14 +112,14 @@ def add_auth_tokens( ) def add_auth_token( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + self, auth_service: str, get_id_token: Callable[[], str], strict: bool = True ) -> "ToolboxTool": """ Registers a function to retrieve an ID token for a given authentication - source. + service. Args: - auth_source: The name of the authentication source. + auth_service: The name of the authentication service. get_id_token: A function that returns the ID token. strict: If True, a ValueError is raised if any of the provided auth token is already bound. If False, only a warning is issued. @@ -134,7 +134,7 @@ def add_auth_token( is True. """ return ToolboxTool( - self.__async_tool.add_auth_token(auth_source, get_id_token, strict), + self.__async_tool.add_auth_token(auth_service, get_id_token, strict), self.__loop, self.__thread, ) diff --git a/src/toolbox_langchain/utils.py b/src/toolbox_langchain/utils.py index c63332f7..be7b7941 100644 --- a/src/toolbox_langchain/utils.py +++ b/src/toolbox_langchain/utils.py @@ -29,7 +29,7 @@ class ParameterSchema(BaseModel): name: str type: str description: str - authSources: Optional[list[str]] = None + authServices: Optional[list[str]] = None items: Optional["ParameterSchema"] = None @@ -149,19 +149,19 @@ def _get_auth_headers(id_token_getters: dict[str, Callable[[], str]]) -> dict[st def _get_auth_tokens(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: """ - Gets ID tokens for the given auth sources in the getters map and returns + Gets ID tokens for the given auth services in the getters map and returns tokens to be included in tool invocation. Args: - id_token_getters: A dict that maps auth source names to the functions + id_token_getters: A dict that maps auth service names to the functions that return its ID token. Returns: A dictionary of tokens to be included in the tool invocation. """ auth_tokens = {} - for auth_source, get_id_token in id_token_getters.items(): - auth_tokens[f"{auth_source}_token"] = get_id_token() + for auth_service, get_id_token in id_token_getters.items(): + auth_tokens[f"{auth_service}_token"] = get_id_token() return auth_tokens @@ -180,7 +180,7 @@ async def _invoke_tool( session: The HTTP client session. tool_name: The name of the tool to invoke. data: The input data for the tool. - id_token_getters: A dict that maps auth source names to the functions + id_token_getters: A dict that maps auth service names to the functions that return its ID token. Returns: @@ -226,7 +226,7 @@ def _find_auth_params( _non_auth_params: list[ParameterSchema] = [] for param in params: - if param.authSources: + if param.authServices: _auth_params.append(param) else: _non_auth_params.append(param) diff --git a/tests/test_async_tools.py b/tests/test_async_tools.py index 13fca7ee..0adc130d 100644 --- a/tests/test_async_tools.py +++ b/tests/test_async_tools.py @@ -42,7 +42,7 @@ def auth_tool_schema(self): "name": "param1", "type": "string", "description": "Param 1", - "authSources": ["test-auth-source"], + "authServices": ["test-auth-service"], }, {"name": "param2", "type": "integer", "description": "Param 2"}, ], @@ -154,17 +154,17 @@ async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): "auth_tokens, expected_auth_tokens", [ ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, + {"test-auth-service": lambda: "test-token"}, + {"test-auth-service": lambda: "test-token"}, ), ( { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", + "test-auth-service": lambda: "test-token", + "another-auth-service": lambda: "another-token", }, { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", + "test-auth-service": lambda: "test-token", + "another-auth-service": lambda: "another-token", }, ), ], @@ -173,17 +173,17 @@ async def test_toolbox_tool_add_auth_tokens( self, auth_toolbox_tool, auth_tokens, expected_auth_tokens ): tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) - for source, getter in expected_auth_tokens.items(): - assert tool._AsyncToolboxTool__auth_tokens[source]() == getter() + for service, getter in expected_auth_tokens.items(): + assert tool._AsyncToolboxTool__auth_tokens[service]() == getter() async def test_toolbox_tool_add_auth_tokens_duplicate(self, auth_toolbox_tool): tool = auth_toolbox_tool.add_auth_tokens( - {"test-auth-source": lambda: "test-token"} + {"test-auth-service": lambda: "test-token"} ) with pytest.raises(ValueError) as e: - tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) + tool = tool.add_auth_tokens({"test-auth-service": lambda: "test-token"}) assert ( - "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." + "Authentication service(s) `test-auth-service` already registered in tool `test_tool`." in str(e.value) ) @@ -224,14 +224,14 @@ async def test_toolbox_tool_call_with_bound_params( async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): tool = auth_toolbox_tool.add_auth_tokens( - {"test-auth-source": lambda: "test-token"} + {"test-auth-service": lambda: "test-token"} ) result = await tool.ainvoke({"param2": 123}) assert result == {"result": "test-result"} auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( "https://test-url/api/tool/test_tool/invoke", json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, + headers={"test-auth-service_token": "test-token"}, ) async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool): @@ -241,14 +241,14 @@ async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_to ): auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" tool = auth_toolbox_tool.add_auth_tokens( - {"test-auth-source": lambda: "test-token"} + {"test-auth-service": lambda: "test-token"} ) result = await tool.ainvoke({"param2": 123}) assert result == {"result": "test-result"} auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( "http://test-url/api/tool/test_tool/invoke", json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, + headers={"test-auth-service_token": "test-token"}, ) async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 6eef88da..2cbbff84 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -153,7 +153,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): tool = await toolbox.aload_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication services are registered\. Please register the required services before use\.", ): await tool.ainvoke({"email": ""}) @@ -287,7 +287,7 @@ def test_run_tool_param_auth_no_auth(self, toolbox): tool = toolbox.load_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication services are registered\. Please register the required services before use\.", ): tool.invoke({"email": ""}) diff --git a/tests/test_tools.py b/tests/test_tools.py index a866f9be..38534143 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -43,7 +43,7 @@ def auth_tool_schema(self): "name": "param1", "type": "string", "description": "Param 1", - "authSources": ["test-auth-source"], + "authServices": ["test-auth-service"], }, {"name": "param2", "type": "integer", "description": "Param 2"}, ], @@ -158,17 +158,17 @@ def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): "auth_tokens, expected_auth_tokens", [ ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, + {"test-auth-service": lambda: "test-token"}, + {"test-auth-service": lambda: "test-token"}, ), ( { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", + "test-auth-service": lambda: "test-token", + "another-auth-service": lambda: "another-token", }, { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", + "test-auth-service": lambda: "test-token", + "another-auth-service": lambda: "another-token", }, ), ], @@ -189,16 +189,16 @@ def test_toolbox_tool_add_auth_tokens( tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) mock_async_auth_tool.add_auth_tokens.assert_called_once_with(auth_tokens, True) - for source, getter in expected_auth_tokens.items(): + for service, getter in expected_auth_tokens.items(): assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[source]() + tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[service]() == getter() ) assert isinstance(tool, ToolboxTool) def test_toolbox_tool_add_auth_token(self, mock_async_auth_tool, auth_toolbox_tool): get_id_token = lambda: "test-token" - expected_auth_tokens = {"test-auth-source": get_id_token} + expected_auth_tokens = {"test-auth-service": get_id_token} auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens = ( expected_auth_tokens ) @@ -206,14 +206,14 @@ def test_toolbox_tool_add_auth_token(self, mock_async_auth_tool, auth_toolbox_to mock_async_auth_tool ) - tool = auth_toolbox_tool.add_auth_token("test-auth-source", get_id_token) + tool = auth_toolbox_tool.add_auth_token("test-auth-service", get_id_token) mock_async_auth_tool.add_auth_token.assert_called_once_with( - "test-auth-source", get_id_token, True + "test-auth-service", get_id_token, True ) assert ( tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[ - "test-auth-source" + "test-auth-service" ]() == "test-token" ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8e5139ed..d48c4525 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -288,4 +288,4 @@ def test_get_auth_headers_deprecation_warning(self): DeprecationWarning, match=r"Call to deprecated function \(or staticmethod\) _get_auth_headers\. \(Please use `_get_auth_tokens` instead\.\)$", ): - _get_auth_headers({"auth_source1": lambda: "test_token"}) + _get_auth_headers({"auth_service1": lambda: "test_token"})