Skip to content

Commit 99a07e9

Browse files
Support dspy.Tool as input field type and dspy.ToolCall as output field type (#8242)
* init * Support tool and toolcall in input and output field types * move file * increment * better ways * return a list from lm for backward compatibility * fix tests * better name * fix tests * fix invalid import path
1 parent 0e68568 commit 99a07e9

19 files changed

+402
-55
lines changed

dspy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from dspy.evaluate import Evaluate # isort: skip
1010
from dspy.clients import * # isort: skip
11-
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType # isort: skip
11+
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType, Tool, ToolCalls # isort: skip
1212
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1313
from dspy.utils.asyncify import asyncify
1414
from dspy.utils.saving import load

dspy/adapters/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dspy.adapters.chat_adapter import ChatAdapter
33
from dspy.adapters.json_adapter import JSONAdapter
44
from dspy.adapters.two_step_adapter import TwoStepAdapter
5-
from dspy.adapters.types import History, Image, Audio, BaseType
5+
from dspy.adapters.types import History, Image, Audio, BaseType, Tool, ToolCalls
66

77
__all__ = [
88
"Adapter",
@@ -13,4 +13,6 @@
1313
"Audio",
1414
"JSONAdapter",
1515
"TwoStepAdapter",
16+
"Tool",
17+
"ToolCalls",
1618
]

dspy/adapters/base.py

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1-
from typing import TYPE_CHECKING, Any, Optional, Type
1+
import logging
2+
from typing import TYPE_CHECKING, Any, Optional, Type, get_origin
3+
4+
import json_repair
5+
import litellm
26

37
from dspy.adapters.types import History
48
from dspy.adapters.types.base_type import split_message_content_for_custom_types
9+
from dspy.adapters.types.tool import Tool, ToolCalls
510
from dspy.signatures.signature import Signature
611
from dspy.utils.callback import BaseCallback, with_callbacks
712

13+
logger = logging.getLogger(__name__)
14+
815
if TYPE_CHECKING:
916
from dspy.clients.lm import LM
1017

@@ -20,18 +27,78 @@ def __init_subclass__(cls, **kwargs) -> None:
2027
cls.format = with_callbacks(cls.format)
2128
cls.parse = with_callbacks(cls.parse)
2229

23-
def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]:
30+
def _call_preprocess(
31+
self,
32+
lm: "LM",
33+
lm_kwargs: dict[str, Any],
34+
signature: Type[Signature],
35+
inputs: dict[str, Any],
36+
use_native_function_calling: bool = False,
37+
) -> dict[str, Any]:
38+
if use_native_function_calling:
39+
tool_call_input_field_name = self._get_tool_call_input_field_name(signature)
40+
tool_call_output_field_name = self._get_tool_call_output_field_name(signature)
41+
42+
if tool_call_output_field_name and tool_call_input_field_name is None:
43+
raise ValueError(
44+
f"You provided an output field {tool_call_output_field_name} to receive the tool calls information, "
45+
"but did not provide any tools as the input. Please provide a list of tools as the input by adding an "
46+
"input field with type `list[dspy.Tool]`."
47+
)
48+
49+
if tool_call_output_field_name and litellm.supports_function_calling(model=lm.model):
50+
tools = inputs[tool_call_input_field_name]
51+
tools = tools if isinstance(tools, list) else [tools]
52+
53+
litellm_tools = []
54+
for tool in tools:
55+
litellm_tools.append(tool.format_as_litellm_function_call())
56+
57+
lm_kwargs["tools"] = litellm_tools
58+
59+
signature_for_native_function_calling = signature.delete(tool_call_output_field_name)
60+
61+
return signature_for_native_function_calling
62+
63+
return signature
64+
65+
def _call_postprocess(
66+
self,
67+
signature: Type[Signature],
68+
outputs: list[dict[str, Any]],
69+
) -> list[dict[str, Any]]:
2470
values = []
2571

