Skip to content

feat: Add bind_param (singular) to align with other packages #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions packages/toolbox-core/src/toolbox_core/sync_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,31 @@ def bind_params(
"""
Binds parameters to values or callables that produce values.

Args:
bound_params: A mapping of parameter names to values or callables that
produce values.
Args:
bound_params: A mapping of parameter names to values or callables that
produce values.

Returns:
A new ToolboxSyncTool instance with the specified parameters bound.
Returns:
A new ToolboxSyncTool instance with the specified parameters bound.
"""

new_async_tool = self.__async_tool.bind_params(bound_params)
return ToolboxSyncTool(new_async_tool, self.__loop, self.__thread)

def bind_param(
self,
param_name: str,
param_value: Union[Callable[[], Any], Any],
) -> "ToolboxSyncTool":
"""
Binds a parameter to the value or callables that produce it.

Args:
param_name: The name of the bound parameter.
param_value: The value of the bound parameter, or a callable that
returns the value.

Returns:
A new ToolboxSyncTool instance with the specified parameter bound.
"""
return self.bind_params({param_name: param_value})
28 changes: 23 additions & 5 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ def bind_params(
"""
Binds parameters to values or callables that produce values.

Args:
bound_params: A mapping of parameter names to values or callables that
produce values.
Args:
bound_params: A mapping of parameter names to values or callables that
produce values.

Returns:
A new ToolboxTool instance with the specified parameters bound.
Returns:
A new ToolboxTool instance with the specified parameters bound.
"""
param_names = set(p.name for p in self.__params)
for name in bound_params.keys():
Expand All @@ -335,3 +335,21 @@ def bind_params(
params=new_params,
bound_params=MappingProxyType(all_bound_params),
)

def bind_param(
self,
param_name: str,
param_value: Union[Callable[[], Any], Any],
) -> "ToolboxTool":
"""
Binds a parameter to the value or callables that produce it.

Args:
param_name: The name of the bound parameter.
param_value: The value of the bound parameter, or a callable that
returns the value.

Returns:
A new ToolboxTool instance with the specified parameters bound.
"""
return self.bind_params({param_name: param_value})
149 changes: 140 additions & 9 deletions packages/toolbox-core/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ async def test_load_toolset_success(self, tool_name, client):
assert "argB" in res

@pytest.mark.asyncio
async def test_bind_param_success(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter specified."""
async def test_bind_params_success(self, tool_name, client):
"""Tests 'bind_params' with a bound parameter specified."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
Expand All @@ -451,8 +451,8 @@ async def test_bind_param_success(self, tool_name, client):
assert "argA" in res

@pytest.mark.asyncio
async def test_bind_callable_param_success(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter specified."""
async def test_bind_callable_params_success(self, tool_name, client):
"""Tests 'bind_params' with a bound parameter specified."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
Expand All @@ -467,7 +467,7 @@ async def test_bind_callable_param_success(self, tool_name, client):
assert "argA" in res

@pytest.mark.asyncio
async def test_bind_param_fail(self, tool_name, client):
async def test_bind_params_fail(self, tool_name, client):
"""Tests 'bind_params' with a bound parameter that doesn't exist."""
tool = await client.load_tool(tool_name)

Expand All @@ -479,7 +479,7 @@ async def test_bind_param_fail(self, tool_name, client):
assert "unable to bind parameters: no parameter named argC" in str(e.value)

@pytest.mark.asyncio
async def test_rebind_param_fail(self, tool_name, client):
async def test_rebind_params_fail(self, tool_name, client):
"""
Tests that 'bind_params' fails when attempting to re-bind a
parameter that has already been bound.
Expand All @@ -502,7 +502,7 @@ async def test_rebind_param_fail(self, tool_name, client):
)

@pytest.mark.asyncio
async def test_bind_param_static_value_success(self, tool_name, client):
async def test_bind_params_static_value_success(self, tool_name, client):
"""
Tests bind_params method with a static value.
"""
Expand All @@ -522,7 +522,7 @@ async def test_bind_param_static_value_success(self, tool_name, client):
assert res_payload == {"argA": passed_value_a, "argB": bound_value}

@pytest.mark.asyncio
async def test_bind_param_sync_callable_value_success(self, tool_name, client):
async def test_bind_params_sync_callable_value_success(self, tool_name, client):
"""
Tests bind_params method with a sync callable value.
"""
Expand All @@ -544,7 +544,7 @@ async def test_bind_param_sync_callable_value_success(self, tool_name, client):
bound_sync_callable.assert_called_once()

@pytest.mark.asyncio
async def test_bind_param_async_callable_value_success(self, tool_name, client):
async def test_bind_params_async_callable_value_success(self, tool_name, client):
"""
Tests bind_params method with an async callable value.
"""
Expand All @@ -565,6 +565,137 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client):
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_async_callable.assert_awaited_once()

@pytest.mark.asyncio
async def test_bind_param_success(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter specified."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool = tool.bind_param("argA", 5)

assert len(tool.__signature__.parameters) == 1
assert "argA" not in tool.__signature__.parameters

res = await tool(True)
assert "argA" in res

@pytest.mark.asyncio
async def test_bind_callable_param_success(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter specified."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool = tool.bind_param("argA", lambda: 5)

assert len(tool.__signature__.parameters) == 1
assert "argA" not in tool.__signature__.parameters

res = await tool(True)
assert "argA" in res

@pytest.mark.asyncio
async def test_bind_param_fail(self, tool_name, client):
"""Tests 'bind_param' with a bound parameter that doesn't exist."""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

with pytest.raises(Exception) as e:
tool.bind_param("argC", lambda: 5)
assert "unable to bind parameters: no parameter named argC" in str(e.value)

@pytest.mark.asyncio
async def test_rebind_param_fail(self, tool_name, client):
"""
Tests that 'bind_param' fails when attempting to re-bind a
parameter that has already been bound.
"""
tool = await client.load_tool(tool_name)

assert len(tool.__signature__.parameters) == 2
assert "argA" in tool.__signature__.parameters

tool_with_bound_param = tool.bind_param("argA", lambda: 10)

assert len(tool_with_bound_param.__signature__.parameters) == 1
assert "argA" not in tool_with_bound_param.__signature__.parameters

with pytest.raises(ValueError) as e:
tool_with_bound_param.bind_param("argA", lambda: 20)

assert "cannot re-bind parameter: parameter 'argA' is already bound" in str(
e.value
)

@pytest.mark.asyncio
async def test_bind_param_static_value_success(self, tool_name, client):
"""
Tests bind_param method with a static value.
"""

bound_value = "Test value"

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_param("argB", bound_value)

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value}

@pytest.mark.asyncio
async def test_bind_param_sync_callable_value_success(self, tool_name, client):
"""
Tests bind_param method with a sync callable value.
"""

bound_value_result = True
bound_sync_callable = Mock(return_value=bound_value_result)

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_param("argB", bound_sync_callable)

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_sync_callable.assert_called_once()

@pytest.mark.asyncio
async def test_bind_param_async_callable_value_success(self, tool_name, client):
"""
Tests bind_param method with an async callable value.
"""

bound_value_result = True
bound_async_callable = AsyncMock(return_value=bound_value_result)

tool = await client.load_tool(tool_name)
bound_tool = tool.bind_param("argB", bound_async_callable)

assert bound_tool is not tool
assert "argB" not in bound_tool.__signature__.parameters
assert "argA" in bound_tool.__signature__.parameters

passed_value_a = 42
res_payload = await bound_tool(argA=passed_value_a)

assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
bound_async_callable.assert_awaited_once()


class TestUnusedParameterValidation:
"""
Expand Down