15
15
16
16
import inspect
17
17
import json
18
+ from unittest .mock import AsyncMock , Mock
18
19
19
20
import pytest
20
21
import pytest_asyncio
@@ -130,6 +131,60 @@ async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_b
130
131
assert {t .__name__ for t in tools } == manifest .tools .keys ()
131
132
132
133
134
+ @pytest .mark .asyncio
135
+ async def test_invoke_tool_server_error (aioresponses , test_tool_str ):
136
+ """Tests that invoking a tool raises an Exception when the server returns an
137
+ error status."""
138
+ TOOL_NAME = "server_error_tool"
139
+ ERROR_MESSAGE = "Simulated Server Error"
140
+ manifest = ManifestSchema (serverVersion = "0.0.0" , tools = {TOOL_NAME : test_tool_str })
141
+
142
+ aioresponses .get (
143
+ f"{ TEST_BASE_URL } /api/tool/{ TOOL_NAME } " ,
144
+ payload = manifest .model_dump (),
145
+ status = 200 ,
146
+ )
147
+ aioresponses .post (
148
+ f"{ TEST_BASE_URL } /api/tool/{ TOOL_NAME } /invoke" ,
149
+ payload = {"error" : ERROR_MESSAGE },
150
+ status = 500 ,
151
+ )
152
+
153
+ async with ToolboxClient (TEST_BASE_URL ) as client :
154
+ loaded_tool = await client .load_tool (TOOL_NAME )
155
+
156
+ with pytest .raises (Exception , match = ERROR_MESSAGE ):
157
+ await loaded_tool (param1 = "some input" )
158
+
159
+
160
+ @pytest .mark .asyncio
161
+ async def test_load_tool_not_found_in_manifest (aioresponses , test_tool_str ):
162
+ """
163
+ Tests that load_tool raises an Exception when the requested tool name
164
+ is not found in the manifest returned by the server, using existing fixtures.
165
+ """
166
+ ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc"
167
+ REQUESTED_TOOL_NAME = "non_existent_tool_xyz"
168
+
169
+ manifest = ManifestSchema (
170
+ serverVersion = "0.0.0" , tools = {ACTUAL_TOOL_IN_MANIFEST : test_tool_str }
171
+ )
172
+
173
+ aioresponses .get (
174
+ f"{ TEST_BASE_URL } /api/tool/{ REQUESTED_TOOL_NAME } " ,
175
+ payload = manifest .model_dump (),
176
+ status = 200 ,
177
+ )
178
+
179
+ async with ToolboxClient (TEST_BASE_URL ) as client :
180
+ with pytest .raises (Exception , match = f"Tool '{ REQUESTED_TOOL_NAME } ' not found!" ):
181
+ await client .load_tool (REQUESTED_TOOL_NAME )
182
+
183
+ aioresponses .assert_called_once_with (
184
+ f"{ TEST_BASE_URL } /api/tool/{ REQUESTED_TOOL_NAME } " , method = "GET"
185
+ )
186
+
187
+
133
188
class TestAuth :
134
189
135
190
@pytest .fixture
@@ -182,7 +237,7 @@ def token_handler():
182
237
tool = await client .load_tool (
183
238
tool_name , auth_token_getters = {"my-auth-service" : token_handler }
184
239
)
185
- res = await tool (5 )
240
+ await tool (5 )
186
241
187
242
@pytest .mark .asyncio
188
243
async def test_auth_with_add_token_success (
@@ -195,20 +250,35 @@ def token_handler():
195
250
196
251
tool = await client .load_tool (tool_name )
197
252
tool = tool .add_auth_token_getters ({"my-auth-service" : token_handler })
198
- res = await tool (5 )
253
+ await tool (5 )
199
254
200
255
@pytest .mark .asyncio
201
256
async def test_auth_with_load_tool_fail_no_token (
202
257
self , tool_name , expected_header , client
203
258
):
204
259
"""Tests 'load_tool' with auth token is specified."""
205
260
206
- def token_handler ():
207
- return expected_header
208
-
209
261
tool = await client .load_tool (tool_name )
210
262
with pytest .raises (Exception ):
211
- res = await tool (5 )
263
+ await tool (5 )
264
+
265
+ @pytest .mark .asyncio
266
+ async def test_add_auth_token_getters_duplicate_fail (self , tool_name , client ):
267
+ """
268
+ Tests that adding a duplicate auth token getter raises ValueError.
269
+ """
270
+ AUTH_SERVICE = "my-auth-service"
271
+
272
+ tool = await client .load_tool (tool_name )
273
+
274
+ authed_tool = tool .add_auth_token_getters ({AUTH_SERVICE : {}})
275
+ assert AUTH_SERVICE in authed_tool ._ToolboxTool__auth_service_token_getters
276
+
277
+ with pytest .raises (
278
+ ValueError ,
279
+ match = f"Authentication source\\ (s\\ ) `{ AUTH_SERVICE } ` already registered in tool `{ tool_name } `." ,
280
+ ):
281
+ authed_tool .add_auth_token_getters ({AUTH_SERVICE : {}})
212
282
213
283
214
284
class TestBoundParameter :
@@ -283,6 +353,22 @@ async def test_bind_param_success(self, tool_name, client):
283
353
assert len (tool .__signature__ .parameters ) == 2
284
354
assert "argA" in tool .__signature__ .parameters
285
355
356
+ tool = tool .bind_parameters ({"argA" : 5 })
357
+
358
+ assert len (tool .__signature__ .parameters ) == 1
359
+ assert "argA" not in tool .__signature__ .parameters
360
+
361
+ res = await tool (True )
362
+ assert "argA" in res
363
+
364
+ @pytest .mark .asyncio
365
+ async def test_bind_callable_param_success (self , tool_name , client ):
366
+ """Tests 'bind_param' with a bound parameter specified."""
367
+ tool = await client .load_tool (tool_name )
368
+
369
+ assert len (tool .__signature__ .parameters ) == 2
370
+ assert "argA" in tool .__signature__ .parameters
371
+
286
372
tool = tool .bind_parameters ({"argA" : lambda : 5 })
287
373
288
374
assert len (tool .__signature__ .parameters ) == 1
@@ -301,3 +387,67 @@ async def test_bind_param_fail(self, tool_name, client):
301
387
302
388
with pytest .raises (Exception ):
303
389
tool = tool .bind_parameters ({"argC" : lambda : 5 })
390
+
391
+ @pytest .mark .asyncio
392
+ async def test_bind_param_static_value_success (self , tool_name , client ):
393
+ """
394
+ Tests bind_parameters method with a static value.
395
+ """
396
+
397
+ bound_value = "Test value"
398
+
399
+ tool = await client .load_tool (tool_name )
400
+ bound_tool = tool .bind_parameters ({"argB" : bound_value })
401
+
402
+ assert bound_tool is not tool
403
+ assert "argB" not in bound_tool .__signature__ .parameters
404
+ assert "argA" in bound_tool .__signature__ .parameters
405
+
406
+ passed_value_a = 42
407
+ res_payload = await bound_tool (argA = passed_value_a )
408
+
409
+ assert res_payload == {"argA" : passed_value_a , "argB" : bound_value }
410
+
411
+ @pytest .mark .asyncio
412
+ async def test_bind_param_sync_callable_value_success (self , tool_name , client ):
413
+ """
414
+ Tests bind_parameters method with a sync callable value.
415
+ """
416
+
417
+ bound_value_result = True
418
+ bound_sync_callable = Mock (return_value = bound_value_result )
419
+
420
+ tool = await client .load_tool (tool_name )
421
+ bound_tool = tool .bind_parameters ({"argB" : bound_sync_callable })
422
+
423
+ assert bound_tool is not tool
424
+ assert "argB" not in bound_tool .__signature__ .parameters
425
+ assert "argA" in bound_tool .__signature__ .parameters
426
+
427
+ passed_value_a = 42
428
+ res_payload = await bound_tool (argA = passed_value_a )
429
+
430
+ assert res_payload == {"argA" : passed_value_a , "argB" : bound_value_result }
431
+ bound_sync_callable .assert_called_once ()
432
+
433
+ @pytest .mark .asyncio
434
+ async def test_bind_param_async_callable_value_success (self , tool_name , client ):
435
+ """
436
+ Tests bind_parameters method with an async callable value.
437
+ """
438
+
439
+ bound_value_result = True
440
+ bound_async_callable = AsyncMock (return_value = bound_value_result )
441
+
442
+ tool = await client .load_tool (tool_name )
443
+ bound_tool = tool .bind_parameters ({"argB" : bound_async_callable })
444
+
445
+ assert bound_tool is not tool
446
+ assert "argB" not in bound_tool .__signature__ .parameters
447
+ assert "argA" in bound_tool .__signature__ .parameters
448
+
449
+ passed_value_a = 42
450
+ res_payload = await bound_tool (argA = passed_value_a )
451
+
452
+ assert res_payload == {"argA" : passed_value_a , "argB" : bound_value_result }
453
+ bound_async_callable .assert_awaited_once ()
0 commit comments