72+
tool_call_output_field_name = self._get_tool_call_output_field_name(signature)
73+
2674
for output in outputs:
2775
output_logprobs = None
76+
tool_calls = None
77+
text = output
2878

2979
if isinstance(output, dict):
30-
output, output_logprobs = output["text"], output["logprobs"]
31-
32-
value = self.parse(signature, output)
33-
34-
if output_logprobs is not None:
80+
text = output["text"]
81+
output_logprobs = output.get("logprobs")
82+
tool_calls = output.get("tool_calls")
83+
84+
if text:
85+
value = self.parse(signature, text)
86+
else:
87+
value = {}
88+
for field_name in signature.output_fields.keys():
89+
value[field_name] = None
90+
91+
if tool_calls and tool_call_output_field_name:
92+
tool_calls = [
93+
{
94+
"name": v["function"]["name"],
95+
"args": json_repair.loads(v["function"]["arguments"]),
96+
}
97+
for v in tool_calls
98+
]
99+
value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls)
100+
101+
if output_logprobs:
35102
value["logprobs"] = output_logprobs
36103

37104
values.append(value)
@@ -46,10 +113,11 @@ def __call__(
46113
demos: list[dict[str, Any]],
47114
inputs: dict[str, Any],
48115
) -> list[dict[str, Any]]:
49-
inputs = self.format(signature, demos, inputs)
116+
processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs)
117+
inputs = self.format(processed_signature, demos, inputs)
50118

51119
outputs = lm(messages=inputs, **lm_kwargs)
52-
return self._call_post_process(outputs, signature)
120+
return self._call_postprocess(signature, outputs)
53121

