Skip to content

feat: Enhance authorization token validation with authRequired #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .protocol import ManifestSchema, ToolSchema
from .tool import ToolboxTool
from .utils import identify_required_authn_params, resolve_value
from .utils import identify_auth_requirements, resolve_value


class ToolboxClient:
Expand Down Expand Up @@ -79,10 +79,9 @@ def __parse_tool(
else: # regular parameter
params.append(p)

authn_params, _, used_auth_keys = identify_required_authn_params(
# TODO: Add schema.authRequired as second arg
authn_params, authz_tokens, used_auth_keys = identify_auth_requirements(
authn_params,
[],
schema.authRequired,
auth_token_getters.keys(),
)

Expand All @@ -94,6 +93,7 @@ def __parse_tool(
# create a read-only values to prevent mutation
params=tuple(params),
required_authn_params=types.MappingProxyType(authn_params),
required_authz_tokens=authz_tokens,
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
bound_params=types.MappingProxyType(bound_params),
client_headers=types.MappingProxyType(client_headers),
Expand Down
68 changes: 42 additions & 26 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .protocol import ParameterSchema
from .utils import (
create_func_docstring,
identify_required_authn_params,
identify_auth_requirements,
params_to_pydantic_model,
resolve_value,
)
Expand All @@ -49,6 +49,7 @@ def __init__(
description: str,
params: Sequence[ParameterSchema],
required_authn_params: Mapping[str, list[str]],
required_authz_tokens: Sequence[str],
auth_service_token_getters: Mapping[str, Callable[[], str]],
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
Expand All @@ -63,12 +64,14 @@ def __init__(
name: The name of the remote tool.
description: The description of the remote tool.
params: The args of the tool.
required_authn_params: A map of required authenticated parameters to a list
of alternative services that can provide values for them.
auth_service_token_getters: A dict of authService -> token (or callables that
produce a token)
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.
required_authn_params: A map of required authenticated parameters to
a list of alternative services that can provide values for them.
required_authz_tokens: A sequence of alternative services for
providing authorization token for the tool invocation.
auth_service_token_getters: A dict of authService -> token (or
callables that produce a token)
bound_params: A mapping of parameter names to bind to specific
values or callables that are called to produce values as needed.
client_headers: Client specific headers bound to the tool.
"""
# used to invoke the toolbox API
Expand Down Expand Up @@ -106,6 +109,8 @@ def __init__(

# map of parameter name to auth service required by it
self.__required_authn_params = required_authn_params
# sequence of authorization tokens required by it
self.__required_authz_tokens = required_authz_tokens
# map of authService -> token_getter
self.__auth_service_token_getters = auth_service_token_getters
# map of parameter name to value (or callable that produces that value)
Expand Down Expand Up @@ -149,6 +154,7 @@ def __copy(
description: Optional[str] = None,
params: Optional[Sequence[ParameterSchema]] = None,
required_authn_params: Optional[Mapping[str, list[str]]] = None,
required_authz_tokens: Optional[Sequence[str]] = None,
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
Expand All @@ -162,12 +168,14 @@ def __copy(
name: The name of the remote tool.
description: The description of the remote tool.
params: The args of the tool.
required_authn_params: A map of required authenticated parameters to a list
of alternative services that can provide values for them.
auth_service_token_getters: A dict of authService -> token (or callables
that produce a token)
bound_params: A mapping of parameter names to bind to specific values or
callables that are called to produce values as needed.
required_authn_params: A map of required authenticated parameters to
a list of alternative services that can provide values for them.
required_authz_tokens: A sequence of alternative services for
providing authorization token for the tool invocation.
auth_service_token_getters: A dict of authService -> token (or
callables that produce a token)
bound_params: A mapping of parameter names to bind to specific
values or callables that are called to produce values as needed.
client_headers: Client specific headers bound to the tool.
"""
check = lambda val, default: val if val is not None else default
Expand All @@ -180,6 +188,9 @@ def __copy(
required_authn_params=check(
required_authn_params, self.__required_authn_params
),
required_authz_tokens=check(
required_authz_tokens, self.__required_authz_tokens
),
auth_service_token_getters=check(
auth_service_token_getters, self.__auth_service_token_getters
),
Expand Down Expand Up @@ -207,11 +218,15 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
"""

# check if any auth services need to be specified yet
if len(self.__required_authn_params) > 0:
if (
len(self.__required_authn_params) > 0
or len(self.__required_authz_tokens) > 0
):
# Gather all the required auth services into a set
req_auth_services = set()
for s in self.__required_authn_params.values():
req_auth_services.update(s)
req_auth_services.update(self.__required_authz_tokens)
raise ValueError(
f"One or more of the following authn services are required to invoke this tool"
f": {','.join(req_auth_services)}"
Expand Down Expand Up @@ -292,23 +307,24 @@ def add_auth_token_getters(
f"Cannot register client the same headers in the client as well as tool."
)

# create a read-only updated value for new_getters
new_getters = MappingProxyType(
dict(self.__auth_service_token_getters, **auth_token_getters)
)
# create a read-only updated for params that are still required
new_req_authn_params = MappingProxyType(
identify_required_authn_params(
# TODO: Add authRequired
new_getters = dict(self.__auth_service_token_getters, **auth_token_getters)

# find the updated requirements
new_req_authn_params, new_req_authz_tokens, used_auth_token_getters = (
identify_auth_requirements(
self.__required_authn_params,
[],
self.__required_authz_tokens,
auth_token_getters.keys(),
)[0]
)
)

# TODO: Add validation for used_auth_token_getters

return self.__copy(
auth_service_token_getters=new_getters,
required_authn_params=new_req_authn_params,
# create a read-only map for updated getters, params and tokens that are still required
auth_service_token_getters=MappingProxyType(new_getters),
required_authn_params=MappingProxyType(new_req_authn_params),
required_authz_tokens=tuple(new_req_authz_tokens),
)

def bind_params(
Expand Down
6 changes: 3 additions & 3 deletions packages/toolbox-core/src/toolbox_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) -
return docstring


def identify_required_authn_params(
def identify_auth_requirements(
req_authn_params: Mapping[str, list[str]],
req_authz_tokens: list[str],
req_authz_tokens: Sequence[str],
auth_service_names: Iterable[str],
) -> tuple[dict[str, list[str]], list[str], set[str]]:
"""
Expand Down Expand Up @@ -100,7 +100,7 @@ def identify_required_authn_params(
if matched_authz_services:
used_services.update(matched_authz_services)
else:
required_authz_tokens = req_authz_tokens
required_authz_tokens = list(req_authz_tokens)

return required_authn_params, required_authz_tokens, used_services

Expand Down
2 changes: 1 addition & 1 deletion packages/toolbox-core/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def test_run_tool_no_auth(self, toolbox: ToolboxClient):
tool = await toolbox.load_tool("get-row-by-id-auth")
with pytest.raises(
Exception,
match="tool invocation not authorized. Please make sure your specify correct auth headers",
match="One or more of the following authn services are required to invoke this tool: my-test-auth",
):
await tool(id="2")

Expand Down
2 changes: 1 addition & 1 deletion packages/toolbox-core/tests/test_sync_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_run_tool_no_auth(self, toolbox: ToolboxSyncClient):
tool = toolbox.load_tool("get-row-by-id-auth")
with pytest.raises(
Exception,
match="tool invocation not authorized. Please make sure your specify correct auth headers",
match="One or more of the following authn services are required to invoke this tool: my-test-auth",
):
tool(id="2")

Expand Down
6 changes: 6 additions & 0 deletions packages/toolbox-core/tests/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ async def test_tool_creation_callable_and_run(
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers={},
Expand Down Expand Up @@ -250,6 +251,7 @@ async def test_tool_run_with_pydantic_validation_error(
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers={},
Expand Down Expand Up @@ -337,6 +339,7 @@ def test_tool_init_basic(http_session, sample_tool_params, sample_tool_descripti
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers={},
Expand All @@ -361,6 +364,7 @@ def test_tool_init_with_client_headers(
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers=static_client_header,
Expand Down Expand Up @@ -388,6 +392,7 @@ def test_tool_init_header_auth_conflict(
description=sample_tool_description,
params=sample_tool_auth_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters=auth_getters,
bound_params={},
client_headers=conflicting_client_header,
Expand All @@ -410,6 +415,7 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers={
Expand Down
Loading