Skip to content

feat!: manifest updates to remove authSource in favor of authService #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/toolbox_langchain/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def aload_tool(

Args:
tool_name: The name of the tool to load.
auth_tokens: An optional mapping of authentication source names to
auth_tokens: An optional mapping of authentication service names to
functions that retrieve ID tokens.
auth_headers: Deprecated. Use `auth_tokens` instead.
bound_params: An optional mapping of parameter names to their
Expand Down Expand Up @@ -107,7 +107,7 @@ async def aload_toolset(
Args:
toolset_name: The name of the toolset to load. If not provided,
all tools are loaded.
auth_tokens: An optional mapping of authentication source names to
auth_tokens: An optional mapping of authentication service names to
functions that retrieve ID tokens.
auth_headers: Deprecated. Use `auth_tokens` instead.
bound_params: An optional mapping of parameter names to their
Expand Down
40 changes: 20 additions & 20 deletions src/toolbox_langchain/async_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
schema: The tool schema.
url: The base URL of the Toolbox service.
session: The HTTP client session.
auth_tokens: A mapping of authentication source names to functions
auth_tokens: A mapping of authentication service names to functions
that retrieve ID tokens.
bound_params: A mapping of parameter names to their bound
values.
Expand Down Expand Up @@ -157,7 +157,7 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]:

# If the tool had parameters that require authentication, then right
# before invoking that tool, we check whether all these required
# authentication sources have been registered or not.
# authentication services have been registered or not.
self.__validate_auth()

# Evaluate dynamic parameter values if any
Expand All @@ -182,36 +182,36 @@ def __validate_auth(self, strict: bool = True) -> None:
A tool is considered authenticated if all of its parameters meet at
least one of the following conditions:

* The parameter has at least one registered authentication source.
* The parameter has at least one registered authentication service.
* The parameter requires no authentication.

Args:
strict: If True, raises a PermissionError if any required
authentication sources are not registered. If False, only issues
authentication services are not registered. If False, only issues
a warning.

Raises:
PermissionError: If strict is True and any required authentication
sources are not registered.
services are not registered.
"""
params_missing_auth: list[str] = []

# Check each parameter for at least 1 required auth source
# Check each parameter for at least 1 required auth service
for param in self.__auth_params:
if not param.authSources:
raise ValueError("Auth sources cannot be None.")
if not param.authServices:
raise ValueError("Auth services cannot be None.")
has_auth = False
for src in param.authSources:
for src in param.authServices:

# Find first auth source that is specified
# Find first auth service that is specified
if src in self.__auth_tokens:
has_auth = True
break
if not has_auth:
params_missing_auth.append(param.name)

if params_missing_auth:
message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use."
message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication services are registered. Please register the required services before use."

if strict:
raise PermissionError(message)
Expand All @@ -235,7 +235,7 @@ def __create_copy(
original instance, ensuring immutability.

Args:
auth_tokens: A dictionary of auth source names to functions that
auth_tokens: A dictionary of auth service names to functions that
retrieve ID tokens. These tokens will be merged with the
existing auth tokens.
bound_params: A dictionary of parameter names to their
Expand Down Expand Up @@ -273,10 +273,10 @@ def add_auth_tokens(
) -> "AsyncToolboxTool":
"""
Registers functions to retrieve ID tokens for the corresponding
authentication sources.
authentication services.

