Skip to content

Commit a2810c1

Browse files
authored
feat: Add bind_param (singular) to align with other packages (#210)
* feat: Add bind_param (singular) to align with other packages * chore: Add unit test coverage. * chore: Delint * chore: Delint
1 parent a36ae42 commit a2810c1

File tree

3 files changed

+186
-19
lines changed

3 files changed

+186
-19
lines changed

packages/toolbox-core/src/toolbox_core/sync_tool.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,31 @@ def bind_params(
159159
"""
160160
Binds parameters to values or callables that produce values.
161161
162-
Args:
163-
bound_params: A mapping of parameter names to values or callables that
164-
produce values.
162+
Args:
163+
bound_params: A mapping of parameter names to values or callables that
164+
produce values.
165165
166-
Returns:
167-
A new ToolboxSyncTool instance with the specified parameters bound.
166+
Returns:
167+
A new ToolboxSyncTool instance with the specified parameters bound.
168168
"""
169169

170170
new_async_tool = self.__async_tool.bind_params(bound_params)
171171
return ToolboxSyncTool(new_async_tool, self.__loop, self.__thread)
172+
173+
def bind_param(
174+
self,
175+
param_name: str,
176+
param_value: Union[Callable[[], Any], Any],
177+
) -> "ToolboxSyncTool":
178+
"""
179+
Binds a parameter to the value or callables that produce it.
180+
181+
Args:
182+
param_name: The name of the bound parameter.
183+
param_value: The value of the bound parameter, or a callable that
184+
returns the value.
185+
186+
Returns:
187+
A new ToolboxSyncTool instance with the specified parameter bound.
188+
"""
189+
return self.bind_params({param_name: param_value})

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,12 @@ def bind_params(
307307
"""
308308
Binds parameters to values or callables that produce values.
309309
310-
Args:
311-
bound_params: A mapping of parameter names to values or callables that
312-
produce values.
310+
Args:
311+
bound_params: A mapping of parameter names to values or callables that
312+
produce values.
313313
314-
Returns:
315-
A new ToolboxTool instance with the specified parameters bound.
314+
Returns:
315+
A new ToolboxTool instance with the specified parameters bound.
316316
"""
317317
param_names = set(p.name for p in self.__params)
318318
for name in bound_params.keys():
@@ -335,3 +335,21 @@ def bind_params(
335335
params=new_params,
336336
bound_params=MappingProxyType(all_bound_params),
337337
)
338+
339+
def bind_param(
340+
self,
341+
param_name: str,
342+
param_value: Union[Callable[[], Any], Any],
343+
) -> "ToolboxTool":
344+
"""
345+
Binds a parameter to the value or callables that produce it.
346+
347+
Args:
348+
param_name: The name of the bound parameter.
349+
param_value: The value of the bound parameter, or a callable that
350+
returns the value.
351+
352+
Returns:
353+
A new ToolboxTool instance with the specified parameters bound.
354+
"""
355+
return self.bind_params({param_name: param_value})

packages/toolbox-core/tests/test_client.py

Lines changed: 140 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,8 @@ async def test_load_toolset_success(self, tool_name, client):
435435
assert "argB" in res
436436

437437
@pytest.mark.asyncio
438-
async def test_bind_param_success(self, tool_name, client):
439-
"""Tests 'bind_param' with a bound parameter specified."""
438+
async def test_bind_params_success(self, tool_name, client):
439+
"""Tests 'bind_params' with a bound parameter specified."""
440440
tool = await client.load_tool(tool_name)
441441

442442
assert len(tool.__signature__.parameters) == 2
@@ -451,8 +451,8 @@ async def test_bind_param_success(self, tool_name, client):
451451
assert "argA" in res
452452

453453
@pytest.mark.asyncio
454-
async def test_bind_callable_param_success(self, tool_name, client):
455-
"""Tests 'bind_param' with a bound parameter specified."""
454+
async def test_bind_callable_params_success(self, tool_name, client):
455+
"""Tests 'bind_params' with a bound parameter specified."""
456456
tool = await client.load_tool(tool_name)
457457

458458
assert len(tool.__signature__.parameters) == 2
@@ -467,7 +467,7 @@ async def test_bind_callable_param_success(self, tool_name, client):
467467
assert "argA" in res
468468

469469
@pytest.mark.asyncio
470-
async def test_bind_param_fail(self, tool_name, client):
470+
async def test_bind_params_fail(self, tool_name, client):
471471
"""Tests 'bind_params' with a bound parameter that doesn't exist."""
472472
tool = await client.load_tool(tool_name)
473473

@@ -479,7 +479,7 @@ async def test_bind_param_fail(self, tool_name, client):
479479
assert "unable to bind parameters: no parameter named argC" in str(e.value)
480480

481481
@pytest.mark.asyncio
482-
async def test_rebind_param_fail(self, tool_name, client):
482+
async def test_rebind_params_fail(self, tool_name, client):
483483
"""
484484
Tests that 'bind_params' fails when attempting to re-bind a
485485
parameter that has already been bound.
@@ -502,7 +502,7 @@ async def test_rebind_param_fail(self, tool_name, client):
502502
)
503503

504504
@pytest.mark.asyncio
505-
async def test_bind_param_static_value_success(self, tool_name, client):
505+
async def test_bind_params_static_value_success(self, tool_name, client):
506506
"""
507507
Tests bind_params method with a static value.
508508
"""
@@ -522,7 +522,7 @@ async def test_bind_param_static_value_success(self, tool_name, client):
522522
assert res_payload == {"argA": passed_value_a, "argB": bound_value}
523523

524524
@pytest.mark.asyncio
525-
async def test_bind_param_sync_callable_value_success(self, tool_name, client):
525+
async def test_bind_params_sync_callable_value_success(self, tool_name, client):
526526
"""
527527
Tests bind_params method with a sync callable value.
528528
"""
@@ -544,7 +544,7 @@ async def test_bind_param_sync_callable_value_success(self, tool_name, client):
544544
bound_sync_callable.assert_called_once()
545545

546546
@pytest.mark.asyncio
547-
async def test_bind_param_async_callable_value_success(self, tool_name, client):
547+
async def test_bind_params_async_callable_value_success(self, tool_name, client):
548548
"""
549549
Tests bind_params method with an async callable value.
550550
"""
@@ -565,6 +565,137 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client):
565565
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
566566
bound_async_callable.assert_awaited_once()
567567

568+
@pytest.mark.asyncio
569+
async def test_bind_param_success(self, tool_name, client):
570+
"""Tests 'bind_param' with a bound parameter specified."""
571+
tool = await client.load_tool(tool_name)
572+
573+
assert len(tool.__signature__.parameters) == 2
574+
assert "argA" in tool.__signature__.parameters
575+
576+
tool = tool.bind_param("argA", 5)
577+
578+
assert len(tool.__signature__.parameters) == 1
579+
assert "argA" not in tool.__signature__.parameters
580+
581+
res = await tool(True)
582+
assert "argA" in res
583+
584+
@pytest.mark.asyncio
585+
async def test_bind_callable_param_success(self, tool_name, client):
586+
"""Tests 'bind_param' with a bound parameter specified."""
587+
tool = await client.load_tool(tool_name)
588+
589+
assert len(tool.__signature__.parameters) == 2
590+
assert "argA" in tool.__signature__.parameters
591+
592+
tool = tool.bind_param("argA", lambda: 5)
593+
594+
assert len(tool.__signature__.parameters) == 1
595+
assert "argA" not in tool.__signature__.parameters
596+
597+
res = await tool(True)
598+
assert "argA" in res
599+
600+
@pytest.mark.asyncio
601+
async def test_bind_param_fail(self, tool_name, client):
602+
"""Tests 'bind_param' with a bound parameter that doesn't exist."""
603+
tool = await client.load_tool(tool_name)
604+
605+
assert len(tool.__signature__.parameters) == 2
606+
assert "argA" in tool.__signature__.parameters
607+
608+
with pytest.raises(Exception) as e:
609+
tool.bind_param("argC", lambda: 5)
610+
assert "unable to bind parameters: no parameter named argC" in str(e.value)
611+
612+
@pytest.mark.asyncio
613+
async def test_rebind_param_fail(self, tool_name, client):
614+
"""
615+
Tests that 'bind_param' fails when attempting to re-bind a
616+
parameter that has already been bound.
617+
"""
618+
tool = await client.load_tool(tool_name)
619+
620+
assert len(tool.__signature__.parameters) == 2
621+
assert "argA" in tool.__signature__.parameters
622+
623+
tool_with_bound_param = tool.bind_param("argA", lambda: 10)
624+
625+
assert len(tool_with_bound_param.__signature__.parameters) == 1
626+
assert "argA" not in tool_with_bound_param.__signature__.parameters
627+
628+
with pytest.raises(ValueError) as e:
629+
tool_with_bound_param.bind_param("argA", lambda: 20)
630+
631+
assert "cannot re-bind parameter: parameter 'argA' is already bound" in str(
632+
e.value
633+
)
634+
635+
@pytest.mark.asyncio
636+
async def test_bind_param_static_value_success(self, tool_name, client):
637+
"""
638+
Tests bind_param method with a static value.
639+
"""
640+
641+
bound_value = "Test value"
642+
643+
tool = await client.load_tool(tool_name)
644+
bound_tool = tool.bind_param("argB", bound_value)
645+
646+
assert bound_tool is not tool
647+
assert "argB" not in bound_tool.__signature__.parameters
648+
assert "argA" in bound_tool.__signature__.parameters
649+
650+
passed_value_a = 42
651+
res_payload = await bound_tool(argA=passed_value_a)
652+
653+
assert res_payload == {"argA": passed_value_a, "argB": bound_value}
654+
655+
@pytest.mark.asyncio
656+
async def test_bind_param_sync_callable_value_success(self, tool_name, client):
657+
"""
658+
Tests bind_param method with a sync callable value.
659+
"""
660+
661+
bound_value_result = True
662+
bound_sync_callable = Mock(return_value=bound_value_result)
663+
664+
tool = await client.load_tool(tool_name)
665+
bound_tool = tool.bind_param("argB", bound_sync_callable)
666+
667+
assert bound_tool is not tool
668+
assert "argB" not in bound_tool.__signature__.parameters
669+
assert "argA" in bound_tool.__signature__.parameters
670+
671+
passed_value_a = 42
672+
res_payload = await bound_tool(argA=passed_value_a)
673+
674+
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
675+
bound_sync_callable.assert_called_once()
676+
677+
@pytest.mark.asyncio
678+
async def test_bind_param_async_callable_value_success(self, tool_name, client):
679+
"""
680+
Tests bind_param method with an async callable value.
681+
"""
682+
683+
bound_value_result = True
684+
bound_async_callable = AsyncMock(return_value=bound_value_result)
685+
686+
tool = await client.load_tool(tool_name)
687+
bound_tool = tool.bind_param("argB", bound_async_callable)
688+
689+
assert bound_tool is not tool
690+
assert "argB" not in bound_tool.__signature__.parameters
691+
assert "argA" in bound_tool.__signature__.parameters
692+
693+
passed_value_a = 42
694+
res_payload = await bound_tool(argA=passed_value_a)
695+
696+
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
697+
bound_async_callable.assert_awaited_once()
698+
568699

569700
class TestUnusedParameterValidation:
570701
"""

0 commit comments

Comments
 (0)