diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 1cdc1edfd9..3e0372c0eb 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -17,8 +17,9 @@ class Adapter: - def __init__(self, callbacks: list[BaseCallback] | None = None): + def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False): self.callbacks = callbacks or [] + self.use_native_function_calling = use_native_function_calling def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) @@ -33,9 +34,8 @@ def _call_preprocess( lm_kwargs: dict[str, Any], signature: Type[Signature], inputs: dict[str, Any], - use_native_function_calling: bool = False, ) -> dict[str, Any]: - if use_native_function_calling: + if self.use_native_function_calling: tool_call_input_field_name = self._get_tool_call_input_field_name(signature) tool_call_output_field_name = self._get_tool_call_output_field_name(signature) @@ -57,6 +57,9 @@ def _call_preprocess( lm_kwargs["tools"] = litellm_tools signature_for_native_function_calling = signature.delete(tool_call_output_field_name) + signature_for_native_function_calling = signature_for_native_function_calling.delete( + tool_call_input_field_name + ) return signature_for_native_function_calling @@ -64,12 +67,13 @@ def _call_preprocess( def _call_postprocess( self, - signature: Type[Signature], + processed_signature: Type[Signature], + original_signature: Type[Signature], outputs: list[dict[str, Any]], ) -> list[dict[str, Any]]: values = [] - tool_call_output_field_name = self._get_tool_call_output_field_name(signature) + tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature) for output in outputs: output_logprobs = None @@ -82,10 +86,14 @@ def _call_postprocess( tool_calls = output.get("tool_calls") if text: - value = self.parse(signature, text) + value = self.parse(processed_signature, text) + for field_name in original_signature.output_fields.keys(): + if field_name not in value: + # We need to set the field not present in the processed signature to None for consistency. + value[field_name] = None else: value = {} - for field_name in signature.output_fields.keys(): + for field_name in original_signature.output_fields.keys(): value[field_name] = None if tool_calls and tool_call_output_field_name: @@ -117,7 +125,7 @@ def __call__( inputs = self.format(processed_signature, demos, inputs) outputs = lm(messages=inputs, **lm_kwargs) - return self._call_postprocess(signature, outputs) + return self._call_postprocess(processed_signature, signature, outputs) async def acall( self, @@ -131,7 +139,7 @@ async def acall( inputs = self.format(processed_signature, demos, inputs) outputs = await lm.acall(messages=inputs, **lm_kwargs) - return self._call_postprocess(signature, outputs) + return self._call_postprocess(processed_signature, signature, outputs) def format( self, diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index ac407e4755..4ace4214de 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -15,7 +15,6 @@ ) from dspy.clients.lm import LM from dspy.signatures.signature import Signature -from dspy.utils.callback import BaseCallback from dspy.utils.exceptions import AdapterParseError field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]") @@ -27,9 +26,6 @@ class FieldInfoWithName(NamedTuple): class ChatAdapter(Adapter): - def __init__(self, callbacks: list[BaseCallback] | None = None): - super().__init__(callbacks) - def __call__( self, lm: LM, diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 9b14691c7c..51470e2ecd 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -9,6 +9,7 @@ from pydantic.fields import FieldInfo from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName +from dspy.adapters.types.tool import ToolCalls from dspy.adapters.utils import ( format_field_value, get_annotation_name, @@ -18,6 +19,7 @@ ) from dspy.clients.lm import LM from dspy.signatures.signature import Signature, SignatureMeta +from dspy.utils.callback import BaseCallback from dspy.utils.exceptions import AdapterParseError logger = logging.getLogger(__name__) @@ -37,6 +39,10 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool: class JSONAdapter(ChatAdapter): + def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = True): + # JSONAdapter uses native function calling by default. + super().__init__(callbacks=callbacks, use_native_function_calling=use_native_function_calling) + def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn): """Common call logic to be used for both sync and async calls.""" provider = lm.model.split("/", 1)[0] or "openai" @@ -45,7 +51,10 @@ def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, cal if not params or "response_format" not in params: return call_fn(lm, lm_kwargs, signature, demos, inputs) - if _has_open_ended_mapping(signature): + has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values()) + if _has_open_ended_mapping(signature) or (not self.use_native_function_calling and has_tool_calls): + # We found that structured output mode doesn't work well with dspy.ToolCalls as output field. + # So we fall back to json mode if native function calling is disabled and ToolCalls is present. lm_kwargs["response_format"] = {"type": "json_object"} return call_fn(lm, lm_kwargs, signature, demos, inputs) @@ -62,7 +71,9 @@ def __call__( return result try: - structured_output_model = _get_structured_outputs_response_format(signature) + structured_output_model = _get_structured_outputs_response_format( + signature, self.use_native_function_calling + ) lm_kwargs["response_format"] = structured_output_model return super().__call__(lm, lm_kwargs, signature, demos, inputs) except Exception: @@ -91,16 +102,6 @@ async def acall( lm_kwargs["response_format"] = {"type": "json_object"} return await super().acall(lm, lm_kwargs, signature, demos, inputs) - def _call_preprocess( - self, - lm: "LM", - lm_kwargs: dict[str, Any], - signature: Type[Signature], - inputs: dict[str, Any], - use_native_function_calling: bool = True, - ) -> dict[str, Any]: - return super()._call_preprocess(lm, lm_kwargs, signature, inputs, use_native_function_calling) - def format_field_structure(self, signature: Type[Signature]) -> str: parts = [] parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") @@ -206,7 +207,10 @@ def format_finetune_data( raise NotImplementedError -def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[pydantic.BaseModel]: +def _get_structured_outputs_response_format( + signature: SignatureMeta, + use_native_function_calling: bool = True, +) -> type[pydantic.BaseModel]: """ Builds a Pydantic model from a DSPy signature's output_fields and ensures the generated JSON schema is compatible with OpenAI Structured Outputs (all objects have a "required" key listing every property, @@ -227,6 +231,9 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py fields = {} for name, field in signature.output_fields.items(): annotation = field.annotation + if use_native_function_calling and annotation == ToolCalls: + # Skip ToolCalls field if native function calling is enabled. + continue default = field.default if hasattr(field, "default") else ... fields[name] = (annotation, default) diff --git a/dspy/adapters/two_step_adapter.py b/dspy/adapters/two_step_adapter.py index 405cc99e29..27b8d5e72b 100644 --- a/dspy/adapters/two_step_adapter.py +++ b/dspy/adapters/two_step_adapter.py @@ -39,7 +39,8 @@ class TwoStepAdapter(Adapter): ``` """ - def __init__(self, extraction_model: LM): + def __init__(self, extraction_model: LM, **kwargs): + super().__init__(**kwargs) if not isinstance(extraction_model, LM): raise ValueError("extraction_model must be an instance of LM") self.extraction_model = extraction_model diff --git a/dspy/adapters/types/tool.py b/dspy/adapters/types/tool.py index c2ca4990cb..14004ad749 100644 --- a/dspy/adapters/types/tool.py +++ b/dspy/adapters/types/tool.py @@ -2,6 +2,7 @@ import inspect from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, get_origin, get_type_hints +import pydantic from jsonschema import ValidationError, validate from pydantic import BaseModel, TypeAdapter, create_model @@ -288,7 +289,7 @@ def from_dict_list(cls, tool_calls_dicts: list[dict[str, Any]]) -> "ToolCalls": def description(cls) -> str: return ( "Tool calls information, including the name of the tools and the arguments to be passed to it. " - "Arguments must be provided in JSON format." + '`args` must be provided in JSON format, e.g., `{"arg1": "value1", "arg2": "value2"}`' ) def format(self) -> list[dict[str, Any]]: @@ -303,11 +304,17 @@ def format(self) -> list[dict[str, Any]]: "name": tool_call.name, "arguments": tool_call.args, }, - } for tool_call in self.tool_calls + } + for tool_call in self.tool_calls ], } ] + @pydantic.model_serializer() + def serialize_model(self): + """Override so that the ToolCalls are not formatted as a message array when using as inputs.""" + return self.format() + def _resolve_json_schema_reference(schema: dict) -> dict: """Recursively resolve json model schema, expanding all references.""" diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 212f7b2161..57a87ca669 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -1,10 +1,11 @@ import logging -from typing import TYPE_CHECKING, Any, Callable, Literal, Type +from typing import TYPE_CHECKING, Any, Callable, Type from litellm import ContextWindowExceededError +from pydantic import ValidationError import dspy -from dspy.adapters.types.tool import Tool +from dspy.adapters.types.tool import Tool, ToolCalls from dspy.primitives.module import Module from dspy.signatures.signature import ensure_signature @@ -53,10 +54,10 @@ def get_weather(city: str) -> str: [ f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", - "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.", + "To do this, you will interleave next_thought and next_tool_calls in each turn, and also when finishing the task.", "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", "When writing next_thought, you may reason about the current situation and plan for future steps.", - "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", + "When selecting the tools to call, the tool must be one of:\n", ] ) @@ -69,14 +70,14 @@ def get_weather(city: str) -> str: for idx, tool in enumerate(tools.values()): instr.append(f"({idx + 1}) {tool}") - instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format") + instr.append("When providing `next_tool_calls`, the args value inside the field must be in JSON format") react_signature = ( dspy.Signature({**signature.input_fields}, "\n".join(instr)) .append("trajectory", dspy.InputField(), type_=str) + .append("tools", dspy.InputField(), type_=list[Tool]) .append("next_thought", dspy.OutputField(), type_=str) - .append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) - .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) + .append("next_tool_calls", dspy.OutputField(), type_=ToolCalls) ) fallback_signature = dspy.Signature( @@ -96,23 +97,34 @@ def _format_trajectory(self, trajectory: dict[str, Any]): def forward(self, **input_args): trajectory = {} max_iters = input_args.pop("max_iters", self.max_iters) + dspy.settings.adapter.use_native_function_calling = False for idx in range(max_iters): try: pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) + except ValidationError as err: + trajectory[f"thought_{idx}"] = ( + "Encounter value when parsing the LM response, please try fixing based on the error message." + ) + trajectory[f"tool_calls_{idx}"] = None + trajectory[f"observation_{idx}"] = [f"Error: {err}"] + continue except ValueError as err: logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") break trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args - - try: - trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) - except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" - - if pred.next_tool_name == "finish": + trajectory[f"tool_calls_{idx}"] = str(pred.next_tool_calls) + trajectory[f"observation_{idx}"] = [] + + tool_calls = [] if pred.next_tool_calls is None else pred.next_tool_calls.tool_calls + tool_names = [tool_call.name for tool_call in tool_calls] + for tool_call in tool_calls: + try: + trajectory[f"observation_{idx}"].append(f"{self.tools[tool_call.name](**tool_call.args)}") + except Exception as err: + trajectory[f"observation_{idx}"].append(f"Execution error in {tool_call.name}: {_fmt_exc(err)}") + + if "finish" in tool_names or len(tool_names) == 0: break extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) @@ -121,23 +133,35 @@ def forward(self, **input_args): async def aforward(self, **input_args): trajectory = {} max_iters = input_args.pop("max_iters", self.max_iters) + dspy.settings.adapter.use_native_function_calling = False for idx in range(max_iters): try: pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) + except ValidationError as err: + trajectory[f"thought_{idx}"] = ( + "Encounter value when parsing the LM response, please try fixing based on the error message." + ) + trajectory[f"tool_calls_{idx}"] = None + trajectory[f"observation_{idx}"] = [f"Error: {err}"] + continue except ValueError as err: logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") break trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args - - try: - trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args) - except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" - - if pred.next_tool_name == "finish": + trajectory[f"tool_calls_{idx}"] = str(pred.next_tool_calls) + trajectory[f"observation_{idx}"] = [] + + tool_calls = [] if pred.next_tool_calls is None else pred.next_tool_calls.tool_calls + tool_names = [tool_call.name for tool_call in tool_calls] + for tool_call in tool_calls: + try: + result = await self.tools[tool_call.name].acall(**tool_call.args) + trajectory[f"observation_{idx}"].append(result) + except Exception as err: + trajectory[f"observation_{idx}"].append(f"Execution error in {tool_call.name}: {_fmt_exc(err)}") + + if "finish" in tool_names or len(tool_names) == 0: break extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) @@ -149,6 +173,7 @@ def _call_with_potential_trajectory_truncation(self, module, trajectory, **input return module( **input_args, trajectory=self._format_trajectory(trajectory), + tools=list(self.tools.values()), ) except ContextWindowExceededError: logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") diff --git a/tests/adapters/test_chat_adapter.py b/tests/adapters/test_chat_adapter.py index e599da5858..457ba0b0ea 100644 --- a/tests/adapters/test_chat_adapter.py +++ b/tests/adapters/test_chat_adapter.py @@ -3,7 +3,7 @@ import pydantic import pytest -from litellm.utils import Choices, Message, ModelResponse +from litellm.utils import ChatCompletionMessageToolCall, Choices, Function, Message, ModelResponse import dspy @@ -422,3 +422,70 @@ async def test_chat_adapter_fallback_to_json_adapter_on_exception_async(): # The parse should succeed result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"}) assert result == [{"answer": "Paris"}] + + +def test_chat_adapter_toolcalls_native_function_calling(): + class MySignature(dspy.Signature): + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + answer: str = dspy.OutputField() + tool_calls: dspy.ToolCalls = dspy.OutputField() + + def get_weather(city: str) -> str: + return f"The weather in {city} is sunny" + + tools = [dspy.Tool(get_weather)] + + adapter = dspy.JSONAdapter(use_native_function_calling=True) + + # Case 1: Tool calls are present in the response, while content is None. + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[ + Choices( + finish_reason="tool_calls", + index=0, + message=Message( + content=None, + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + function=Function(arguments='{"city":"Paris"}', name="get_weather"), + id="call_pQm8ajtSMxgA0nrzK2ivFmxG", + type="function", + ) + ], + ), + ), + ], + model="openai/gpt-4o-mini", + ) + result = adapter( + dspy.LM(model="openai/gpt-4o-mini", cache=False), + {}, + MySignature, + [], + {"question": "What is the weather in Paris?", "tools": tools}, + ) + + assert result[0]["tool_calls"] == dspy.ToolCalls( + tool_calls=[dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Paris"})] + ) + # `answer` is not present, so we set it to None + assert result[0]["answer"] is None + + # Case 2: Tool calls are not present in the response, while content is present. + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="{'answer': 'Paris'}"))], + model="openai/gpt-4o-mini", + ) + result = adapter( + dspy.LM(model="openai/gpt-4o-mini", cache=False), + {}, + MySignature, + [], + {"question": "What is the weather in Paris?", "tools": tools}, + ) + assert result[0]["answer"] == "Paris" + assert result[0]["tool_calls"] is None diff --git a/tests/adapters/test_json_adapter.py b/tests/adapters/test_json_adapter.py index eb90a754d1..d58e5f9efc 100644 --- a/tests/adapters/test_json_adapter.py +++ b/tests/adapters/test_json_adapter.py @@ -2,7 +2,7 @@ import pydantic import pytest -from litellm.utils import Choices, Message, ModelResponse +from litellm.utils import ChatCompletionMessageToolCall, Choices, Function, Message, ModelResponse import dspy @@ -650,3 +650,102 @@ class TestSignature(dspy.Signature): await program.acall(question="Dummy question!") assert "ValueError!" in str(error.value) + + +def test_json_adapter_toolcalls_native_function_calling(): + class MySignature(dspy.Signature): + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + answer: str = dspy.OutputField() + tool_calls: dspy.ToolCalls = dspy.OutputField() + + def get_weather(city: str) -> str: + return f"The weather in {city} is sunny" + + tools = [dspy.Tool(get_weather)] + + adapter = dspy.JSONAdapter(use_native_function_calling=True) + + # Case 1: Tool calls are present in the response, while content is None. + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[ + Choices( + finish_reason="tool_calls", + index=0, + message=Message( + content=None, + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + function=Function(arguments='{"city":"Paris"}', name="get_weather"), + id="call_pQm8ajtSMxgA0nrzK2ivFmxG", + type="function", + ) + ], + ), + ), + ], + model="openai/gpt-4o-mini", + ) + result = adapter( + dspy.LM(model="openai/gpt-4o-mini", cache=False), + {}, + MySignature, + [], + {"question": "What is the weather in Paris?", "tools": tools}, + ) + + assert result[0]["tool_calls"] == dspy.ToolCalls( + tool_calls=[dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Paris"})] + ) + # `answer` is not present, so we set it to None + assert result[0]["answer"] is None + + # Case 2: Tool calls are not present in the response, while content is present. + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="{'answer': 'Paris'}"))], + model="openai/gpt-4o-mini", + ) + result = adapter( + dspy.LM(model="openai/gpt-4o-mini", cache=False), + {}, + MySignature, + [], + {"question": "What is the weather in Paris?", "tools": tools}, + ) + assert result[0]["answer"] == "Paris" + assert result[0]["tool_calls"] is None + + +def test_json_adapter_toolcalls_no_native_function_calling(): + class MySignature(dspy.Signature): + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + answer: str = dspy.OutputField() + tool_calls: dspy.ToolCalls = dspy.OutputField() + + def get_weather(city: str) -> str: + return f"The weather in {city} is sunny" + + tools = [dspy.Tool(get_weather)] + + # Patch _get_structured_outputs_response_format to track calls + with mock.patch("dspy.adapters.json_adapter._get_structured_outputs_response_format") as mock_structured: + # Patch litellm.completion to return a dummy response + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="{'answer': 'sunny', 'tool_calls': {'tool_calls': []}}"))], + model="openai/gpt-4o-mini", + ) + adapter = dspy.JSONAdapter(use_native_function_calling=False) + lm = dspy.LM(model="openai/gpt-4o-mini", cache=False) + adapter(lm, {}, MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools}) + + # _get_structured_outputs_response_format is not called because without using native function calling, + # JSONAdapter falls back to json mode for stable quality. + mock_structured.assert_not_called() + mock_completion.assert_called_once() + _, call_kwargs = mock_completion.call_args + assert call_kwargs["response_format"] == {"type": "json_object"} diff --git a/tests/predict/test_react.py b/tests/predict/test_react.py index 4ae1b6f52b..3ee5d2df60 100644 --- a/tests/predict/test_react.py +++ b/tests/predict/test_react.py @@ -5,6 +5,7 @@ from pydantic import BaseModel import dspy +from dspy.adapters.types.tool import ToolCalls from dspy.utils.dummies import DummyLM @@ -30,14 +31,20 @@ class InvitationSignature(dspy.Signature): [ { "next_thought": "I need to write an invitation letter for Alice to the Science Fair event.", - "next_tool_name": "write_invitation_letter", - "next_tool_args": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, + "next_tool_calls": { + "tool_calls": [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ] }, }, { @@ -45,8 +52,7 @@ class InvitationSignature(dspy.Signature): "I have successfully written the invitation letter for Alice to the Science Fair. Now " "I can finish the task." ), - "next_tool_name": "finish", - "next_tool_args": {}, + "next_tool_calls": {"tool_calls": [{"name": "finish", "args": {}}]}, }, { "reasoning": "This is a very rigorous reasoning process, trust me bro!", @@ -68,20 +74,27 @@ class InvitationSignature(dspy.Signature): expected_trajectory = { "thought_0": "I need to write an invitation letter for Alice to the Science Fair event.", - "tool_name_0": "write_invitation_letter", - "tool_args_0": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, - }, - "observation_0": "It's my honor to invite Alice to event Science Fair on Friday", + "tool_calls_0": str( + ToolCalls.from_dict_list( + [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ] + ) + ), + "observation_0": ["It's my honor to invite Alice to event Science Fair on Friday"], "thought_1": "I have successfully written the invitation letter for Alice to the Science Fair. Now I can finish the task.", - "tool_name_1": "finish", - "tool_args_1": {}, - "observation_1": "Completed.", + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "finish", "args": {}}])), + "observation_1": ["Completed."], } assert outputs.trajectory == expected_trajectory @@ -94,8 +107,14 @@ def foo(a, b): react = dspy.ReAct("a, b -> c:int", tools=[foo]) lm = DummyLM( [ - {"next_thought": "I need to add two numbers.", "next_tool_name": "foo", "next_tool_args": {"a": 1, "b": 2}}, - {"next_thought": "I have the sum, now I can finish.", "next_tool_name": "finish", "next_tool_args": {}}, + { + "next_thought": "I need to add two numbers.", + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, + }, + { + "next_thought": "I have the sum, now I can finish.", + "next_tool_calls": {"tool_calls": [{"name": "finish", "args": {}}]}, + }, {"reasoning": "I added the numbers successfully", "c": 3}, ] ) @@ -104,16 +123,20 @@ def foo(a, b): expected_trajectory = { "thought_0": "I need to add two numbers.", - "tool_name_0": "foo", - "tool_args_0": { - "a": 1, - "b": 2, - }, - "observation_0": 3, + "tool_calls_0": str( + ToolCalls.from_dict_list( + [ + { + "name": "foo", + "args": {"a": 1, "b": 2}, + } + ] + ) + ), + "observation_0": ["3"], "thought_1": "I have the sum, now I can finish.", - "tool_name_1": "finish", - "tool_args_1": {}, - "observation_1": "Completed.", + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "finish", "args": {}}])), + "observation_1": ["Completed."], } assert outputs.trajectory == expected_trajectory @@ -137,15 +160,19 @@ def mock_react(**kwargs): # First 2 calls use the echo tool return dspy.Prediction( next_thought=f"Thought {call_count}", - next_tool_name="echo", - next_tool_args={"text": f"Text {call_count}"}, + next_tool_calls=ToolCalls( + tool_calls=[ToolCalls.ToolCall(name="echo", args={"text": f"Text {call_count}"})] + ), ) elif call_count == 3: # The 3rd call raises context window exceeded error raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider") else: # The 4th call finishes - return dspy.Prediction(next_thought="Final thought", next_tool_name="finish", next_tool_args={}) + return dspy.Prediction( + next_thought="Final thought", + next_tool_calls=ToolCalls(tool_calls=[ToolCalls.ToolCall(name="finish", args={})]), + ) react.react = mock_react react.extract = lambda **kwargs: dspy.Prediction(output_text="Final output") @@ -170,13 +197,11 @@ def foo(a, b): [ { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, }, { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, }, # (The model *would* succeed on the 3rd turn, but max_iters=2 stops earlier.) {"reasoning": "I added the numbers successfully", "c": 3}, @@ -190,11 +215,9 @@ def foo(a, b): # --- exact-match checks (thoughts + tool calls) ------------------------- control_expected = { "thought_0": "I need to add two numbers.", - "tool_name_0": "foo", - "tool_args_0": {"a": 1, "b": 2}, + "tool_calls_0": str(ToolCalls.from_dict_list([{"name": "foo", "args": {"a": 1, "b": 2}}])), "thought_1": "I need to add two numbers.", - "tool_name_1": "foo", - "tool_args_1": {"a": 1, "b": 2}, + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "foo", "args": {"a": 1, "b": 2}}])), } for k, v in control_expected.items(): assert traj[k] == v, f"{k} mismatch" @@ -203,7 +226,7 @@ def foo(a, b): # We only care that each observation mentions our error string; we ignore # any extra traceback detail or differing prefixes. for i in range(2): - obs = traj[f"observation_{i}"] + obs = str(traj[f"observation_{i}"]) assert re.search(r"\btool error\b", obs), f"unexpected observation_{i!r}: {obs}" @@ -230,14 +253,20 @@ class InvitationSignature(dspy.Signature): [ { "next_thought": "I need to write an invitation letter for Alice to the Science Fair event.", - "next_tool_name": "write_invitation_letter", - "next_tool_args": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, + "next_tool_calls": { + "tool_calls": [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ] }, }, { @@ -245,8 +274,7 @@ class InvitationSignature(dspy.Signature): "I have successfully written the invitation letter for Alice to the Science Fair. Now " "I can finish the task." ), - "next_tool_name": "finish", - "next_tool_args": {}, + "next_tool_calls": {"tool_calls": [{"name": "finish", "args": {}}]}, }, { "reasoning": "This is a very rigorous reasoning process, trust me bro!", @@ -267,20 +295,27 @@ class InvitationSignature(dspy.Signature): expected_trajectory = { "thought_0": "I need to write an invitation letter for Alice to the Science Fair event.", - "tool_name_0": "write_invitation_letter", - "tool_args_0": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, - }, - "observation_0": "It's my honor to invite Alice to event Science Fair on Friday", + "tool_calls_0": str( + ToolCalls.from_dict_list( + [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ] + ) + ), + "observation_0": ["It's my honor to invite Alice to event Science Fair on Friday"], "thought_1": "I have successfully written the invitation letter for Alice to the Science Fair. Now I can finish the task.", - "tool_name_1": "finish", - "tool_args_1": {}, - "observation_1": "Completed.", + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "finish", "args": {}}])), + "observation_1": ["Completed."], } assert outputs.trajectory == expected_trajectory @@ -296,13 +331,11 @@ async def foo(a, b): [ { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, }, { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, }, # (The model *would* succeed on the 3rd turn, but max_iters=2 stops earlier.) {"reasoning": "I added the numbers successfully", "c": 3}, @@ -315,11 +348,9 @@ async def foo(a, b): # Exact-match checks (thoughts + tool calls) control_expected = { "thought_0": "I need to add two numbers.", - "tool_name_0": "foo", - "tool_args_0": {"a": 1, "b": 2}, + "tool_calls_0": str(ToolCalls.from_dict_list([{"name": "foo", "args": {"a": 1, "b": 2}}])), "thought_1": "I need to add two numbers.", - "tool_name_1": "foo", - "tool_args_1": {"a": 1, "b": 2}, + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "foo", "args": {"a": 1, "b": 2}}])), } for k, v in control_expected.items(): assert traj[k] == v, f"{k} mismatch" @@ -328,5 +359,5 @@ async def foo(a, b): # We only care that each observation mentions our error string; we ignore # any extra traceback detail or differing prefixes. for i in range(2): - obs = traj[f"observation_{i}"] + obs = str(traj[f"observation_{i}"]) assert re.search(r"\btool error\b", obs), f"unexpected observation_{i!r}: {obs}"