54122
async def acall(
55123
self,
@@ -59,10 +127,11 @@ async def acall(
59127
demos: list[dict[str, Any]],
60128
inputs: dict[str, Any],
61129
) -> list[dict[str, Any]]:
62-
inputs = self.format(signature, demos, inputs)
130+
processed_signature = self._call_preprocess(lm, lm_kwargs, signature, inputs)
131+
inputs = self.format(processed_signature, demos, inputs)
63132

64133
outputs = await lm.acall(messages=inputs, **lm_kwargs)
65-
return self._call_post_process(outputs, signature)
134+
return self._call_postprocess(signature, outputs)
66135

67136
def format(
68137
self,
@@ -297,6 +366,22 @@ def _get_history_field_name(self, signature: Type[Signature]) -> bool:
297366
return name
298367
return None
299368

369+
def _get_tool_call_input_field_name(self, signature: Type[Signature]) -> bool:
370+
for name, field in signature.input_fields.items():
371+
# Look for annotation `list[dspy.Tool]` or `dspy.Tool`
372+
origin = get_origin(field.annotation)
373+
if origin is list and field.annotation.__args__[0] == Tool:
374+
return name
375+
if field.annotation == Tool:
376+
return name
377+
return None
378+
379+
def _get_tool_call_output_field_name(self, signature: Type[Signature]) -> bool:
380+
for name, field in signature.output_fields.items():
381+
if field.annotation == ToolCalls:
382+
return name
383+
return None
384+
300385
def format_conversation_history(
301386
self,
302387
signature: Type[Signature],
@@ -352,4 +437,4 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
352437
Returns:
353438
A dictionary of the output fields.
354439
"""
355-
raise NotImplementedError
440+
raise NotImplementedError

dspy/adapters/json_adapter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ def __call__(
7777
f"`response_format` argument. Original error: {e}"
7878
) from e
7979

80+
def _call_preprocess(
81+
self,
82+
lm: "LM",
83+
lm_kwargs: dict[str, Any],
84+
signature: Type[Signature],
85+
inputs: dict[str, Any],
86+
use_native_function_calling: bool = True,
87+
) -> dict[str, Any]:
88+
return super()._call_preprocess(lm, lm_kwargs, signature, inputs, use_native_function_calling)
89+
8090
def format_field_structure(self, signature: Type[Signature]) -> str:
8191
parts = []
8292
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")

dspy/adapters/two_step_adapter.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from typing import Any, Optional, Type
22

3+
import json_repair
4+
35
from dspy.adapters.base import Adapter
46
from dspy.adapters.chat_adapter import ChatAdapter
7+
from dspy.adapters.types import ToolCalls
58
from dspy.adapters.utils import get_field_description_string
69
from dspy.clients import LM
710
from dspy.signatures.field import InputField
@@ -115,11 +118,16 @@ async def acall(
115118

116119
values = []
117120

121+
tool_call_output_field_name = self._get_tool_call_output_field_name(signature)
118122
for output in outputs:
119123
output_logprobs = None
124+
tool_calls = None
125+
text = output
120126

121127
if isinstance(output, dict):
122-
output, output_logprobs = output["text"], output["logprobs"]
128+
text = output["text"]
129+
output_logprobs = output.get("logprobs")
130+
tool_calls = output.get("tool_calls")
123131

124132
try:
125133
# Call the smaller LM to extract structured data from the raw completion text with ChatAdapter
@@ -128,13 +136,23 @@ async def acall(
128136
lm_kwargs={},
129137
signature=extractor_signature,
130138
demos=[],
131-
inputs={"text": output},
139+
inputs={"text": text},
132140
)
133141
value = value[0]
134142

135143
except Exception as e:
136144
raise ValueError(f"Failed to parse response from the original completion: {output}") from e
137145

146+
if tool_calls and tool_call_output_field_name:
147+
tool_calls = [
148+
{
149+
"name": v["function"]["name"],
150+
"args": json_repair.loads(v["function"]["arguments"]),
151+
}
152+
for v in tool_calls
153+
]
154+
value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls)
155+
138156
if output_logprobs is not None:
139157
value["logprobs"] = output_logprobs
140158

dspy/adapters/types/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from dspy.adapters.types.image import Image
33
from dspy.adapters.types.audio import Audio
44
from dspy.adapters.types.base_type import BaseType
5+
from dspy.adapters.types.tool import Tool, ToolCalls
56

6-
__all__ = ["History", "Image", "Audio", "BaseType"]
7+
__all__ = ["History", "Image", "Audio", "BaseType", "Tool", "ToolCalls"]

dspy/adapters/types/base_type.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import re
3-
from typing import Any
3+
from typing import Any, Union, get_args, get_origin
44

55
import json_repair
66
import pydantic
@@ -26,12 +26,42 @@ def format(self) -> list[dict[str, Any]]:
2626
```
2727
"""
2828

29-
def format(self) -> list[dict[str, Any]]:
29+
def format(self) -> Union[list[dict[str, Any]], str]:
3030
raise NotImplementedError
3131

32+
@classmethod
33+
def description(cls) -> str:
34+
"""Description of the custom type"""
35+
return ""
36+
37+
@classmethod
38+
def extract_custom_type_from_annotation(cls, annotation):
39+
"""Extract all custom types from the annotation.
40+
41+
This is used to extract all custom types from the annotation of a field, while the annotation can
42+
have arbitrary level of nesting. For example, we detect `Tool` is in `list[dict[str, Tool]]`.
43+
"""
44+
# Direct match
45+
if isinstance(annotation, type) and issubclass(annotation, cls):
46+
return [annotation]
47+
48+
origin = get_origin(annotation)
49+
if origin is None:
50+
return []
51+
52+
result = []
53+
# Recurse into all type args
54+
for arg in get_args(annotation):
55+
result.extend(cls.extract_custom_type_from_annotation(arg))
56+
57+
return result
58+
3259
@pydantic.model_serializer()
3360
def serialize_model(self):
34-
return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}"
61+
formatted = self.format()
62+
if isinstance(formatted, list):
63+
return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}"
64+
return formatted
3565

3666

3767
def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:

0 commit comments

Comments
 (0)