Skip to content

chore(toolbox-core): Consolidate auth header creation logic #213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
# Validate conflicting Headers/Auth Tokens
request_header_names = client_headers.keys()
auth_token_names = [
auth_token_name + "_token"
self.__get_auth_header(auth_token_name)
for auth_token_name in auth_service_token_getters.keys()
]
duplicates = request_header_names & auth_token_names
Expand Down Expand Up @@ -187,6 +187,10 @@ def __copy(
client_headers=check(client_headers, self.__client_headers),
)

def __get_auth_header(self, auth_token_name: str) -> str:
"""Returns the formatted auth token header name."""
return f"{auth_token_name}_token"

async def __call__(self, *args: Any, **kwargs: Any) -> str:
"""
Asynchronously calls the remote tool with the provided arguments.
Expand Down Expand Up @@ -228,7 +232,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
# create headers for auth services
headers = {}
for auth_service, token_getter in self.__auth_service_token_getters.items():
headers[f"{auth_service}_token"] = await resolve_value(token_getter)
headers[self.__get_auth_header(auth_service)] = await resolve_value(
token_getter
)
for client_header_name, client_header_val in self.__client_headers.items():
headers[client_header_name] = await resolve_value(client_header_val)

Expand Down Expand Up @@ -276,7 +282,8 @@ def add_auth_token_getters(
# Validate duplicates with client headers
request_header_names = self.__client_headers.keys()
auth_token_names = [
auth_token_name + "_token" for auth_token_name in incoming_services
self.__get_auth_header(auth_token_name)
for auth_token_name in incoming_services
]
duplicates = request_header_names & auth_token_names
if duplicates:
Expand Down
131 changes: 131 additions & 0 deletions packages/toolbox-core/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,137 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client):
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_async_callable.assert_awaited_once()

@pytest.mark.asyncio
async def test_bind_param_success(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter specified."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool = tool.bind_param("argA", 5)

assert len(tool.__signature__.parameters) == 1
assert "argA" not in tool.__signature__.parameters

res = await tool(True)
assert "argA" in res

@pytest.mark.asyncio
async def test_bind_callable_param_success(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter specified."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool = tool.bind_param("argA", lambda: 5)

assert len(tool.__signature__.parameters) == 1
assert "argA" not in tool.__signature__.parameters

res = await tool(True)
assert "argA" in res

@pytest.mark.asyncio
async def test_bind_param_fail(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter that doesn't exist."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

with pytest.raises(Exception) as e:
tool.bind_param("argC", lambda: 5)
assert "unable to bind parameters: no parameter named argC" in str(e.value)

@pytest.mark.asyncio
async def test_rebind_param_fail(self, tool_name, client):
"""
Tests that 'bind_param' fails when attempting to re-bind a
parameter that has already been bound.
"""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool_with_bound_param = tool.bind_param("argA", lambda: 10)

assert len(tool_with_bound_param.__signature__.parameters) == 1
assert "argA" not in tool_with_bound_param.__signature__.parameters

with pytest.raises(ValueError) as e:
tool_with_bound_param.bind_param("argA", lambda: 20)

assert "cannot re-bind parameter: parameter 'argA' is already bound" in str(
e.value
)

@pytest.mark.asyncio
async def test_bind_param_static_value_success(self, tool_name, client):
"""
Tests bind_param method with a static value.
"""

bound_value = "Test value"

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_param("argB", bound_value)

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value}

@pytest.mark.asyncio
async def test_bind_param_sync_callable_value_success(self, tool_name, client):
"""
Tests bind_param method with a sync callable value.
"""

bound_value_result = True
bound_sync_callable = Mock(return_value=bound_value_result)

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_param("argB", bound_sync_callable)

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_sync_callable.assert_called_once()

@pytest.mark.asyncio
async def test_bind_param_async_callable_value_success(self, tool_name, client):
"""
Tests bind_param method with an async callable value.
"""

bound_value_result = True
bound_async_callable = AsyncMock(return_value=bound_value_result)

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_param("argB", bound_async_callable)

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_async_callable.assert_awaited_once()


class TestUnusedParameterValidation:
"""
Expand Down