Skip to content

Commit 4d19c3e

Browse files
authored
add format to ToolCalls (#8455)
1 parent 547aa3e commit 4d19c3e

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

dspy/adapters/types/tool.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,23 @@ def description(cls) -> str:
291291
"Arguments must be provided in JSON format."
292292
)
293293

294+
def format(self) -> list[dict[str, Any]]:
295+
# The tool_call field is compatible with OpenAI's tool calls schema.
296+
return [
297+
{
298+
"type": "tool_calls",
299+
"tool_calls": [
300+
{
301+
"type": "function",
302+
"function": {
303+
"name": tool_call.name,
304+
"arguments": tool_call.args,
305+
},
306+
} for tool_call in self.tool_calls
307+
],
308+
}
309+
]
310+
294311

295312
def _resolve_json_schema_reference(schema: dict) -> dict:
296313
"""Recursively resolve json model schema, expanding all references."""

tests/adapters/test_tool.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import BaseModel
66

77
import dspy
8-
from dspy.adapters.types.tool import Tool
8+
from dspy.adapters.types.tool import Tool, ToolCalls
99

1010

1111
# Test fixtures
@@ -375,3 +375,71 @@ def test_async_tool_call_in_sync_mode():
375375
with dspy.context(allow_tool_async_sync_conversion=True):
376376
result = tool(x=1, y="hello")
377377
assert result == "hello 1"
378+
379+
380+
TOOL_CALL_TEST_CASES = [
381+
([], [{"type": "tool_calls", "tool_calls": []}]),
382+
(
383+
[{"name": "search", "args": {"query": "hello"}}],
384+
[{
385+
"type": "tool_calls",
386+
"tool_calls": [{
387+
"type": "function",
388+
"function": {"name": "search", "arguments": {"query": "hello"}}
389+
}]
390+
}],
391+
),
392+
(
393+
[
394+
{"name": "search", "args": {"query": "hello"}},
395+
{"name": "translate", "args": {"text": "world", "lang": "fr"}}
396+
],
397+
[{
398+
"type": "tool_calls",
399+
"tool_calls": [
400+
{
401+
"type": "function",
402+
"function": {"name": "search", "arguments": {"query": "hello"}}
403+
},
404+
{
405+
"type": "function",
406+
"function": {"name": "translate", "arguments": {"text": "world", "lang": "fr"}}
407+
}
408+
]
409+
}],
410+
),
411+
(
412+
[{"name": "get_time", "args": {}}],
413+
[{
414+
"type": "tool_calls",
415+
"tool_calls": [{
416+
"type": "function",
417+
"function": {"name": "get_time", "arguments": {}}
418+
}]
419+
}],
420+
),
421+
]
422+
423+
424+
@pytest.mark.parametrize("tool_calls_data,expected", TOOL_CALL_TEST_CASES)
425+
def test_tool_calls_format_basic(tool_calls_data, expected):
426+
"""Test ToolCalls.format with various basic scenarios."""
427+
tool_calls_list = [ToolCalls.ToolCall(**data) for data in tool_calls_data]
428+
tool_calls = ToolCalls(tool_calls=tool_calls_list)
429+
result = tool_calls.format()
430+
431+
assert result == expected
432+
433+
def test_tool_calls_format_from_dict_list():
434+
"""Test format works with ToolCalls created from from_dict_list."""
435+
tool_calls_dicts = [
436+
{"name": "search", "args": {"query": "hello"}},
437+
{"name": "translate", "args": {"text": "world", "lang": "fr"}}
438+
]
439+
440+
tool_calls = ToolCalls.from_dict_list(tool_calls_dicts)
441+
result = tool_calls.format()
442+
443+
assert len(result[0]["tool_calls"]) == 2
444+
assert result[0]["tool_calls"][0]["function"]["name"] == "search"
445+
assert result[0]["tool_calls"][1]["function"]["name"] == "translate"

0 commit comments

Comments
 (0)