Skip to content

chore(toolbox-langchain): rename auth_tokens to auth_token_getters #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions packages/toolbox-langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,21 @@ async def get_auth_token():
toolbox = ToolboxClient("http://127.0.0.1:5000")
tools = toolbox.load_toolset()

auth_tool = tools[0].add_auth_token("my_auth", get_auth_token) # Single token
auth_tool = tools[0].add_auth_token_getter("my_auth", get_auth_token) # Single token

multi_auth_tool = tools[0].add_auth_tokens({"my_auth", get_auth_token}) # Multiple tokens
multi_auth_tool = tools[0].add_auth_token_getters({"my_auth", get_auth_token}) # Multiple tokens

# OR

auth_tools = [tool.add_auth_token("my_auth", get_auth_token) for tool in tools]
auth_tools = [tool.add_auth_token_getter("my_auth", get_auth_token) for tool in tools]
```

#### Add Authentication While Loading

```py
auth_tool = toolbox.load_tool(auth_tokens={"my_auth": get_auth_token})
auth_tool = toolbox.load_tool(auth_token_getters={"my_auth": get_auth_token})

auth_tools = toolbox.load_toolset(auth_tokens={"my_auth": get_auth_token})
auth_tools = toolbox.load_toolset(auth_token_getters={"my_auth": get_auth_token})
```

> [!NOTE]
Expand All @@ -260,7 +260,7 @@ async def get_auth_token():
toolbox = ToolboxClient("http://127.0.0.1:5000")
tool = toolbox.load_tool("my-tool")

auth_tool = tool.add_auth_token("my_auth", get_auth_token)
auth_tool = tool.add_auth_token_getter("my_auth", get_auth_token)
result = auth_tool.invoke({"input": "some input"})
print(result)
```
Expand Down
72 changes: 52 additions & 20 deletions packages/toolbox-langchain/src/toolbox_langchain/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __init__(
async def aload_tool(
self,
tool_name: str,
auth_tokens: dict[str, Callable[[], str]] = {},
auth_token_getters: dict[str, Callable[[], str]] = {},
auth_tokens: Optional[dict[str, Callable[[], str]]] = None,
auth_headers: Optional[dict[str, Callable[[], str]]] = None,
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
strict: bool = True,
Expand All @@ -54,9 +55,10 @@ async def aload_tool(

Args:
tool_name: The name of the tool to load.
auth_tokens: An optional mapping of authentication source names to
functions that retrieve ID tokens.
auth_headers: Deprecated. Use `auth_tokens` instead.
auth_token_getters: An optional mapping of authentication source
names to functions that retrieve ID tokens.
auth_tokens: Deprecated. Use `auth_token_getters` instead.
auth_headers: Deprecated. Use `auth_token_getters` instead.
bound_params: An optional mapping of parameter names to their
bound values.
strict: If True, raises a ValueError if any of the given bound
Expand All @@ -67,17 +69,30 @@ async def aload_tool(
A tool loaded from the Toolbox.
"""
if auth_headers:
if auth_tokens:
if auth_token_getters:
warn(
"Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.",
"Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.",
DeprecationWarning,
)
else:
warn(
"Argument `auth_headers` is deprecated. Use `auth_tokens` instead.",
"Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.",
DeprecationWarning,
)
auth_tokens = auth_headers
auth_token_getters = auth_headers

if auth_tokens:
if auth_token_getters:
warn(
"Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.",
DeprecationWarning,
)
else:
warn(
"Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.",
DeprecationWarning,
)
auth_token_getters = auth_tokens

url = f"{self.__url}/api/tool/{tool_name}"
manifest: ManifestSchema = await _load_manifest(url, self.__session)
Expand All @@ -87,15 +102,16 @@ async def aload_tool(
manifest.tools[tool_name],
self.__url,
self.__session,
auth_tokens,
auth_token_getters,
bound_params,
strict,
)

async def aload_toolset(
self,
toolset_name: Optional[str] = None,
auth_tokens: dict[str, Callable[[], str]] = {},
auth_token_getters: dict[str, Callable[[], str]] = {},
auth_tokens: Optional[dict[str, Callable[[], str]]] = None,
auth_headers: Optional[dict[str, Callable[[], str]]] = None,
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
strict: bool = True,
Expand All @@ -107,9 +123,10 @@ async def aload_toolset(
Args:
toolset_name: The name of the toolset to load. If not provided,
all tools are loaded.
auth_tokens: An optional mapping of authentication source names to
functions that retrieve ID tokens.
auth_headers: Deprecated. Use `auth_tokens` instead.
auth_token_getters: An optional mapping of authentication source
names to functions that retrieve ID tokens.
auth_tokens: Deprecated. Use `auth_token_getters` instead.
auth_headers: Deprecated. Use `auth_token_getters` instead.
bound_params: An optional mapping of parameter names to their
bound values.
strict: If True, raises a ValueError if any of the given bound
Expand All @@ -120,17 +137,30 @@ async def aload_toolset(
A list of all tools loaded from the Toolbox.
"""
if auth_headers:
if auth_tokens:
if auth_token_getters:
warn(
"Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.",
DeprecationWarning,
)
else:
warn(
"Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.",
DeprecationWarning,
)
auth_token_getters = auth_headers

if auth_tokens:
if auth_token_getters:
warn(
"Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.",
"Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.",
DeprecationWarning,
)
else:
warn(
"Argument `auth_headers` is deprecated. Use `auth_tokens` instead.",
"Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.",
DeprecationWarning,
)
auth_tokens = auth_headers
auth_token_getters = auth_tokens

url = f"{self.__url}/api/toolset/{toolset_name or ''}"
manifest: ManifestSchema = await _load_manifest(url, self.__session)
Expand All @@ -143,7 +173,7 @@ async def aload_toolset(
tool_schema,
self.__url,
self.__session,
auth_tokens,
auth_token_getters,
bound_params,
strict,
)
Expand All @@ -153,7 +183,8 @@ async def aload_toolset(
def load_tool(
self,
tool_name: str,
auth_tokens: dict[str, Callable[[], str]] = {},
auth_token_getters: dict[str, Callable[[], str]] = {},
auth_tokens: Optional[dict[str, Callable[[], str]]] = None,
auth_headers: Optional[dict[str, Callable[[], str]]] = None,
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
strict: bool = True,
Expand All @@ -163,7 +194,8 @@ def load_tool(
def load_toolset(
self,
toolset_name: Optional[str] = None,
auth_tokens: dict[str, Callable[[], str]] = {},
auth_token_getters: dict[str, Callable[[], str]] = {},
auth_tokens: Optional[dict[str, Callable[[], str]]] = None,
auth_headers: Optional[dict[str, Callable[[], str]]] = None,
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
strict: bool = True,
Expand Down
40 changes: 20 additions & 20 deletions packages/toolbox-langchain/src/toolbox_langchain/async_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
schema: ToolSchema,
url: str,
session: ClientSession,
auth_tokens: dict[str, Callable[[], str]] = {},
auth_token_getters: dict[str, Callable[[], str]] = {},
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
strict: bool = True,
) -> None:
Expand All @@ -57,8 +57,8 @@ def __init__(
schema: The tool schema.
url: The base URL of the Toolbox service.
session: The HTTP client session.
auth_tokens: A mapping of authentication source names to functions
that retrieve ID tokens.
auth_token_getters: A mapping of authentication source names to
functions that retrieve ID tokens.
bound_params: A mapping of parameter names to their bound
values.
strict: If True, raises a ValueError if any of the given bound
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
self.__schema = schema
self.__url = url
self.__session = session
self.__auth_tokens = auth_tokens
self.__auth_token_getters = auth_token_getters
self.__auth_params = auth_params
self.__bound_params = bound_params

Expand Down Expand Up @@ -172,7 +172,7 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]:
kwargs.update(evaluated_params)

return await _invoke_tool(
self.__url, self.__session, self.__name, kwargs, self.__auth_tokens
self.__url, self.__session, self.__name, kwargs, self.__auth_token_getters
)

def __validate_auth(self, strict: bool = True) -> None:
Expand All @@ -199,7 +199,7 @@ def __validate_auth(self, strict: bool = True) -> None:

# Check tool for at least 1 required auth source
for src in self.__schema.authRequired:
if src in self.__auth_tokens:
if src in self.__auth_token_getters:
is_authenticated = True
break

Expand All @@ -211,7 +211,7 @@ def __validate_auth(self, strict: bool = True) -> None:
for src in param.authSources:

# Find first auth source that is specified
if src in self.__auth_tokens:
if src in self.__auth_token_getters:
has_auth = True
break
if not has_auth:
Expand All @@ -238,7 +238,7 @@ def __validate_auth(self, strict: bool = True) -> None:
def __create_copy(
self,
*,
auth_tokens: dict[str, Callable[[], str]] = {},
auth_token_getters: dict[str, Callable[[], str]] = {},
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
strict: bool,
) -> "AsyncToolboxTool":
Expand All @@ -253,8 +253,8 @@ def __create_copy(
original instance, ensuring immutability.

Args:
auth_tokens: A dictionary of auth source names to functions that
retrieve ID tokens. These tokens will be merged with the
auth_token_getters: A dictionary of auth source names to functions
that retrieve ID tokens. These tokens will be merged with the
existing auth tokens.
bound_params: A dictionary of parameter names to their
bound values or functions to retrieve the values. These params
Expand All @@ -281,21 +281,21 @@ def __create_copy(
schema=new_schema,
url=self.__url,
session=self.__session,
auth_tokens={**self.__auth_tokens, **auth_tokens},
auth_token_getters={**self.__auth_token_getters, **auth_token_getters},
bound_params={**self.__bound_params, **bound_params},
strict=strict,
)

def add_auth_tokens(
self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True
def add_auth_token_getters(
self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True
) -> "AsyncToolboxTool":
"""
Registers functions to retrieve ID tokens for the corresponding
authentication sources.

Args:
auth_tokens: A dictionary of authentication source names to the
functions that return corresponding ID token.
auth_token_getters: A dictionary of authentication source names to
the functions that return corresponding ID token getters.
strict: If True, a ValueError is raised if any of the provided auth
parameters is already bound. If False, only a warning is issued.

Expand All @@ -313,18 +313,18 @@ def add_auth_tokens(

# Check if the authentication source is already registered.
dupe_tokens: list[str] = []
for auth_token, _ in auth_tokens.items():
if auth_token in self.__auth_tokens:
for auth_token, _ in auth_token_getters.items():
if auth_token in self.__auth_token_getters:
dupe_tokens.append(auth_token)

if dupe_tokens:
raise ValueError(
f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`."
)

return self.__create_copy(auth_tokens=auth_tokens, strict=strict)
return self.__create_copy(auth_token_getters=auth_token_getters, strict=strict)

def add_auth_token(
def add_auth_token_getter(
self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True
) -> "AsyncToolboxTool":
"""
Expand All @@ -346,7 +346,7 @@ def add_auth_token(
ValueError: If the provided auth parameter is already bound and
strict is True.
"""
return self.add_auth_tokens({auth_source: get_id_token}, strict=strict)
return self.add_auth_token_getters({auth_source: get_id_token}, strict=strict)

def bind_params(
self,
Expand Down
Loading