1
- from typing import TYPE_CHECKING , Any , Optional , Type
1
+ import logging
2
+ from typing import TYPE_CHECKING , Any , Optional , Type , get_origin
3
+
4
+ import json_repair
5
+ import litellm
2
6
3
7
from dspy .adapters .types import History
4
8
from dspy .adapters .types .base_type import split_message_content_for_custom_types
9
+ from dspy .adapters .types .tool import Tool , ToolCalls
5
10
from dspy .signatures .signature import Signature
6
11
from dspy .utils .callback import BaseCallback , with_callbacks
7
12
13
+ logger = logging .getLogger (__name__ )
14
+
8
15
if TYPE_CHECKING :
9
16
from dspy .clients .lm import LM
10
17
@@ -20,18 +27,78 @@ def __init_subclass__(cls, **kwargs) -> None:
20
27
cls .format = with_callbacks (cls .format )
21
28
cls .parse = with_callbacks (cls .parse )
22
29
23
- def _call_post_process (self , outputs : list [dict [str , Any ]], signature : Type [Signature ]) -> list [dict [str , Any ]]:
30
+ def _call_preprocess (
31
+ self ,
32
+ lm : "LM" ,
33
+ lm_kwargs : dict [str , Any ],
34
+ signature : Type [Signature ],
35
+ inputs : dict [str , Any ],
36
+ use_native_function_calling : bool = False ,
37
+ ) -> dict [str , Any ]:
38
+ if use_native_function_calling :
39
+ tool_call_input_field_name = self ._get_tool_call_input_field_name (signature )
40
+ tool_call_output_field_name = self ._get_tool_call_output_field_name (signature )
41
+
42
+ if tool_call_output_field_name and tool_call_input_field_name is None :
43
+ raise ValueError (
44
+ f"You provided an output field { tool_call_output_field_name } to receive the tool calls information, "
45
+ "but did not provide any tools as the input. Please provide a list of tools as the input by adding an "
46
+ "input field with type `list[dspy.Tool]`."
47
+ )
48
+
49
+ if tool_call_output_field_name and litellm .supports_function_calling (model = lm .model ):
50
+ tools = inputs [tool_call_input_field_name ]
51
+ tools = tools if isinstance (tools , list ) else [tools ]
52
+
53
+ litellm_tools = []
54
+ for tool in tools :
55
+ litellm_tools .append (tool .format_as_litellm_function_call ())
56
+
57
+ lm_kwargs ["tools" ] = litellm_tools
58
+
59
+ signature_for_native_function_calling = signature .delete (tool_call_output_field_name )
60
+
61
+ return signature_for_native_function_calling
62
+
63
+ return signature
64
+
65
+ def _call_postprocess (
66
+ self ,
67
+ signature : Type [Signature ],
68
+ outputs : list [dict [str , Any ]],
69
+ ) -> list [dict [str , Any ]]:
24
70
values = []
25
71
72
+ tool_call_output_field_name = self ._get_tool_call_output_field_name (signature )
73
+
26
74
for output in outputs :
27
75
output_logprobs = None
76
+ tool_calls = None
77
+ text = output
28
78
29
79
if isinstance (output , dict ):
30
- output , output_logprobs = output ["text" ], output ["logprobs" ]
31
-
32
- value = self .parse (signature , output )
33
-
34
- if output_logprobs is not None :
80
+ text = output ["text" ]
81
+ output_logprobs = output .get ("logprobs" )
82
+ tool_calls = output .get ("tool_calls" )
83
+
84
+ if text :
85
+ value = self .parse (signature , text )
86
+ else :
87
+ value = {}
88
+ for field_name in signature .output_fields .keys ():
89
+ value [field_name ] = None
90
+
91
+ if tool_calls and tool_call_output_field_name :
92
+ tool_calls = [
93
+ {
94
+ "name" : v ["function" ]["name" ],
95
+ "args" : json_repair .loads (v ["function" ]["arguments" ]),
96
+ }
97
+ for v in tool_calls
98
+ ]
99
+ value [tool_call_output_field_name ] = ToolCalls .from_dict_list (tool_calls )
100
+
101
+ if output_logprobs :
35
102
value ["logprobs" ] = output_logprobs
36
103
37
104
values .append (value )
@@ -46,10 +113,11 @@ def __call__(
46
113
demos : list [dict [str , Any ]],
47
114
inputs : dict [str , Any ],
48
115
) -> list [dict [str , Any ]]:
49
- inputs = self .format (signature , demos , inputs )
116
+ processed_signature = self ._call_preprocess (lm , lm_kwargs , signature , inputs )
117
+ inputs = self .format (processed_signature , demos , inputs )
50
118
51
119
outputs = lm (messages = inputs , ** lm_kwargs )
52
- return self ._call_post_process ( outputs , signature )
120
+ return self ._call_postprocess ( signature , outputs )
53
121
54
122
async def acall (
55
123
self ,
@@ -59,10 +127,11 @@ async def acall(
59
127
demos : list [dict [str , Any ]],
60
128
inputs : dict [str , Any ],
61
129
) -> list [dict [str , Any ]]:
62
- inputs = self .format (signature , demos , inputs )
130
+ processed_signature = self ._call_preprocess (lm , lm_kwargs , signature , inputs )
131
+ inputs = self .format (processed_signature , demos , inputs )
63
132
64
133
outputs = await lm .acall (messages = inputs , ** lm_kwargs )
65
- return self ._call_post_process ( outputs , signature )
134
+ return self ._call_postprocess ( signature , outputs )
66
135
67
136
def format (
68
137
self ,
@@ -297,6 +366,22 @@ def _get_history_field_name(self, signature: Type[Signature]) -> bool:
297
366
return name
298
367
return None
299
368
369
+ def _get_tool_call_input_field_name (self , signature : Type [Signature ]) -> bool :
370
+ for name , field in signature .input_fields .items ():
371
+ # Look for annotation `list[dspy.Tool]` or `dspy.Tool`
372
+ origin = get_origin (field .annotation )
373
+ if origin is list and field .annotation .__args__ [0 ] == Tool :
374
+ return name
375
+ if field .annotation == Tool :
376
+ return name
377
+ return None
378
+
379
+ def _get_tool_call_output_field_name (self , signature : Type [Signature ]) -> bool :
380
+ for name , field in signature .output_fields .items ():
381
+ if field .annotation == ToolCalls :
382
+ return name
383
+ return None
384
+
300
385
def format_conversation_history (
301
386
self ,
302
387
signature : Type [Signature ],
@@ -352,4 +437,4 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
352
437
Returns:
353
438
A dictionary of the output fields.
354
439
"""
355
- raise NotImplementedError
440
+ raise NotImplementedError
0 commit comments