Skip to content

Commit c0506d1

Browse files
Fix native function calling in adapters (#8479)
* fix * fix test * better test * fix * fix tests * lint fix * better test
1 parent b6ae529 commit c0506d1

File tree

6 files changed

+207
-29
lines changed

6 files changed

+207
-29
lines changed

dspy/adapters/base.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818

1919
class Adapter:
20-
def __init__(self, callbacks: list[BaseCallback] | None = None):
20+
def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False):
2121
self.callbacks = callbacks or []
22+
self.use_native_function_calling = use_native_function_calling
2223

2324
def __init_subclass__(cls, **kwargs) -> None:
2425
super().__init_subclass__(**kwargs)
@@ -33,9 +34,8 @@ def _call_preprocess(
3334
lm_kwargs: dict[str, Any],
3435
signature: Type[Signature],
3536
inputs: dict[str, Any],
36-
use_native_function_calling: bool = False,
3737
) -> dict[str, Any]:
38-
if use_native_function_calling:
38+
if self.use_native_function_calling:
3939
tool_call_input_field_name = self._get_tool_call_input_field_name(signature)
4040
tool_call_output_field_name = self._get_tool_call_output_field_name(signature)
4141

@@ -57,19 +57,23 @@ def _call_preprocess(
5757
lm_kwargs["tools"] = litellm_tools
5858

5959
signature_for_native_function_calling = signature.delete(tool_call_output_field_name)
60+
signature_for_native_function_calling = signature_for_native_function_calling.delete(
61+
tool_call_input_field_name
62+
)
6063

6164
return signature_for_native_function_calling
6265

6366
return signature
6467

6568
def _call_postprocess(
6669
self,
67-
signature: Type[Signature],
70+
processed_signature: Type[Signature],
71+
original_signature: Type[Signature],
6872
outputs: list[dict[str, Any]],
6973
) -> list[dict[str, Any]]:
7074
values = []
7175

72-
tool_call_output_field_name = self._get_tool_call_output_field_name(signature)
76+
tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature)
7377

