Skip to content

feat(toolbox-core): add authenticated parameters support #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/toolbox-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ test = [
"isort==6.0.1",
"mypy==1.15.0",
"pytest==8.3.5",
"pytest-aioresponses==0.3.0"
"pytest-aioresponses==0.3.0",
"pytest-asyncio==0.25.3",
]
[build-system]
requires = ["setuptools"]
Expand Down
48 changes: 41 additions & 7 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
import re
import types
from typing import Any, Callable, Optional

from aiohttp import ClientSession

from .protocol import ManifestSchema, ToolSchema
from .tool import ToolboxTool
from .tool import ToolboxTool, identify_required_authn_params


class ToolboxClient:
Expand Down Expand Up @@ -53,14 +54,37 @@ def __init__(
session = ClientSession()
self.__session = session

def __parse_tool(self, name: str, schema: ToolSchema) -> ToolboxTool:
def __parse_tool(
self,
name: str,
schema: ToolSchema,
auth_token_getters: dict[str, Callable[[], str]],
) -> ToolboxTool:
"""Internal helper to create a callable tool from its schema."""
# sort into authenticated and reg params
params = []
authn_params: dict[str, list[str]] = {}
auth_sources: set[str] = set()
for p in schema.parameters:
if not p.authSources:
params.append(p)
else:
authn_params[p.name] = p.authSources
auth_sources.update(p.authSources)

authn_params = identify_required_authn_params(
authn_params, auth_token_getters.keys()
)

tool = ToolboxTool(
session=self.__session,
base_url=self.__base_url,
name=name,
desc=schema.description,
params=[p.to_param() for p in schema.parameters],
params=[p.to_param() for p in params],
# create a read-only values for the maps to prevent mutation
required_authn_params=types.MappingProxyType(authn_params),
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
)
return tool

Expand Down Expand Up @@ -99,6 +123,7 @@ async def close(self):
async def load_tool(
self,
name: str,
auth_token_getters: dict[str, Callable[[], str]] = {},
) -> ToolboxTool:
"""
Asynchronously loads a tool from the server.
Expand All @@ -109,6 +134,8 @@ async def load_tool(

Args:
name: The unique name or identifier of the tool to load.
auth_token_getters: A mapping of authentication service names to
callables that return the corresponding authentication token.

Returns:
ToolboxTool: A callable object representing the loaded tool, ready
Expand All @@ -127,19 +154,23 @@ async def load_tool(
if name not in manifest.tools:
# TODO: Better exception
raise Exception(f"Tool '{name}' not found!")
tool = self.__parse_tool(name, manifest.tools[name])
tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters)

return tool

async def load_toolset(
self,
name: str,
auth_token_getters: dict[str, Callable[[], str]] = {},
) -> list[ToolboxTool]:
"""
Asynchronously fetches a toolset and loads all tools defined within it.

Args:
name: Name of the toolset to load tools.
auth_token_getters: A mapping of authentication service names to
callables that return the corresponding authentication token.


Returns:
list[ToolboxTool]: A list of callables, one for each tool defined
Expand All @@ -152,5 +183,8 @@ async def load_toolset(
manifest: ManifestSchema = ManifestSchema(**json)

# parse each tools name and schema into a list of ToolboxTools
tools = [self.__parse_tool(n, s) for n, s in manifest.tools.items()]
tools = [
self.__parse_tool(n, s, auth_token_getters)
for n, s in manifest.tools.items()
]
return tools
171 changes: 157 additions & 14 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.


import types
from collections import defaultdict
from inspect import Parameter, Signature
from typing import Any
from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence

from aiohttp import ClientSession
from pytest import Session


class ToolboxTool:
Expand All @@ -32,20 +35,19 @@ class ToolboxTool:
and `inspect` work as expected.
"""

__url: str
__session: ClientSession
__signature__: Signature

def __init__(
self,
session: ClientSession,
base_url: str,
name: str,
desc: str,
params: list[Parameter],
params: Sequence[Parameter],
required_authn_params: Mapping[str, list[str]],
auth_service_token_getters: Mapping[str, Callable[[], str]],
):
"""
Initializes a callable that will trigger the tool invocation through the Toolbox server.
Initializes a callable that will trigger the tool invocation through the
Toolbox server.

Args:
session: The `aiohttp.ClientSession` used for making API requests.
Expand All @@ -54,19 +56,73 @@ def __init__(
desc: The description of the remote tool (used as its docstring).
params: A list of `inspect.Parameter` objects defining the tool's
arguments and their types/defaults.
required_authn_params: A dict of required authenticated parameters to a list
of services that provide values for them.
auth_service_token_getters: A dict of authService -> token (or callables that
produce a token)
"""

# used to invoke the toolbox API
self.__session = session
self.__session: ClientSession = session
self.__base_url: str = base_url
self.__url = f"{base_url}/api/tool/{name}/invoke"

# the following properties are set to help anyone that might inspect it determine
self.__desc = desc
self.__params = params

# the following properties are set to help anyone that might inspect it determine usage
self.__name__ = name
self.__doc__ = desc
self.__signature__ = Signature(parameters=params, return_annotation=str)
self.__annotations__ = {p.name: p.annotation for p in params}
# TODO: self.__qualname__ ??

# map of parameter name to auth service required by it
self.__required_authn_params = required_authn_params
# map of authService -> token_getter
self.__auth_service_token_getters = auth_service_token_getters

def __copy(
self,
session: Optional[ClientSession] = None,
base_url: Optional[str] = None,
name: Optional[str] = None,
desc: Optional[str] = None,
params: Optional[list[Parameter]] = None,
required_authn_params: Optional[Mapping[str, list[str]]] = None,
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
) -> "ToolboxTool":
"""
Creates a copy of the ToolboxTool, overriding specific fields.

Args:
session: The `aiohttp.ClientSession` used for making API requests.
base_url: The base URL of the Toolbox server API.
name: The name of the remote tool.
desc: The description of the remote tool (used as its docstring).
params: A list of `inspect.Parameter` objects defining the tool's
arguments and their types/defaults.
required_authn_params: A dict of required authenticated parameters that need
a auth_service_token_getter set for them yet.
auth_service_token_getters: A dict of authService -> token (or callables
that produce a token)

"""
check = lambda val, default: val if val is not None else default
return ToolboxTool(
session=check(session, self.__session),
base_url=check(base_url, self.__base_url),
name=check(name, self.__name__),
desc=check(desc, self.__desc),
params=check(params, self.__params),
required_authn_params=check(
required_authn_params, self.__required_authn_params
),
auth_service_token_getters=check(
auth_service_token_getters, self.__auth_service_token_getters
),
)

async def __call__(self, *args: Any, **kwargs: Any) -> str:
"""
Asynchronously calls the remote tool with the provided arguments.
Expand All @@ -81,16 +137,103 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
Returns:
The string result returned by the remote tool execution.
"""

# check if any auth services need to be specified yet
if len(self.__required_authn_params) > 0:
# Gather all the required auth services into a set
req_auth_services = set()
for s in self.__required_authn_params.values():
req_auth_services.update(s)
raise Exception(
f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
)

# validate inputs to this call using the signature
all_args = self.__signature__.bind(*args, **kwargs)
all_args.apply_defaults() # Include default values if not provided
payload = all_args.arguments

# create headers for auth services
headers = {}
for auth_service, token_getter in self.__auth_service_token_getters.items():
headers[f"{auth_service}_token"] = token_getter()

async with self.__session.post(
self.__url,
json=payload,
headers=headers,
) as resp:
ret = await resp.json()
if "error" in ret:
# TODO: better error
raise Exception(ret["error"])
return ret.get("result", ret)
body = await resp.json()
if resp.status < 200 or resp.status >= 300:
err = body.get("error", f"unexpected status from server: {resp.status}")
raise Exception(err)
return body.get("result", body)

def add_auth_token_getters(
self,
auth_token_getters: Mapping[str, Callable[[], str]],
) -> "ToolboxTool":
"""
Registers an auth token getter function that is used for AuthServices when tools
are invoked.

Args:
auth_token_getters: A mapping of authentication service names to
callables that return the corresponding authentication token.

Returns:
A new ToolboxTool instance with the specified authentication token
getters registered.
"""

# throw an error if the authentication source is already registered
existing_services = self.__auth_service_token_getters.keys()
incoming_services = auth_token_getters.keys()
duplicates = existing_services & incoming_services
if duplicates:
raise ValueError(
f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`."
)

# create a read-only updated value for new_getters
new_getters = types.MappingProxyType(
dict(self.__auth_service_token_getters, **auth_token_getters)
)
# create a read-only updated for params that are still required
new_req_authn_params = types.MappingProxyType(
identify_required_authn_params(
self.__required_authn_params, auth_token_getters.keys()
)
)

return self.__copy(
auth_service_token_getters=new_getters,
required_authn_params=new_req_authn_params,
)


def identify_required_authn_params(
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
) -> dict[str, list[str]]:
"""
Identifies authentication parameters that are still required; or not covered by
the provided `auth_service_names`.

Args:
req_authn_params: A mapping of parameter names to sets of required
authentication services.
auth_service_names: An iterable of authentication service names for which
token getters are available.

Returns:
A new dictionary representing the subset of required authentication
parameters that are not covered by the provided `auth_service_names`.
"""
required_params = {} # params that are still required with provided auth_services
for param, services in req_authn_params.items():
# if we don't have a token_getter for any of the services required by the param,
# the param is still required
required = not any(s in services for s in auth_service_names)
if required:
required_params[param] = services
return required_params
Loading