Args:
auth_tokens: A dictionary of authentication source names to the
auth_tokens: A dictionary of authentication service names to the
functions that return corresponding ID token.
strict: If True, a ValueError is raised if any of the provided auth
tokens are already bound. If False, only a warning is issued.
Expand All @@ -291,28 +291,28 @@ def add_auth_tokens(
is True.
"""

# Check if the authentication source is already registered.
# Check if the authentication service is already registered.
dupe_tokens: list[str] = []
for auth_token, _ in auth_tokens.items():
if auth_token in self.__auth_tokens:
dupe_tokens.append(auth_token)

if dupe_tokens:
raise ValueError(
f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`."
f"Authentication service(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`."
)

return self.__create_copy(auth_tokens=auth_tokens, strict=strict)

def add_auth_token(
self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True
self, auth_service: str, get_id_token: Callable[[], str], strict: bool = True
) -> "AsyncToolboxTool":
"""
Registers a function to retrieve an ID token for a given authentication
source.
service.

Args:
auth_source: The name of the authentication source.
auth_service: The name of the authentication service.
get_id_token: A function that returns the ID token.
strict: If True, a ValueError is raised if any of the provided auth
token is already bound. If False, only a warning is issued.
Expand All @@ -326,7 +326,7 @@ def add_auth_token(
ValueError: If the provided auth token is already bound and strict
is True.
"""
return self.add_auth_tokens({auth_source: get_id_token}, strict=strict)
return self.add_auth_tokens({auth_service: get_id_token}, strict=strict)

def bind_params(
self,
Expand Down
8 changes: 4 additions & 4 deletions src/toolbox_langchain/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def aload_tool(

Args:
tool_name: The name of the tool to load.
auth_tokens: An optional mapping of authentication source names to
auth_tokens: An optional mapping of authentication service names to
functions that retrieve ID tokens.
auth_headers: Deprecated. Use `auth_tokens` instead.
bound_params: An optional mapping of parameter names to their
Expand Down Expand Up @@ -135,7 +135,7 @@ async def aload_toolset(
Args:
toolset_name: The name of the toolset to load. If not provided,
all tools are loaded.
auth_tokens: An optional mapping of authentication source names to
auth_tokens: An optional mapping of authentication service names to
functions that retrieve ID tokens.
auth_headers: Deprecated. Use `auth_tokens` instead.
bound_params: An optional mapping of parameter names to their
Expand Down Expand Up @@ -174,7 +174,7 @@ def load_tool(

Args:
tool_name: The name of the tool to load.
auth_tokens: An optional mapping of authentication source names to
auth_tokens: An optional mapping of authentication service names to
functions that retrieve ID tokens.
auth_headers: Deprecated. Use `auth_tokens` instead.
bound_params: An optional mapping of parameter names to their
Expand Down Expand Up @@ -211,7 +211,7 @@ def load_toolset(
Args:
toolset_name: The name of the toolset to load. If not provided,
all tools are loaded.
auth_tokens: An optional mapping of authentication source names to
auth_tokens: An optional mapping of authentication service names to
functions that retrieve ID tokens.
auth_headers: Deprecated. Use `auth_tokens` instead.
bound_params: An optional mapping of parameter names to their
Expand Down
12 changes: 6 additions & 6 deletions src/toolbox_langchain/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def add_auth_tokens(
) -> "ToolboxTool":
"""
Registers functions to retrieve ID tokens for the corresponding
authentication sources.
authentication services.

Args:
auth_tokens: A dictionary of authentication source names to the
auth_tokens: A dictionary of authentication service names to the
functions that return corresponding ID token.
strict: If True, a ValueError is raised if any of the provided auth
tokens are already bound. If False, only a warning is issued.
Expand All @@ -112,14 +112,14 @@ def add_auth_tokens(
)

def add_auth_token(
self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True
self, auth_service: str, get_id_token: Callable[[], str], strict: bool = True
) -> "ToolboxTool":
"""
Registers a function to retrieve an ID token for a given authentication
source.
service.

Args:
auth_source: The name of the authentication source.
auth_service: The name of the authentication service.
get_id_token: A function that returns the ID token.
strict: If True, a ValueError is raised if any of the provided auth
token is already bound. If False, only a warning is issued.
Expand All @@ -134,7 +134,7 @@ def add_auth_token(
is True.
"""
return ToolboxTool(
self.__async_tool.add_auth_token(auth_source, get_id_token, strict),
self.__async_tool.add_auth_token(auth_service, get_id_token, strict),
self.__loop,
self.__thread,
)
Expand Down
14 changes: 7 additions & 7 deletions src/toolbox_langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ParameterSchema(BaseModel):
name: str
type: str
description: str
authSources: Optional[list[str]] = None
authServices: Optional[list[str]] = None
items: Optional["ParameterSchema"] = None


Expand Down Expand Up @@ -149,19 +149,19 @@ def _get_auth_headers(id_token_getters: dict[str, Callable[[], str]]) -> dict[st

def _get_auth_tokens(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]:
"""
Gets ID tokens for the given auth sources in the getters map and returns
Gets ID tokens for the given auth services in the getters map and returns
tokens to be included in tool invocation.

Args:
id_token_getters: A dict that maps auth source names to the functions
id_token_getters: A dict that maps auth service names to the functions
that return its ID token.

Returns:
A dictionary of tokens to be included in the tool invocation.
"""
auth_tokens = {}
for auth_source, get_id_token in id_token_getters.items():
auth_tokens[f"{auth_source}_token"] = get_id_token()
for auth_service, get_id_token in id_token_getters.items():
auth_tokens[f"{auth_service}_token"] = get_id_token()
return auth_tokens


Expand All @@ -180,7 +180,7 @@ async def _invoke_tool(
session: The HTTP client session.
tool_name: The name of the tool to invoke.
data: The input data for the tool.
id_token_getters: A dict that maps auth source names to the functions
id_token_getters: A dict that maps auth service names to the functions
that return its ID token.

Returns:
Expand Down Expand Up @@ -226,7 +226,7 @@ def _find_auth_params(
_non_auth_params: list[ParameterSchema] = []

for param in params:
if param.authSources:
if param.authServices:
_auth_params.append(param)
else:
_non_auth_params.append(param)
Expand Down
32 changes: 16 additions & 16 deletions tests/test_async_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def auth_tool_schema(self):
"name": "param1",
"type": "string",
"description": "Param 1",
"authSources": ["test-auth-source"],
"authServices": ["test-auth-service"],
},
{"name": "param2", "type": "integer", "description": "Param 2"},
],
Expand Down Expand Up @@ -154,17 +154,17 @@ async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool):
"auth_tokens, expected_auth_tokens",
[
(
{"test-auth-source": lambda: "test-token"},
{"test-auth-source": lambda: "test-token"},
{"test-auth-service": lambda: "test-token"},
{"test-auth-service": lambda: "test-token"},
),
(
{
"test-auth-source": lambda: "test-token",
"another-auth-source": lambda: "another-token",
"test-auth-service": lambda: "test-token",
"another-auth-service": lambda: "another-token",
},
{
"test-auth-source": lambda: "test-token",
"another-auth-source": lambda: "another-token",
"test-auth-service": lambda: "test-token",
"another-auth-service": lambda: "another-token",
},
),
],
Expand All @@ -173,17 +173,17 @@ async def test_toolbox_tool_add_auth_tokens(
self, auth_toolbox_tool, auth_tokens, expected_auth_tokens
):
tool = auth_toolbox_tool.add_auth_tokens(auth_tokens)
for source, getter in expected_auth_tokens.items():
assert tool._AsyncToolboxTool__auth_tokens[source]() == getter()
for service, getter in expected_auth_tokens.items():
assert tool._AsyncToolboxTool__auth_tokens[service]() == getter()

async def test_toolbox_tool_add_auth_tokens_duplicate(self, auth_toolbox_tool):
tool = auth_toolbox_tool.add_auth_tokens(
{"test-auth-source": lambda: "test-token"}
{"test-auth-service": lambda: "test-token"}
)
with pytest.raises(ValueError) as e:
tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"})
tool = tool.add_auth_tokens({"test-auth-service": lambda: "test-token"})
assert (
"Authentication source(s) `test-auth-source` already registered in tool `test_tool`."
"Authentication service(s) `test-auth-service` already registered in tool `test_tool`."
in str(e.value)
)

Expand Down Expand Up @@ -224,14 +224,14 @@ async def test_toolbox_tool_call_with_bound_params(

async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool):
tool = auth_toolbox_tool.add_auth_tokens(
{"test-auth-source": lambda: "test-token"}
{"test-auth-service": lambda: "test-token"}
)
result = await tool.ainvoke({"param2": 123})
assert result == {"result": "test-result"}
auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with(
"https://test-url/api/tool/test_tool/invoke",
json={"param2": 123},
headers={"test-auth-source_token": "test-token"},
headers={"test-auth-service_token": "test-token"},
)

async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool):
Expand All @@ -241,14 +241,14 @@ async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_to
):
auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url"
tool = auth_toolbox_tool.add_auth_tokens(
{"test-auth-source": lambda: "test-token"}
{"test-auth-service": lambda: "test-token"}
)
result = await tool.ainvoke({"param2": 123})
assert result == {"result": "test-result"}
auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with(
"http://test-url/api/tool/test_tool/invoke",
json={"param2": 123},
headers={"test-auth-source_token": "test-token"},
headers={"test-auth-service_token": "test-token"},
)

async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox):
tool = await toolbox.aload_tool("get-row-by-email-auth")
with pytest.raises(
PermissionError,
match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.",
match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication services are registered\. Please register the required services before use\.",
):
await tool.ainvoke({"email": ""})

Expand Down Expand Up @@ -287,7 +287,7 @@ def test_run_tool_param_auth_no_auth(self, toolbox):
tool = toolbox.load_tool("get-row-by-email-auth")
with pytest.raises(
PermissionError,
match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.",
match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication services are registered\. Please register the required services before use\.",
):
tool.invoke({"email": ""})

Expand Down
Loading