|
5 | 5 | from pydantic import BaseModel
|
6 | 6 |
|
7 | 7 | import dspy
|
8 |
| -from dspy.adapters.types.tool import Tool |
| 8 | +from dspy.adapters.types.tool import Tool, ToolCalls |
9 | 9 |
|
10 | 10 |
|
11 | 11 | # Test fixtures
|
@@ -375,3 +375,71 @@ def test_async_tool_call_in_sync_mode():
|
375 | 375 | with dspy.context(allow_tool_async_sync_conversion=True):
|
376 | 376 | result = tool(x=1, y="hello")
|
377 | 377 | assert result == "hello 1"
|
| 378 | + |
| 379 | + |
| 380 | +TOOL_CALL_TEST_CASES = [ |
| 381 | + ([], [{"type": "tool_calls", "tool_calls": []}]), |
| 382 | + ( |
| 383 | + [{"name": "search", "args": {"query": "hello"}}], |
| 384 | + [{ |
| 385 | + "type": "tool_calls", |
| 386 | + "tool_calls": [{ |
| 387 | + "type": "function", |
| 388 | + "function": {"name": "search", "arguments": {"query": "hello"}} |
| 389 | + }] |
| 390 | + }], |
| 391 | + ), |
| 392 | + ( |
| 393 | + [ |
| 394 | + {"name": "search", "args": {"query": "hello"}}, |
| 395 | + {"name": "translate", "args": {"text": "world", "lang": "fr"}} |
| 396 | + ], |
| 397 | + [{ |
| 398 | + "type": "tool_calls", |
| 399 | + "tool_calls": [ |
| 400 | + { |
| 401 | + "type": "function", |
| 402 | + "function": {"name": "search", "arguments": {"query": "hello"}} |
| 403 | + }, |
| 404 | + { |
| 405 | + "type": "function", |
| 406 | + "function": {"name": "translate", "arguments": {"text": "world", "lang": "fr"}} |
| 407 | + } |
| 408 | + ] |
| 409 | + }], |
| 410 | + ), |
| 411 | + ( |
| 412 | + [{"name": "get_time", "args": {}}], |
| 413 | + [{ |
| 414 | + "type": "tool_calls", |
| 415 | + "tool_calls": [{ |
| 416 | + "type": "function", |
| 417 | + "function": {"name": "get_time", "arguments": {}} |
| 418 | + }] |
| 419 | + }], |
| 420 | + ), |
| 421 | +] |
| 422 | + |
| 423 | + |
| 424 | +@pytest.mark.parametrize("tool_calls_data,expected", TOOL_CALL_TEST_CASES) |
| 425 | +def test_tool_calls_format_basic(tool_calls_data, expected): |
| 426 | + """Test ToolCalls.format with various basic scenarios.""" |
| 427 | + tool_calls_list = [ToolCalls.ToolCall(**data) for data in tool_calls_data] |
| 428 | + tool_calls = ToolCalls(tool_calls=tool_calls_list) |
| 429 | + result = tool_calls.format() |
| 430 | + |
| 431 | + assert result == expected |
| 432 | + |
| 433 | +def test_tool_calls_format_from_dict_list(): |
| 434 | + """Test format works with ToolCalls created from from_dict_list.""" |
| 435 | + tool_calls_dicts = [ |
| 436 | + {"name": "search", "args": {"query": "hello"}}, |
| 437 | + {"name": "translate", "args": {"text": "world", "lang": "fr"}} |
| 438 | + ] |
| 439 | + |
| 440 | + tool_calls = ToolCalls.from_dict_list(tool_calls_dicts) |
| 441 | + result = tool_calls.format() |
| 442 | + |
| 443 | + assert len(result[0]["tool_calls"]) == 2 |
| 444 | + assert result[0]["tool_calls"][0]["function"]["name"] == "search" |
| 445 | + assert result[0]["tool_calls"][1]["function"]["name"] == "translate" |
0 commit comments