7478
for output in outputs:
7579
output_logprobs = None
@@ -82,10 +86,14 @@ def _call_postprocess(
8286
tool_calls = output.get("tool_calls")
8387

8488
if text:
85-
value = self.parse(signature, text)
89+
value = self.parse(processed_signature, text)
90+
for field_name in original_signature.output_fields.keys():
91+
if field_name not in value:
92+
# We need to set the field not present in the processed signature to None for consistency.
93+
value[field_name] = None
8694
else:
8795
value = {}
88-
for field_name in signature.output_fields.keys():
96+
for field_name in original_signature.output_fields.keys():
8997
value[field_name] = None
9098

9199
if tool_calls and tool_call_output_field_name:
@@ -117,7 +125,7 @@ def __call__(
117125
inputs = self.format(processed_signature, demos, inputs)
118126

119127
outputs = lm(messages=inputs, **lm_kwargs)
120-
return self._call_postprocess(signature, outputs)
128+
return self._call_postprocess(processed_signature, signature, outputs)
121129

122130
async def acall(
123131
self,
@@ -131,7 +139,7 @@ async def acall(
131139
inputs = self.format(processed_signature, demos, inputs)
132140

133141
outputs = await lm.acall(messages=inputs, **lm_kwargs)
134-
return self._call_postprocess(signature, outputs)
142+
return self._call_postprocess(processed_signature, signature, outputs)
135143

136144
def format(
137145
self,

dspy/adapters/chat_adapter.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
)
1616
from dspy.clients.lm import LM
1717
from dspy.signatures.signature import Signature
18-
from dspy.utils.callback import BaseCallback
1918
from dspy.utils.exceptions import AdapterParseError
2019

2120
field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")
@@ -27,9 +26,6 @@ class FieldInfoWithName(NamedTuple):
2726

2827

2928
class ChatAdapter(Adapter):
30-
def __init__(self, callbacks: list[BaseCallback] | None = None):
31-
super().__init__(callbacks)
32-
3329
def __call__(
3430
self,
3531
lm: LM,

dspy/adapters/json_adapter.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic.fields import FieldInfo
1010

1111
from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName
12+
from dspy.adapters.types.tool import ToolCalls
1213
from dspy.adapters.utils import (
1314
format_field_value,
1415
get_annotation_name,
@@ -18,6 +19,7 @@
1819
)
1920
from dspy.clients.lm import LM
2021
from dspy.signatures.signature import Signature, SignatureMeta
22+
from dspy.utils.callback import BaseCallback
2123
from dspy.utils.exceptions import AdapterParseError
2224

2325
logger = logging.getLogger(__name__)
@@ -37,6 +39,10 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool:
3739

3840

3941
class JSONAdapter(ChatAdapter):
42+
def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = True):
43+
# JSONAdapter uses native function calling by default.
44+
super().__init__(callbacks=callbacks, use_native_function_calling=use_native_function_calling)
45+
4046
def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn):
4147
"""Common call logic to be used for both sync and async calls."""
4248
provider = lm.model.split("/", 1)[0] or "openai"
@@ -45,7 +51,10 @@ def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, cal
4551
if not params or "response_format" not in params:
4652
return call_fn(lm, lm_kwargs, signature, demos, inputs)
4753

48-
if _has_open_ended_mapping(signature):
54+
has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values())
55+
if _has_open_ended_mapping(signature) or (not self.use_native_function_calling and has_tool_calls):
56+
# We found that structured output mode doesn't work well with dspy.ToolCalls as output field.
57+
# So we fall back to json mode if native function calling is disabled and ToolCalls is present.
4958
lm_kwargs["response_format"] = {"type": "json_object"}
5059
return call_fn(lm, lm_kwargs, signature, demos, inputs)
5160

@@ -62,7 +71,9 @@ def __call__(
6271
return result
6372

6473
try:
65-
structured_output_model = _get_structured_outputs_response_format(signature)
74+
structured_output_model = _get_structured_outputs_response_format(
75+
signature, self.use_native_function_calling
76+
)
6677
lm_kwargs["response_format"] = structured_output_model
6778
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
6879
except Exception:
@@ -91,16 +102,6 @@ async def acall(
91102
lm_kwargs["response_format"] = {"type": "json_object"}
92103
return await super().acall(lm, lm_kwargs, signature, demos, inputs)
93104

94-
def _call_preprocess(
95-
self,
96-
lm: "LM",
97-
lm_kwargs: dict[str, Any],
98-
signature: Type[Signature],
99-
inputs: dict[str, Any],
100-
use_native_function_calling: bool = True,
101-
) -> dict[str, Any]:
102-
return super()._call_preprocess(lm, lm_kwargs, signature, inputs, use_native_function_calling)
103-
104105
def format_field_structure(self, signature: Type[Signature]) -> str:
105106
parts = []
106107
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")
@@ -206,7 +207,10 @@ def format_finetune_data(
206207
raise NotImplementedError
207208

208209

209-
def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[pydantic.BaseModel]:
210+
def _get_structured_outputs_response_format(
211+
signature: SignatureMeta,
212+
use_native_function_calling: bool = True,
213+
) -> type[pydantic.BaseModel]:
210214
"""
211215
Builds a Pydantic model from a DSPy signature's output_fields and ensures the generated JSON schema
212216
is compatible with OpenAI Structured Outputs (all objects have a "required" key listing every property,
@@ -227,6 +231,9 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py
227231
fields = {}
228232
for name, field in signature.output_fields.items():
229233
annotation = field.annotation
234+
if use_native_function_calling and annotation == ToolCalls:
235+
# Skip ToolCalls field if native function calling is enabled.
236+
continue
230237
default = field.default if hasattr(field, "default") else ...
231238
fields[name] = (annotation, default)
232239

dspy/adapters/two_step_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class TwoStepAdapter(Adapter):
3939
```
4040
"""
4141

42-
def __init__(self, extraction_model: LM):
42+
def __init__(self, extraction_model: LM, **kwargs):
43+
super().__init__(**kwargs)
4344
if not isinstance(extraction_model, LM):
4445
raise ValueError("extraction_model must be an instance of LM")
4546
self.extraction_model = extraction_model

tests/adapters/test_chat_adapter.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pydantic
55
import pytest
6-
from litellm.utils import Choices, Message, ModelResponse
6+
from litellm.utils import ChatCompletionMessageToolCall, Choices, Function, Message, ModelResponse
77

88
import dspy
99

@@ -422,3 +422,70 @@ async def test_chat_adapter_fallback_to_json_adapter_on_exception_async():
422422
# The parse should succeed
423423
result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
424424
assert result == [{"answer": "Paris"}]
425+
426+
427+
def test_chat_adapter_toolcalls_native_function_calling():
428+
class MySignature(dspy.Signature):
429+
question: str = dspy.InputField()
430+
tools: list[dspy.Tool] = dspy.InputField()
431+
answer: str = dspy.OutputField()
432+
tool_calls: dspy.ToolCalls = dspy.OutputField()
433+
434+
def get_weather(city: str) -> str:
435+
return f"The weather in {city} is sunny"
436+
437+
tools = [dspy.Tool(get_weather)]
438+
439+
adapter = dspy.JSONAdapter(use_native_function_calling=True)
440+
441+
# Case 1: Tool calls are present in the response, while content is None.
442+
with mock.patch("litellm.completion") as mock_completion:
443+
mock_completion.return_value = ModelResponse(
444+
choices=[
445+
Choices(
446+
finish_reason="tool_calls",
447+
index=0,
448+
message=Message(
449+
content=None,
450+
role="assistant",
451+
tool_calls=[
452+
ChatCompletionMessageToolCall(
453+
function=Function(arguments='{"city":"Paris"}', name="get_weather"),
454+
id="call_pQm8ajtSMxgA0nrzK2ivFmxG",
455+
type="function",
456+
)
457+
],
458+
),
459+
),
460+
],
461+
model="openai/gpt-4o-mini",
462+
)
463+
result = adapter(
464+
dspy.LM(model="openai/gpt-4o-mini", cache=False),
465+
{},
466+
MySignature,
467+
[],
468+
{"question": "What is the weather in Paris?", "tools": tools},
469+
)
470+
471+
assert result[0]["tool_calls"] == dspy.ToolCalls(
472+
tool_calls=[dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Paris"})]
473+
)
474+
# `answer` is not present, so we set it to None
475+
assert result[0]["answer"] is None
476+
477+
# Case 2: Tool calls are not present in the response, while content is present.
478+
with mock.patch("litellm.completion") as mock_completion:
479+
mock_completion.return_value = ModelResponse(
480+
choices=[Choices(message=Message(content="{'answer': 'Paris'}"))],
481+
model="openai/gpt-4o-mini",
482+
)
483+
result = adapter(
484+
dspy.LM(model="openai/gpt-4o-mini", cache=False),
485+
{},
486+
MySignature,
487+
[],
488+
{"question": "What is the weather in Paris?", "tools": tools},
489+
)
490+
assert result[0]["answer"] == "Paris"
491+
assert result[0]["tool_calls"] is None

tests/adapters/test_json_adapter.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pydantic
44
import pytest
5-
from litellm.utils import Choices, Message, ModelResponse
5+
from litellm.utils import ChatCompletionMessageToolCall, Choices, Function, Message, ModelResponse
66

77
import dspy
88

@@ -650,3 +650,102 @@ class TestSignature(dspy.Signature):
650650
await program.acall(question="Dummy question!")
651651

652652
assert "ValueError!" in str(error.value)
653+
654+
655+
def test_json_adapter_toolcalls_native_function_calling():
656+
class MySignature(dspy.Signature):
657+
question: str = dspy.InputField()
658+
tools: list[dspy.Tool] = dspy.InputField()
659+
answer: str = dspy.OutputField()
660+
tool_calls: dspy.ToolCalls = dspy.OutputField()
661+
662+
def get_weather(city: str) -> str:
663+
return f"The weather in {city} is sunny"
664+
665+
tools = [dspy.Tool(get_weather)]
666+
667+
adapter = dspy.JSONAdapter(use_native_function_calling=True)
668+
669+
# Case 1: Tool calls are present in the response, while content is None.
670+
with mock.patch("litellm.completion") as mock_completion:
671+
mock_completion.return_value = ModelResponse(
672+
choices=[
673+
Choices(
674+
finish_reason="tool_calls",
675+
index=0,
676+
message=Message(
677+
content=None,
678+
role="assistant",
679+
tool_calls=[
680+
ChatCompletionMessageToolCall(
681+
function=Function(arguments='{"city":"Paris"}', name="get_weather"),
682+
id="call_pQm8ajtSMxgA0nrzK2ivFmxG",
683+
type="function",
684+
)
685+
],
686+
),
687+
),
688+
],
689+
model="openai/gpt-4o-mini",
690+
)
691+
result = adapter(
692+
dspy.LM(model="openai/gpt-4o-mini", cache=False),
693+
{},
694+
MySignature,
695+
[],
696+
{"question": "What is the weather in Paris?", "tools": tools},
697+
)
698+
699+
assert result[0]["tool_calls"] == dspy.ToolCalls(
700+
tool_calls=[dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Paris"})]
701+
)
702+
# `answer` is not present, so we set it to None
703+
assert result[0]["answer"] is None
704+
705+
# Case 2: Tool calls are not present in the response, while content is present.
706+
with mock.patch("litellm.completion") as mock_completion:
707+
mock_completion.return_value = ModelResponse(
708+
choices=[Choices(message=Message(content="{'answer': 'Paris'}"))],
709+
model="openai/gpt-4o-mini",
710+
)
711+
result = adapter(
712+
dspy.LM(model="openai/gpt-4o-mini", cache=False),
713+
{},
714+
MySignature,
715+
[],
716+
{"question": "What is the weather in Paris?", "tools": tools},
717+
)
718+
assert result[0]["answer"] == "Paris"
719+
assert result[0]["tool_calls"] is None
720+
721+
722+
def test_json_adapter_toolcalls_no_native_function_calling():
723+
class MySignature(dspy.Signature):
724+
question: str = dspy.InputField()
725+
tools: list[dspy.Tool] = dspy.InputField()
726+
answer: str = dspy.OutputField()
727+
tool_calls: dspy.ToolCalls = dspy.OutputField()
728+
729+
def get_weather(city: str) -> str:
730+
return f"The weather in {city} is sunny"
731+
732+
tools = [dspy.Tool(get_weather)]
733+
734+
# Patch _get_structured_outputs_response_format to track calls
735+
with mock.patch("dspy.adapters.json_adapter._get_structured_outputs_response_format") as mock_structured:
736+
# Patch litellm.completion to return a dummy response
737+
with mock.patch("litellm.completion") as mock_completion:
738+
mock_completion.return_value = ModelResponse(
739+
choices=[Choices(message=Message(content="{'answer': 'sunny', 'tool_calls': {'tool_calls': []}}"))],
740+
model="openai/gpt-4o-mini",
741+
)
742+
adapter = dspy.JSONAdapter(use_native_function_calling=False)
743+
lm = dspy.LM(model="openai/gpt-4o-mini", cache=False)
744+
adapter(lm, {}, MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools})
745+
746+
# _get_structured_outputs_response_format is not called because without using native function calling,
747+
# JSONAdapter falls back to json mode for stable quality.
748+
mock_structured.assert_not_called()
749+
mock_completion.assert_called_once()
750+
_, call_kwargs = mock_completion.call_args
751+
assert call_kwargs["response_format"] == {"type": "json_object"}

0 commit comments

Comments
 (0)