From f3b72661d2d9fa16768e5a300897fb52178a8e3d Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 24 Jun 2025 13:59:12 -0700 Subject: [PATCH 01/19] init --- dspy/adapters/base.py | 7 +++++++ dspy/predict/react.py | 35 +++++++++++++++++++---------------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index a8cced8827..0f726df049 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -56,6 +56,10 @@ def _call_preprocess( lm_kwargs["tools"] = litellm_tools + import pdb + + pdb.set_trace() + signature_for_native_function_calling = signature.delete(tool_call_output_field_name) return signature_for_native_function_calling @@ -88,6 +92,9 @@ def _call_postprocess( for field_name in signature.output_fields.keys(): value[field_name] = None + import pdb + + pdb.set_trace() if tool_calls and tool_call_output_field_name: tool_calls = [ { diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 212f7b2161..8887f03d8d 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -1,10 +1,10 @@ import logging -from typing import TYPE_CHECKING, Any, Callable, Literal, Type +from typing import TYPE_CHECKING, Any, Callable, Type from litellm import ContextWindowExceededError import dspy -from dspy.adapters.types.tool import Tool +from dspy.adapters.types.tool import Tool, ToolCalls from dspy.primitives.module import Module from dspy.signatures.signature import ensure_signature @@ -53,10 +53,10 @@ def get_weather(city: str) -> str: [ f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", - "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.", + "To do this, you will interleave next_thought and tool_calls in each turn, and also when finishing the task.", "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", "When writing next_thought, you may reason about the current situation and plan for future steps.", - "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", + "When selecting the tool_calls, the tool must be one of the tool provided in tools\n", ] ) @@ -67,16 +67,16 @@ def get_weather(city: str) -> str: args={}, ) - for idx, tool in enumerate(tools.values()): - instr.append(f"({idx + 1}) {tool}") - instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format") + # for idx, tool in enumerate(tools.values()): + # instr.append(f"({idx + 1}) {tool}") + # instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format") react_signature = ( dspy.Signature({**signature.input_fields}, "\n".join(instr)) .append("trajectory", dspy.InputField(), type_=str) + .append("tools", dspy.InputField(), type_=list[Tool]) .append("next_thought", dspy.OutputField(), type_=str) - .append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) - .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) + .append("tool_calls", dspy.OutputField(), type_=ToolCalls) ) fallback_signature = dspy.Signature( @@ -104,15 +104,17 @@ def forward(self, **input_args): break trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args + trajectory[f"tool_calls_{idx}"] = pred.tool_calls + trajectory[f"observation_{idx}"] = [] - try: - trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) - except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + tool_names = [tool_call.name for tool_call in pred.tool_calls.tool_calls] + for tool_call in pred.tool_calls.tool_calls: + try: + trajectory[f"observation_{idx}"].append(self.tools[tool_call.name](**tool_call.args)) + except Exception as err: + trajectory[f"observation_{idx}"].append(f"Execution error in {tool_call.name}: {_fmt_exc(err)}") - if pred.next_tool_name == "finish": + if "finish" in tool_names or len(tool_names) == 0: break extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) @@ -149,6 +151,7 @@ def _call_with_potential_trajectory_truncation(self, module, trajectory, **input return module( **input_args, trajectory=self._format_trajectory(trajectory), + tools=list(self.tools.values()), ) except ContextWindowExceededError: logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") From 4c01078fc3cec0b13fefc90e3e7474f4b4af6008 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 24 Jun 2025 16:23:13 -0700 Subject: [PATCH 02/19] increment --- dspy/adapters/base.py | 11 +++-------- dspy/clients/lm.py | 9 +++++---- dspy/predict/react.py | 5 +++-- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 0f726df049..5bf8b4a146 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -56,11 +56,10 @@ def _call_preprocess( lm_kwargs["tools"] = litellm_tools - import pdb - - pdb.set_trace() - signature_for_native_function_calling = signature.delete(tool_call_output_field_name) + signature_for_native_function_calling = signature_for_native_function_calling.delete( + tool_call_input_field_name + ) return signature_for_native_function_calling @@ -91,10 +90,6 @@ def _call_postprocess( value = {} for field_name in signature.output_fields.keys(): value[field_name] = None - - import pdb - - pdb.set_trace() if tool_calls and tool_call_output_field_name: tool_calls = [ { diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 228d78b63a..a29359bc29 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -14,7 +14,7 @@ from dspy.clients.provider import Provider, ReinforceJob, TrainingJob from dspy.clients.utils_finetune import TrainDataFormat from dspy.dsp.utils.settings import settings -from dspy.utils.callback import BaseCallback +from dspy.utils.callback import BaseCallback, with_callbacks from .base_lm import BaseLM @@ -83,9 +83,9 @@ def __init__( if model_pattern: # Handle OpenAI reasoning models (o1, o3) - assert ( - max_tokens >= 20_000 and temperature == 1.0 - ), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`" + assert max_tokens >= 20_000 and temperature == 1.0, ( + "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`" + ) self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs) else: self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) @@ -113,6 +113,7 @@ def _get_cached_completion_fn(self, completion_fn, cache, enable_memory_cache): return completion_fn, litellm_cache_args + @with_callbacks def forward(self, prompt=None, messages=None, **kwargs): # Build the request. cache = kwargs.pop("cache", self.cache) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 8887f03d8d..852589b7dc 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -53,10 +53,11 @@ def get_weather(city: str) -> str: [ f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", - "To do this, you will interleave next_thought and tool_calls in each turn, and also when finishing the task.", + "To do this, you will reason about the current situation to decide which tool to call next along with the arguments to pass to the tool.", + "Or finish the task.\n", "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", "When writing next_thought, you may reason about the current situation and plan for future steps.", - "When selecting the tool_calls, the tool must be one of the tool provided in tools\n", + "When selecting the tools to call, the tool must be one of the tools provided to you.\n", ] ) From 63cb5f97d5c70caa2c40071902e2991f69f24e64 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 24 Jun 2025 21:19:32 -0700 Subject: [PATCH 03/19] increment --- dspy/adapters/base.py | 22 ++++++++++++++++------ dspy/adapters/json_adapter.py | 28 +++++++++++++++++++++++++--- dspy/predict/react.py | 17 ++++++++++------- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 5bf8b4a146..e60c81adbe 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -67,12 +67,13 @@ def _call_preprocess( def _call_postprocess( self, - signature: Type[Signature], + processed_signature: Type[Signature], + original_signature: Type[Signature], outputs: list[dict[str, Any]], ) -> list[dict[str, Any]]: values = [] - tool_call_output_field_name = self._get_tool_call_output_field_name(signature) + tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature) for output in outputs: output_logprobs = None @@ -85,11 +86,16 @@ def _call_postprocess( tool_calls = output.get("tool_calls") if text: - value = self.parse(signature, text) + value = self.parse(processed_signature, text) + for field_name in original_signature.output_fields.keys(): + if field_name not in value: + # We need to set the field not present in the processed signature to None. + value[field_name] = None else: value = {} - for field_name in signature.output_fields.keys(): + for field_name in original_signature.output_fields.keys(): value[field_name] = None + if tool_calls and tool_call_output_field_name: tool_calls = [ { @@ -119,7 +125,7 @@ def __call__( inputs = self.format(processed_signature, demos, inputs) outputs = lm(messages=inputs, **lm_kwargs) - return self._call_postprocess(signature, outputs) + return self._call_postprocess(processed_signature, signature, outputs) async def acall( self, @@ -133,7 +139,11 @@ async def acall( inputs = self.format(processed_signature, demos, inputs) outputs = await lm.acall(messages=inputs, **lm_kwargs) - return self._call_postprocess(signature, outputs) + return self._call_postprocess( + processed_signature=processed_signature, + original_signature=signature, + outputs=outputs, + ) def format( self, diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 9b14691c7c..e7d1ceefc4 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -9,6 +9,7 @@ from pydantic.fields import FieldInfo from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName +from dspy.adapters.types.tool import ToolCalls from dspy.adapters.utils import ( format_field_value, get_annotation_name, @@ -37,6 +38,10 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool: class JSONAdapter(ChatAdapter): + def __init__(self, enable_structured_output: bool = True, use_native_function_calling: bool = False, **kwargs): + super().__init__(**kwargs) + self.use_native_function_calling = use_native_function_calling + def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn): """Common call logic to be used for both sync and async calls.""" provider = lm.model.split("/", 1)[0] or "openai" @@ -62,7 +67,9 @@ def __call__( return result try: - structured_output_model = _get_structured_outputs_response_format(signature) + structured_output_model = _get_structured_outputs_response_format( + signature, self.use_native_function_calling + ) lm_kwargs["response_format"] = structured_output_model return super().__call__(lm, lm_kwargs, signature, demos, inputs) except Exception: @@ -99,7 +106,13 @@ def _call_preprocess( inputs: dict[str, Any], use_native_function_calling: bool = True, ) -> dict[str, Any]: - return super()._call_preprocess(lm, lm_kwargs, signature, inputs, use_native_function_calling) + return super()._call_preprocess( + lm, + lm_kwargs, + signature, + inputs, + use_native_function_calling=self.use_native_function_calling, + ) def format_field_structure(self, signature: Type[Signature]) -> str: parts = [] @@ -128,6 +141,9 @@ def type_info(v): else "" ) + if self.use_native_function_calling: + return None + message = "Respond with a JSON object in the following order of fields: " message += ", then ".join(f"`{f}`{type_info(v)}" for f, v in signature.output_fields.items()) message += "." @@ -206,7 +222,10 @@ def format_finetune_data( raise NotImplementedError -def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[pydantic.BaseModel]: +def _get_structured_outputs_response_format( + signature: SignatureMeta, + enable_native_function_calling: bool = True, +) -> type[pydantic.BaseModel]: """ Builds a Pydantic model from a DSPy signature's output_fields and ensures the generated JSON schema is compatible with OpenAI Structured Outputs (all objects have a "required" key listing every property, @@ -227,6 +246,9 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py fields = {} for name, field in signature.output_fields.items(): annotation = field.annotation + if enable_native_function_calling and annotation == ToolCalls: + # Skip ToolCalls field if native function calling is enabled. + continue default = field.default if hasattr(field, "default") else ... fields[name] = (annotation, default) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 852589b7dc..063a94ca74 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -54,10 +54,12 @@ def get_weather(city: str) -> str: f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", "To do this, you will reason about the current situation to decide which tool to call next along with the arguments to pass to the tool.", - "Or finish the task.\n", + "Or finish the task.\n\n", "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", "When writing next_thought, you may reason about the current situation and plan for future steps.", - "When selecting the tools to call, the tool must be one of the tools provided to you.\n", + "When selecting the tools to call, the tool must be one of the tools provided to you.\n\n", + "!!!If you decide a tool call is required, then you MUST IGNORE the output fields requirements, and just return the tool", + "call information.\n\n", ] ) @@ -68,8 +70,8 @@ def get_weather(city: str) -> str: args={}, ) - # for idx, tool in enumerate(tools.values()): - # instr.append(f"({idx + 1}) {tool}") + for idx, tool in enumerate(tools.values()): + instr.append(f"({idx + 1}) {tool}") # instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format") react_signature = ( @@ -108,10 +110,11 @@ def forward(self, **input_args): trajectory[f"tool_calls_{idx}"] = pred.tool_calls trajectory[f"observation_{idx}"] = [] - tool_names = [tool_call.name for tool_call in pred.tool_calls.tool_calls] - for tool_call in pred.tool_calls.tool_calls: + tool_calls = [] if pred.tool_calls is None else pred.tool_calls.tool_calls + tool_names = [tool_call.name for tool_call in tool_calls] + for tool_call in tool_calls: try: - trajectory[f"observation_{idx}"].append(self.tools[tool_call.name](**tool_call.args)) + trajectory[f"observation_{idx}"].append(f"{self.tools[tool_call.name](**tool_call.args)}") except Exception as err: trajectory[f"observation_{idx}"].append(f"Execution error in {tool_call.name}: {_fmt_exc(err)}") From 4779b18340db6312fbbc9710a73d77d35975e927 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 25 Jun 2025 10:48:17 -0700 Subject: [PATCH 04/19] give it a pause --- dspy/adapters/json_adapter.py | 3 ++- dspy/predict/react.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index e7d1ceefc4..d95deb3096 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -50,7 +50,8 @@ def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, cal if not params or "response_format" not in params: return call_fn(lm, lm_kwargs, signature, demos, inputs) - if _has_open_ended_mapping(signature): + has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values()) + if _has_open_ended_mapping(signature) or (self.use_native_function_calling and has_tool_calls): lm_kwargs["response_format"] = {"type": "json_object"} return call_fn(lm, lm_kwargs, signature, demos, inputs) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 063a94ca74..1f41904c6a 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -106,8 +106,8 @@ def forward(self, **input_args): logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") break - trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_calls_{idx}"] = pred.tool_calls + # trajectory[f"thought_{idx}"] = pred.next_thought + trajectory[f"tool_calls_{idx}"] = str(pred.tool_calls) trajectory[f"observation_{idx}"] = [] tool_calls = [] if pred.tool_calls is None else pred.tool_calls.tool_calls From f56afeb19aa7d43fa9819fd37f11601410b07b8b Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Fri, 27 Jun 2025 19:48:04 -0700 Subject: [PATCH 05/19] increment --- dspy/predict/react.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 1f41904c6a..d74266ab70 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -79,7 +79,7 @@ def get_weather(city: str) -> str: .append("trajectory", dspy.InputField(), type_=str) .append("tools", dspy.InputField(), type_=list[Tool]) .append("next_thought", dspy.OutputField(), type_=str) - .append("tool_calls", dspy.OutputField(), type_=ToolCalls) + .append("next_tool_calls", dspy.OutputField(), type_=ToolCalls) ) fallback_signature = dspy.Signature( @@ -106,11 +106,11 @@ def forward(self, **input_args): logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") break - # trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_calls_{idx}"] = str(pred.tool_calls) + trajectory[f"thought_{idx}"] = pred.next_thought + trajectory[f"tool_calls_{idx}"] = str(pred.next_tool_calls) trajectory[f"observation_{idx}"] = [] - tool_calls = [] if pred.tool_calls is None else pred.tool_calls.tool_calls + tool_calls = [] if pred.next_tool_calls is None else pred.next_tool_calls.tool_calls tool_names = [tool_call.name for tool_call in tool_calls] for tool_call in tool_calls: try: From d00e7c835ac3b63a9eb358e1c64186b13e704ad2 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Mon, 30 Jun 2025 14:14:03 -0700 Subject: [PATCH 06/19] increment --- dspy/adapters/types/tool.py | 10 ++++++++-- dspy/clients/lm.py | 37 ++++++++++++++++++------------------- dspy/predict/react.py | 9 +++------ 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/dspy/adapters/types/tool.py b/dspy/adapters/types/tool.py index 6bddbee9a8..0256363a5c 100644 --- a/dspy/adapters/types/tool.py +++ b/dspy/adapters/types/tool.py @@ -2,6 +2,7 @@ import inspect from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, get_origin, get_type_hints +import pydantic from jsonschema import ValidationError, validate from pydantic import BaseModel, TypeAdapter, create_model @@ -288,7 +289,7 @@ def from_dict_list(cls, tool_calls_dicts: list[dict[str, Any]]) -> "ToolCalls": def description(cls) -> str: return ( "Tool calls information, including the name of the tools and the arguments to be passed to it. " - "Arguments must be provided in JSON format." + '`args` must be provided in JSON format, e.g., `{"arg1": "value1", "arg2": "value2"}`' ) def format(self) -> list[dict[str, Any]]: @@ -303,11 +304,16 @@ def format(self) -> list[dict[str, Any]]: "name": tool_call.name, "arguments": tool_call.args, }, - } for tool_call in self.tool_calls + } + for tool_call in self.tool_calls ], } ] + @pydantic.model_serializer() + def serialize_model(self): + return self.format() + def _resolve_json_schema_reference(schema: dict) -> dict: """Recursively resolve json model schema, expanding all references.""" diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index a29359bc29..e79a731681 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -2,7 +2,7 @@ import os import re import threading -from typing import Any, Dict, List, Literal, Optional, cast +from typing import Any, Dict, List, Literal, cast import litellm from anyio.streams.memory import MemoryObjectSendStream @@ -14,7 +14,7 @@ from dspy.clients.provider import Provider, ReinforceJob, TrainingJob from dspy.clients.utils_finetune import TrainDataFormat from dspy.dsp.utils.settings import settings -from dspy.utils.callback import BaseCallback, with_callbacks +from dspy.utils.callback import BaseCallback from .base_lm import BaseLM @@ -34,12 +34,12 @@ def __init__( max_tokens: int = 4000, cache: bool = True, cache_in_memory: bool = True, - callbacks: Optional[List[BaseCallback]] = None, + callbacks: List[BaseCallback] | None = None, num_retries: int = 3, - provider: Optional[Provider] = None, - finetuning_model: Optional[str] = None, - launch_kwargs: Optional[dict[str, Any]] = None, - train_kwargs: Optional[dict[str, Any]] = None, + provider: Provider | None = None, + finetuning_model: str | None = None, + launch_kwargs: dict[str, Any] | None = None, + train_kwargs: dict[str, Any] | None = None, **kwargs, ): """ @@ -83,9 +83,9 @@ def __init__( if model_pattern: # Handle OpenAI reasoning models (o1, o3) - assert max_tokens >= 20_000 and temperature == 1.0, ( - "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`" - ) + assert ( + max_tokens >= 20_000 and temperature == 1.0 + ), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`" self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs) else: self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) @@ -113,7 +113,6 @@ def _get_cached_completion_fn(self, completion_fn, cache, enable_memory_cache): return completion_fn, litellm_cache_args - @with_callbacks def forward(self, prompt=None, messages=None, **kwargs): # Build the request. cache = kwargs.pop("cache", self.cache) @@ -174,17 +173,17 @@ async def aforward(self, prompt=None, messages=None, **kwargs): settings.usage_tracker.add_usage(self.model, dict(results.usage)) return results - def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None): + def launch(self, launch_kwargs: Dict[str, Any] | None = None): self.provider.launch(self, launch_kwargs) - def kill(self, launch_kwargs: Optional[Dict[str, Any]] = None): + def kill(self, launch_kwargs: Dict[str, Any] | None = None): self.provider.kill(self, launch_kwargs) def finetune( self, train_data: List[Dict[str, Any]], - train_data_format: Optional[TrainDataFormat], - train_kwargs: Optional[Dict[str, Any]] = None, + train_data_format: TrainDataFormat | None, + train_kwargs: Dict[str, Any] | None = None, ) -> TrainingJob: from dspy import settings as settings @@ -301,7 +300,7 @@ async def async_stream_completion(): return async_stream_completion -def litellm_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None): +def litellm_completion(request: Dict[str, Any], num_retries: int, cache: Dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} stream_completion = _get_stream_completion_fn(request, cache, sync=True) if stream_completion is None: @@ -315,7 +314,7 @@ def litellm_completion(request: Dict[str, Any], num_retries: int, cache: Optiona return stream_completion() -def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None): +def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache: Dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} # Extract the provider and model from the model string. # TODO: Not all the models are in the format of "provider/model" @@ -341,7 +340,7 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache: Op ) -async def alitellm_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None): +async def alitellm_completion(request: Dict[str, Any], num_retries: int, cache: Dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} stream_completion = _get_stream_completion_fn(request, cache, sync=False) if stream_completion is None: @@ -355,7 +354,7 @@ async def alitellm_completion(request: Dict[str, Any], num_retries: int, cache: return await stream_completion() -async def alitellm_text_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None): +async def alitellm_text_completion(request: Dict[str, Any], num_retries: int, cache: Dict[str, Any] | None = None): cache = cache or {"no-cache": True, "no-store": True} model = request.pop("model").split("/", 1) provider, model = model[0] if len(model) > 1 else "openai", model[-1] diff --git a/dspy/predict/react.py b/dspy/predict/react.py index d74266ab70..703fec2779 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -53,13 +53,10 @@ def get_weather(city: str) -> str: [ f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", - "To do this, you will reason about the current situation to decide which tool to call next along with the arguments to pass to the tool.", - "Or finish the task.\n\n", + "To do this, you will interleave next_thought and next_tool_calls in each turn, and also when finishing the task.", "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", "When writing next_thought, you may reason about the current situation and plan for future steps.", - "When selecting the tools to call, the tool must be one of the tools provided to you.\n\n", - "!!!If you decide a tool call is required, then you MUST IGNORE the output fields requirements, and just return the tool", - "call information.\n\n", + "When selecting the tools to call, the tool must be one of:\n", ] ) @@ -72,7 +69,7 @@ def get_weather(city: str) -> str: for idx, tool in enumerate(tools.values()): instr.append(f"({idx + 1}) {tool}") - # instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format") + instr.append("When providing `next_tool_calls`, the args value inside the field must be in JSON format") react_signature = ( dspy.Signature({**signature.input_fields}, "\n".join(instr)) From fb0a4fa109c74d717ebc23a3c0836603dcfa9dc8 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 1 Jul 2025 14:34:33 -0700 Subject: [PATCH 07/19] update test --- dspy/adapters/chat_adapter.py | 1 + dspy/predict/react.py | 18 ++-- tests/predict/test_react.py | 187 ++++++++++++++++++++-------------- 3 files changed, 121 insertions(+), 85 deletions(-) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 7e7859978a..c6375aa311 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -41,6 +41,7 @@ def __call__( try: return super().__call__(lm, lm_kwargs, signature, demos, inputs) except Exception as e: + raise e # fallback to JSONAdapter from dspy.adapters.json_adapter import JSONAdapter diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 703fec2779..12b1c71a32 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -132,15 +132,19 @@ async def aforward(self, **input_args): break trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args + trajectory[f"tool_calls_{idx}"] = str(pred.next_tool_calls) + trajectory[f"observation_{idx}"] = [] - try: - trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args) - except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + tool_calls = [] if pred.next_tool_calls is None else pred.next_tool_calls.tool_calls + tool_names = [tool_call.name for tool_call in tool_calls] + for tool_call in tool_calls: + try: + result = await self.tools[tool_call.name].acall(**tool_call.args) + trajectory[f"observation_{idx}"].append(result) + except Exception as err: + trajectory[f"observation_{idx}"].append(f"Execution error in {tool_call.name}: {_fmt_exc(err)}") - if pred.next_tool_name == "finish": + if "finish" in tool_names or len(tool_names) == 0: break extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) diff --git a/tests/predict/test_react.py b/tests/predict/test_react.py index 4ae1b6f52b..3ee5d2df60 100644 --- a/tests/predict/test_react.py +++ b/tests/predict/test_react.py @@ -5,6 +5,7 @@ from pydantic import BaseModel import dspy +from dspy.adapters.types.tool import ToolCalls from dspy.utils.dummies import DummyLM @@ -30,14 +31,20 @@ class InvitationSignature(dspy.Signature): [ { "next_thought": "I need to write an invitation letter for Alice to the Science Fair event.", - "next_tool_name": "write_invitation_letter", - "next_tool_args": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, + "next_tool_calls": { + "tool_calls": [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ] }, }, { @@ -45,8 +52,7 @@ class InvitationSignature(dspy.Signature): "I have successfully written the invitation letter for Alice to the Science Fair. Now " "I can finish the task." ), - "next_tool_name": "finish", - "next_tool_args": {}, + "next_tool_calls": {"tool_calls": [{"name": "finish", "args": {}}]}, }, { "reasoning": "This is a very rigorous reasoning process, trust me bro!", @@ -68,20 +74,27 @@ class InvitationSignature(dspy.Signature): expected_trajectory = { "thought_0": "I need to write an invitation letter for Alice to the Science Fair event.", - "tool_name_0": "write_invitation_letter", - "tool_args_0": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, - }, - "observation_0": "It's my honor to invite Alice to event Science Fair on Friday", + "tool_calls_0": str( + ToolCalls.from_dict_list( + [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ] + ) + ), + "observation_0": ["It's my honor to invite Alice to event Science Fair on Friday"], "thought_1": "I have successfully written the invitation letter for Alice to the Science Fair. Now I can finish the task.", - "tool_name_1": "finish", - "tool_args_1": {}, - "observation_1": "Completed.", + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "finish", "args": {}}])), + "observation_1": ["Completed."], } assert outputs.trajectory == expected_trajectory @@ -94,8 +107,14 @@ def foo(a, b): react = dspy.ReAct("a, b -> c:int", tools=[foo]) lm = DummyLM( [ - {"next_thought": "I need to add two numbers.", "next_tool_name": "foo", "next_tool_args": {"a": 1, "b": 2}}, - {"next_thought": "I have the sum, now I can finish.", "next_tool_name": "finish", "next_tool_args": {}}, + { + "next_thought": "I need to add two numbers.", + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, + }, + { + "next_thought": "I have the sum, now I can finish.", + "next_tool_calls": {"tool_calls": [{"name": "finish", "args": {}}]}, + }, {"reasoning": "I added the numbers successfully", "c": 3}, ] ) @@ -104,16 +123,20 @@ def foo(a, b): expected_trajectory = { "thought_0": "I need to add two numbers.", - "tool_name_0": "foo", - "tool_args_0": { - "a": 1, - "b": 2, - }, - "observation_0": 3, + "tool_calls_0": str( + ToolCalls.from_dict_list( + [ + { + "name": "foo", + "args": {"a": 1, "b": 2}, + } + ] + ) + ), + "observation_0": ["3"], "thought_1": "I have the sum, now I can finish.", - "tool_name_1": "finish", - "tool_args_1": {}, - "observation_1": "Completed.", + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "finish", "args": {}}])), + "observation_1": ["Completed."], } assert outputs.trajectory == expected_trajectory @@ -137,15 +160,19 @@ def mock_react(**kwargs): # First 2 calls use the echo tool return dspy.Prediction( next_thought=f"Thought {call_count}", - next_tool_name="echo", - next_tool_args={"text": f"Text {call_count}"}, + next_tool_calls=ToolCalls( + tool_calls=[ToolCalls.ToolCall(name="echo", args={"text": f"Text {call_count}"})] + ), ) elif call_count == 3: # The 3rd call raises context window exceeded error raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider") else: # The 4th call finishes - return dspy.Prediction(next_thought="Final thought", next_tool_name="finish", next_tool_args={}) + return dspy.Prediction( + next_thought="Final thought", + next_tool_calls=ToolCalls(tool_calls=[ToolCalls.ToolCall(name="finish", args={})]), + ) react.react = mock_react react.extract = lambda **kwargs: dspy.Prediction(output_text="Final output") @@ -170,13 +197,11 @@ def foo(a, b): [ { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, }, { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, }, # (The model *would* succeed on the 3rd turn, but max_iters=2 stops earlier.) {"reasoning": "I added the numbers successfully", "c": 3}, @@ -190,11 +215,9 @@ def foo(a, b): # --- exact-match checks (thoughts + tool calls) ------------------------- control_expected = { "thought_0": "I need to add two numbers.", - "tool_name_0": "foo", - "tool_args_0": {"a": 1, "b": 2}, + "tool_calls_0": str(ToolCalls.from_dict_list([{"name": "foo", "args": {"a": 1, "b": 2}}])), "thought_1": "I need to add two numbers.", - "tool_name_1": "foo", - "tool_args_1": {"a": 1, "b": 2}, + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "foo", "args": {"a": 1, "b": 2}}])), } for k, v in control_expected.items(): assert traj[k] == v, f"{k} mismatch" @@ -203,7 +226,7 @@ def foo(a, b): # We only care that each observation mentions our error string; we ignore # any extra traceback detail or differing prefixes. for i in range(2): - obs = traj[f"observation_{i}"] + obs = str(traj[f"observation_{i}"]) assert re.search(r"\btool error\b", obs), f"unexpected observation_{i!r}: {obs}" @@ -230,14 +253,20 @@ class InvitationSignature(dspy.Signature): [ { "next_thought": "I need to write an invitation letter for Alice to the Science Fair event.", - "next_tool_name": "write_invitation_letter", - "next_tool_args": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, + "next_tool_calls": { + "tool_calls": [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ] }, }, { @@ -245,8 +274,7 @@ class InvitationSignature(dspy.Signature): "I have successfully written the invitation letter for Alice to the Science Fair. Now " "I can finish the task." ), - "next_tool_name": "finish", - "next_tool_args": {}, + "next_tool_calls": {"tool_calls": [{"name": "finish", "args": {}}]}, }, { "reasoning": "This is a very rigorous reasoning process, trust me bro!", @@ -267,20 +295,27 @@ class InvitationSignature(dspy.Signature): expected_trajectory = { "thought_0": "I need to write an invitation letter for Alice to the Science Fair event.", - "tool_name_0": "write_invitation_letter", - "tool_args_0": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, - }, - "observation_0": "It's my honor to invite Alice to event Science Fair on Friday", + "tool_calls_0": str( + ToolCalls.from_dict_list( + [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ] + ) + ), + "observation_0": ["It's my honor to invite Alice to event Science Fair on Friday"], "thought_1": "I have successfully written the invitation letter for Alice to the Science Fair. Now I can finish the task.", - "tool_name_1": "finish", - "tool_args_1": {}, - "observation_1": "Completed.", + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "finish", "args": {}}])), + "observation_1": ["Completed."], } assert outputs.trajectory == expected_trajectory @@ -296,13 +331,11 @@ async def foo(a, b): [ { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, }, { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": {"tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, }, # (The model *would* succeed on the 3rd turn, but max_iters=2 stops earlier.) {"reasoning": "I added the numbers successfully", "c": 3}, @@ -315,11 +348,9 @@ async def foo(a, b): # Exact-match checks (thoughts + tool calls) control_expected = { "thought_0": "I need to add two numbers.", - "tool_name_0": "foo", - "tool_args_0": {"a": 1, "b": 2}, + "tool_calls_0": str(ToolCalls.from_dict_list([{"name": "foo", "args": {"a": 1, "b": 2}}])), "thought_1": "I need to add two numbers.", - "tool_name_1": "foo", - "tool_args_1": {"a": 1, "b": 2}, + "tool_calls_1": str(ToolCalls.from_dict_list([{"name": "foo", "args": {"a": 1, "b": 2}}])), } for k, v in control_expected.items(): assert traj[k] == v, f"{k} mismatch" @@ -328,5 +359,5 @@ async def foo(a, b): # We only care that each observation mentions our error string; we ignore # any extra traceback detail or differing prefixes. for i in range(2): - obs = traj[f"observation_{i}"] + obs = str(traj[f"observation_{i}"]) assert re.search(r"\btool error\b", obs), f"unexpected observation_{i!r}: {obs}" From 839d78992c38acbc7e8c3837d3ab773c7dacd924 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 1 Jul 2025 14:38:16 -0700 Subject: [PATCH 08/19] fix --- dspy/adapters/base.py | 24 ++++++++++++++++++------ dspy/adapters/json_adapter.py | 31 +++++++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 1cdc1edfd9..a1a2582237 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -57,6 +57,9 @@ def _call_preprocess( lm_kwargs["tools"] = litellm_tools signature_for_native_function_calling = signature.delete(tool_call_output_field_name) + signature_for_native_function_calling = signature_for_native_function_calling.delete( + tool_call_input_field_name + ) return signature_for_native_function_calling @@ -64,12 +67,13 @@ def _call_preprocess( def _call_postprocess( self, - signature: Type[Signature], + processed_signature: Type[Signature], + original_signature: Type[Signature], outputs: list[dict[str, Any]], ) -> list[dict[str, Any]]: values = [] - tool_call_output_field_name = self._get_tool_call_output_field_name(signature) + tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature) for output in outputs: output_logprobs = None @@ -82,10 +86,14 @@ def _call_postprocess( tool_calls = output.get("tool_calls") if text: - value = self.parse(signature, text) + value = self.parse(processed_signature, text) + for field_name in original_signature.output_fields.keys(): + if field_name not in value: + # We need to set the field not present in the processed signature to None. + value[field_name] = None else: value = {} - for field_name in signature.output_fields.keys(): + for field_name in original_signature.output_fields.keys(): value[field_name] = None if tool_calls and tool_call_output_field_name: @@ -117,7 +125,7 @@ def __call__( inputs = self.format(processed_signature, demos, inputs) outputs = lm(messages=inputs, **lm_kwargs) - return self._call_postprocess(signature, outputs) + return self._call_postprocess(processed_signature, signature, outputs) async def acall( self, @@ -131,7 +139,11 @@ async def acall( inputs = self.format(processed_signature, demos, inputs) outputs = await lm.acall(messages=inputs, **lm_kwargs) - return self._call_postprocess(signature, outputs) + return self._call_postprocess( + processed_signature=processed_signature, + original_signature=signature, + outputs=outputs, + ) def format( self, diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 9b14691c7c..d95deb3096 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -9,6 +9,7 @@ from pydantic.fields import FieldInfo from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName +from dspy.adapters.types.tool import ToolCalls from dspy.adapters.utils import ( format_field_value, get_annotation_name, @@ -37,6 +38,10 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool: class JSONAdapter(ChatAdapter): + def __init__(self, enable_structured_output: bool = True, use_native_function_calling: bool = False, **kwargs): + super().__init__(**kwargs) + self.use_native_function_calling = use_native_function_calling + def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn): """Common call logic to be used for both sync and async calls.""" provider = lm.model.split("/", 1)[0] or "openai" @@ -45,7 +50,8 @@ def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, cal if not params or "response_format" not in params: return call_fn(lm, lm_kwargs, signature, demos, inputs) - if _has_open_ended_mapping(signature): + has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values()) + if _has_open_ended_mapping(signature) or (self.use_native_function_calling and has_tool_calls): lm_kwargs["response_format"] = {"type": "json_object"} return call_fn(lm, lm_kwargs, signature, demos, inputs) @@ -62,7 +68,9 @@ def __call__( return result try: - structured_output_model = _get_structured_outputs_response_format(signature) + structured_output_model = _get_structured_outputs_response_format( + signature, self.use_native_function_calling + ) lm_kwargs["response_format"] = structured_output_model return super().__call__(lm, lm_kwargs, signature, demos, inputs) except Exception: @@ -99,7 +107,13 @@ def _call_preprocess( inputs: dict[str, Any], use_native_function_calling: bool = True, ) -> dict[str, Any]: - return super()._call_preprocess(lm, lm_kwargs, signature, inputs, use_native_function_calling) + return super()._call_preprocess( + lm, + lm_kwargs, + signature, + inputs, + use_native_function_calling=self.use_native_function_calling, + ) def format_field_structure(self, signature: Type[Signature]) -> str: parts = [] @@ -128,6 +142,9 @@ def type_info(v): else "" ) + if self.use_native_function_calling: + return None + message = "Respond with a JSON object in the following order of fields: " message += ", then ".join(f"`{f}`{type_info(v)}" for f, v in signature.output_fields.items()) message += "." @@ -206,7 +223,10 @@ def format_finetune_data( raise NotImplementedError -def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[pydantic.BaseModel]: +def _get_structured_outputs_response_format( + signature: SignatureMeta, + enable_native_function_calling: bool = True, +) -> type[pydantic.BaseModel]: """ Builds a Pydantic model from a DSPy signature's output_fields and ensures the generated JSON schema is compatible with OpenAI Structured Outputs (all objects have a "required" key listing every property, @@ -227,6 +247,9 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py fields = {} for name, field in signature.output_fields.items(): annotation = field.annotation + if enable_native_function_calling and annotation == ToolCalls: + # Skip ToolCalls field if native function calling is enabled. + continue default = field.default if hasattr(field, "default") else ... fields[name] = (annotation, default) From a3a664ec770fb388b8c054aa1e428c99fa1ce0e6 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 1 Jul 2025 14:40:08 -0700 Subject: [PATCH 09/19] cleanup --- dspy/adapters/base.py | 30 +++++++++--------------------- dspy/adapters/chat_adapter.py | 5 ++--- dspy/adapters/json_adapter.py | 31 ++++--------------------------- 3 files changed, 15 insertions(+), 51 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index e60c81adbe..1cdc1edfd9 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Any, Optional, Type, get_origin +from typing import TYPE_CHECKING, Any, Type, get_origin import json_repair import litellm @@ -17,7 +17,7 @@ class Adapter: - def __init__(self, callbacks: Optional[list[BaseCallback]] = None): + def __init__(self, callbacks: list[BaseCallback] | None = None): self.callbacks = callbacks or [] def __init_subclass__(cls, **kwargs) -> None: @@ -57,9 +57,6 @@ def _call_preprocess( lm_kwargs["tools"] = litellm_tools signature_for_native_function_calling = signature.delete(tool_call_output_field_name) - signature_for_native_function_calling = signature_for_native_function_calling.delete( - tool_call_input_field_name - ) return signature_for_native_function_calling @@ -67,13 +64,12 @@ def _call_preprocess( def _call_postprocess( self, - processed_signature: Type[Signature], - original_signature: Type[Signature], + signature: Type[Signature], outputs: list[dict[str, Any]], ) -> list[dict[str, Any]]: values = [] - tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature) + tool_call_output_field_name = self._get_tool_call_output_field_name(signature) for output in outputs: output_logprobs = None @@ -86,14 +82,10 @@ def _call_postprocess( tool_calls = output.get("tool_calls") if text: - value = self.parse(processed_signature, text) - for field_name in original_signature.output_fields.keys(): - if field_name not in value: - # We need to set the field not present in the processed signature to None. - value[field_name] = None + value = self.parse(signature, text) else: value = {} - for field_name in original_signature.output_fields.keys(): + for field_name in signature.output_fields.keys(): value[field_name] = None if tool_calls and tool_call_output_field_name: @@ -125,7 +117,7 @@ def __call__( inputs = self.format(processed_signature, demos, inputs) outputs = lm(messages=inputs, **lm_kwargs) - return self._call_postprocess(processed_signature, signature, outputs) + return self._call_postprocess(signature, outputs) async def acall( self, @@ -139,11 +131,7 @@ async def acall( inputs = self.format(processed_signature, demos, inputs) outputs = await lm.acall(messages=inputs, **lm_kwargs) - return self._call_postprocess( - processed_signature=processed_signature, - original_signature=signature, - outputs=outputs, - ) + return self._call_postprocess(signature, outputs) def format( self, @@ -293,7 +281,7 @@ def format_assistant_message_content( self, signature: Type[Signature], outputs: dict[str, Any], - missing_field_message: Optional[str] = None, + missing_field_message: str | None = None, ) -> str: """Format the assistant message content. diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index c6375aa311..ac407e4755 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -1,6 +1,6 @@ import re import textwrap -from typing import Any, Dict, NamedTuple, Optional, Type +from typing import Any, Dict, NamedTuple, Type from litellm import ContextWindowExceededError from pydantic.fields import FieldInfo @@ -27,7 +27,7 @@ class FieldInfoWithName(NamedTuple): class ChatAdapter(Adapter): - def __init__(self, callbacks: Optional[list[BaseCallback]] = None): + def __init__(self, callbacks: list[BaseCallback] | None = None): super().__init__(callbacks) def __call__( @@ -41,7 +41,6 @@ def __call__( try: return super().__call__(lm, lm_kwargs, signature, demos, inputs) except Exception as e: - raise e # fallback to JSONAdapter from dspy.adapters.json_adapter import JSONAdapter diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index d95deb3096..9b14691c7c 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -9,7 +9,6 @@ from pydantic.fields import FieldInfo from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName -from dspy.adapters.types.tool import ToolCalls from dspy.adapters.utils import ( format_field_value, get_annotation_name, @@ -38,10 +37,6 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool: class JSONAdapter(ChatAdapter): - def __init__(self, enable_structured_output: bool = True, use_native_function_calling: bool = False, **kwargs): - super().__init__(**kwargs) - self.use_native_function_calling = use_native_function_calling - def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn): """Common call logic to be used for both sync and async calls.""" provider = lm.model.split("/", 1)[0] or "openai" @@ -50,8 +45,7 @@ def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, cal if not params or "response_format" not in params: return call_fn(lm, lm_kwargs, signature, demos, inputs) - has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values()) - if _has_open_ended_mapping(signature) or (self.use_native_function_calling and has_tool_calls): + if _has_open_ended_mapping(signature): lm_kwargs["response_format"] = {"type": "json_object"} return call_fn(lm, lm_kwargs, signature, demos, inputs) @@ -68,9 +62,7 @@ def __call__( return result try: - structured_output_model = _get_structured_outputs_response_format( - signature, self.use_native_function_calling - ) + structured_output_model = _get_structured_outputs_response_format(signature) lm_kwargs["response_format"] = structured_output_model return super().__call__(lm, lm_kwargs, signature, demos, inputs) except Exception: @@ -107,13 +99,7 @@ def _call_preprocess( inputs: dict[str, Any], use_native_function_calling: bool = True, ) -> dict[str, Any]: - return super()._call_preprocess( - lm, - lm_kwargs, - signature, - inputs, - use_native_function_calling=self.use_native_function_calling, - ) + return super()._call_preprocess(lm, lm_kwargs, signature, inputs, use_native_function_calling) def format_field_structure(self, signature: Type[Signature]) -> str: parts = [] @@ -142,9 +128,6 @@ def type_info(v): else "" ) - if self.use_native_function_calling: - return None - message = "Respond with a JSON object in the following order of fields: " message += ", then ".join(f"`{f}`{type_info(v)}" for f, v in signature.output_fields.items()) message += "." @@ -223,10 +206,7 @@ def format_finetune_data( raise NotImplementedError -def _get_structured_outputs_response_format( - signature: SignatureMeta, - enable_native_function_calling: bool = True, -) -> type[pydantic.BaseModel]: +def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[pydantic.BaseModel]: """ Builds a Pydantic model from a DSPy signature's output_fields and ensures the generated JSON schema is compatible with OpenAI Structured Outputs (all objects have a "required" key listing every property, @@ -247,9 +227,6 @@ def _get_structured_outputs_response_format( fields = {} for name, field in signature.output_fields.items(): annotation = field.annotation - if enable_native_function_calling and annotation == ToolCalls: - # Skip ToolCalls field if native function calling is enabled. - continue default = field.default if hasattr(field, "default") else ... fields[name] = (annotation, default) From 1b06c69289f3d49a94c6e2b29daa683713141e71 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 1 Jul 2025 14:43:44 -0700 Subject: [PATCH 10/19] add docstring --- dspy/adapters/types/tool.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dspy/adapters/types/tool.py b/dspy/adapters/types/tool.py index d0046e3745..14004ad749 100644 --- a/dspy/adapters/types/tool.py +++ b/dspy/adapters/types/tool.py @@ -312,6 +312,7 @@ def format(self) -> list[dict[str, Any]]: @pydantic.model_serializer() def serialize_model(self): + """Override so that the ToolCalls are not formatted as a message array when using as inputs.""" return self.format() From 973f61370acc69651331a970ae53e630d656340b Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 1 Jul 2025 14:45:53 -0700 Subject: [PATCH 11/19] direct jsonadapter to chatadapter --- dspy/adapters/json_adapter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 9b14691c7c..338e392ba9 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -9,6 +9,7 @@ from pydantic.fields import FieldInfo from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName +from dspy.adapters.types.tool import ToolCalls from dspy.adapters.utils import ( format_field_value, get_annotation_name, @@ -45,7 +46,8 @@ def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, cal if not params or "response_format" not in params: return call_fn(lm, lm_kwargs, signature, demos, inputs) - if _has_open_ended_mapping(signature): + has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values()) + if _has_open_ended_mapping(signature) or has_tool_calls: lm_kwargs["response_format"] = {"type": "json_object"} return call_fn(lm, lm_kwargs, signature, demos, inputs) From a8e0e47c7239ea4ca0dda7476421138b5c05af9a Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 1 Jul 2025 18:24:55 -0700 Subject: [PATCH 12/19] fix test --- dspy/adapters/base.py | 6 +++--- dspy/adapters/chat_adapter.py | 4 ---- dspy/adapters/json_adapter.py | 29 +++++++------------------- tests/adapters/test_json_adapter.py | 32 +++++++++++++++++++++++++++++ 4 files changed, 42 insertions(+), 29 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index a1a2582237..fa6b3ceb14 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -17,8 +17,9 @@ class Adapter: - def __init__(self, callbacks: list[BaseCallback] | None = None): + def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False): self.callbacks = callbacks or [] + self.use_native_function_calling = use_native_function_calling def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) @@ -33,9 +34,8 @@ def _call_preprocess( lm_kwargs: dict[str, Any], signature: Type[Signature], inputs: dict[str, Any], - use_native_function_calling: bool = False, ) -> dict[str, Any]: - if use_native_function_calling: + if self.use_native_function_calling: tool_call_input_field_name = self._get_tool_call_input_field_name(signature) tool_call_output_field_name = self._get_tool_call_output_field_name(signature) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index ac407e4755..4ace4214de 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -15,7 +15,6 @@ ) from dspy.clients.lm import LM from dspy.signatures.signature import Signature -from dspy.utils.callback import BaseCallback from dspy.utils.exceptions import AdapterParseError field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]") @@ -27,9 +26,6 @@ class FieldInfoWithName(NamedTuple): class ChatAdapter(Adapter): - def __init__(self, callbacks: list[BaseCallback] | None = None): - super().__init__(callbacks) - def __call__( self, lm: LM, diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index d95deb3096..21642a2c38 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -19,6 +19,7 @@ ) from dspy.clients.lm import LM from dspy.signatures.signature import Signature, SignatureMeta +from dspy.utils.callback import BaseCallback from dspy.utils.exceptions import AdapterParseError logger = logging.getLogger(__name__) @@ -38,9 +39,9 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool: class JSONAdapter(ChatAdapter): - def __init__(self, enable_structured_output: bool = True, use_native_function_calling: bool = False, **kwargs): - super().__init__(**kwargs) - self.use_native_function_calling = use_native_function_calling + def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = True): + """JSONAdapter uses native function calling by default.""" + super().__init__(callbacks=callbacks, use_native_function_calling=use_native_function_calling) def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn): """Common call logic to be used for both sync and async calls.""" @@ -51,7 +52,7 @@ def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, cal return call_fn(lm, lm_kwargs, signature, demos, inputs) has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values()) - if _has_open_ended_mapping(signature) or (self.use_native_function_calling and has_tool_calls): + if _has_open_ended_mapping(signature) or (not self.use_native_function_calling and has_tool_calls): lm_kwargs["response_format"] = {"type": "json_object"} return call_fn(lm, lm_kwargs, signature, demos, inputs) @@ -99,22 +100,6 @@ async def acall( lm_kwargs["response_format"] = {"type": "json_object"} return await super().acall(lm, lm_kwargs, signature, demos, inputs) - def _call_preprocess( - self, - lm: "LM", - lm_kwargs: dict[str, Any], - signature: Type[Signature], - inputs: dict[str, Any], - use_native_function_calling: bool = True, - ) -> dict[str, Any]: - return super()._call_preprocess( - lm, - lm_kwargs, - signature, - inputs, - use_native_function_calling=self.use_native_function_calling, - ) - def format_field_structure(self, signature: Type[Signature]) -> str: parts = [] parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") @@ -225,7 +210,7 @@ def format_finetune_data( def _get_structured_outputs_response_format( signature: SignatureMeta, - enable_native_function_calling: bool = True, + use_native_function_calling: bool = True, ) -> type[pydantic.BaseModel]: """ Builds a Pydantic model from a DSPy signature's output_fields and ensures the generated JSON schema @@ -247,7 +232,7 @@ def _get_structured_outputs_response_format( fields = {} for name, field in signature.output_fields.items(): annotation = field.annotation - if enable_native_function_calling and annotation == ToolCalls: + if use_native_function_calling and annotation == ToolCalls: # Skip ToolCalls field if native function calling is enabled. continue default = field.default if hasattr(field, "default") else ... diff --git a/tests/adapters/test_json_adapter.py b/tests/adapters/test_json_adapter.py index eb90a754d1..f35ae22553 100644 --- a/tests/adapters/test_json_adapter.py +++ b/tests/adapters/test_json_adapter.py @@ -650,3 +650,35 @@ class TestSignature(dspy.Signature): await program.acall(question="Dummy question!") assert "ValueError!" in str(error.value) + + +def test_json_adapter_toolcalls_no_native_function_calling(): + class MySignature(dspy.Signature): + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + answer: str = dspy.OutputField() + tool_calls: dspy.ToolCalls = dspy.OutputField() + + def get_weather(city: str) -> str: + return f"The weather in {city} is sunny" + + tools = [dspy.Tool(get_weather)] + + # Patch _get_structured_outputs_response_format to track calls + with mock.patch("dspy.adapters.json_adapter._get_structured_outputs_response_format") as mock_structured: + # Patch litellm.completion to return a dummy response + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="{'answer': 'sunny', 'tool_calls': {'tool_calls': []}}"))], + model="openai/gpt-4o-mini", + ) + adapter = dspy.JSONAdapter(use_native_function_calling=False) + lm = dspy.LM(model="openai/gpt-4o-mini", cache=False) + adapter(lm, {}, MySignature, [], {"question": "What is the weather in Tokyo?", "tools": tools}) + + # _get_structured_outputs_response_format is not called because without using native function calling, + # JSONAdapter falls back to json mode for stable quality. + mock_structured.assert_not_called() + mock_completion.assert_called_once() + _, call_kwargs = mock_completion.call_args + assert call_kwargs["response_format"] == {"type": "json_object"} From 4fb6239685994d3ffff014ea3c579b16c2f68ae9 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 1 Jul 2025 18:42:29 -0700 Subject: [PATCH 13/19] better test --- dspy/clients/lm.py | 7 ++-- tests/adapters/test_chat_adapter.py | 51 ++++++++++++++++++++++++++++- tests/adapters/test_json_adapter.py | 51 ++++++++++++++++++++++++++++- 3 files changed, 104 insertions(+), 5 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index e79a731681..f73ebf04dc 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -83,9 +83,9 @@ def __init__( if model_pattern: # Handle OpenAI reasoning models (o1, o3) - assert ( - max_tokens >= 20_000 and temperature == 1.0 - ), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`" + assert max_tokens >= 20_000 and temperature == 1.0, ( + "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`" + ) self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs) else: self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) @@ -141,6 +141,7 @@ def forward(self, prompt=None, messages=None, **kwargs): if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): settings.usage_tracker.add_usage(self.model, dict(results.usage)) + return results async def aforward(self, prompt=None, messages=None, **kwargs): diff --git a/tests/adapters/test_chat_adapter.py b/tests/adapters/test_chat_adapter.py index e599da5858..1321b2abd5 100644 --- a/tests/adapters/test_chat_adapter.py +++ b/tests/adapters/test_chat_adapter.py @@ -3,7 +3,7 @@ import pydantic import pytest -from litellm.utils import Choices, Message, ModelResponse +from litellm.utils import ChatCompletionMessageToolCall, Choices, Function, Message, ModelResponse import dspy @@ -422,3 +422,52 @@ async def test_chat_adapter_fallback_to_json_adapter_on_exception_async(): # The parse should succeed result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"}) assert result == [{"answer": "Paris"}] + + +def test_chat_adapter_toolcalls_native_function_calling(): + class MySignature(dspy.Signature): + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + answer: str = dspy.OutputField() + tool_calls: dspy.ToolCalls = dspy.OutputField() + + def get_weather(city: str) -> str: + return f"The weather in {city} is sunny" + + tools = [dspy.Tool(get_weather)] + + adapter = dspy.JSONAdapter(use_native_function_calling=True) + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[ + Choices( + finish_reason="tool_calls", + index=0, + message=Message( + content=None, + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + function=Function(arguments='{"city":"Paris"}', name="get_weather"), + id="call_pQm8ajtSMxgA0nrzK2ivFmxG", + type="function", + ) + ], + ), + ), + ], + model="openai/gpt-4o-mini", + ) + result = adapter( + dspy.LM(model="openai/gpt-4o-mini", cache=False), + {}, + MySignature, + [], + {"question": "What is the weather in Paris?", "tools": tools}, + ) + + assert result[0]["tool_calls"] == dspy.ToolCalls( + tool_calls=[dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Paris"})] + ) + # `answer` is not present, so we set it to None + assert result[0]["answer"] is None diff --git a/tests/adapters/test_json_adapter.py b/tests/adapters/test_json_adapter.py index f35ae22553..cb654e4342 100644 --- a/tests/adapters/test_json_adapter.py +++ b/tests/adapters/test_json_adapter.py @@ -2,7 +2,7 @@ import pydantic import pytest -from litellm.utils import Choices, Message, ModelResponse +from litellm.utils import ChatCompletionMessageToolCall, Choices, Function, Message, ModelResponse import dspy @@ -652,6 +652,55 @@ class TestSignature(dspy.Signature): assert "ValueError!" in str(error.value) +def test_json_adapter_toolcalls_native_function_calling(): + class MySignature(dspy.Signature): + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + answer: str = dspy.OutputField() + tool_calls: dspy.ToolCalls = dspy.OutputField() + + def get_weather(city: str) -> str: + return f"The weather in {city} is sunny" + + tools = [dspy.Tool(get_weather)] + + adapter = dspy.JSONAdapter(use_native_function_calling=True) + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[ + Choices( + finish_reason="tool_calls", + index=0, + message=Message( + content=None, + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + function=Function(arguments='{"city":"Paris"}', name="get_weather"), + id="call_pQm8ajtSMxgA0nrzK2ivFmxG", + type="function", + ) + ], + ), + ), + ], + model="openai/gpt-4o-mini", + ) + result = adapter( + dspy.LM(model="openai/gpt-4o-mini", cache=False), + {}, + MySignature, + [], + {"question": "What is the weather in Paris?", "tools": tools}, + ) + + assert result[0]["tool_calls"] == dspy.ToolCalls( + tool_calls=[dspy.ToolCalls.ToolCall(name="get_weather", args={"city": "Paris"})] + ) + # `answer` is not present, so we set it to None + assert result[0]["answer"] is None + + def test_json_adapter_toolcalls_no_native_function_calling(): class MySignature(dspy.Signature): question: str = dspy.InputField() From 8e5b6ae60ff2e338fd6ff8e14f96e9da926ea7a7 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 1 Jul 2025 18:55:16 -0700 Subject: [PATCH 14/19] fix --- dspy/adapters/base.py | 8 ++------ dspy/adapters/json_adapter.py | 3 --- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index fa6b3ceb14..3e0372c0eb 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -89,7 +89,7 @@ def _call_postprocess( value = self.parse(processed_signature, text) for field_name in original_signature.output_fields.keys(): if field_name not in value: - # We need to set the field not present in the processed signature to None. + # We need to set the field not present in the processed signature to None for consistency. value[field_name] = None else: value = {} @@ -139,11 +139,7 @@ async def acall( inputs = self.format(processed_signature, demos, inputs) outputs = await lm.acall(messages=inputs, **lm_kwargs) - return self._call_postprocess( - processed_signature=processed_signature, - original_signature=signature, - outputs=outputs, - ) + return self._call_postprocess(processed_signature, signature, outputs) def format( self, diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 21642a2c38..4fb44510cc 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -127,9 +127,6 @@ def type_info(v): else "" ) - if self.use_native_function_calling: - return None - message = "Respond with a JSON object in the following order of fields: " message += ", then ".join(f"`{f}`{type_info(v)}" for f, v in signature.output_fields.items()) message += "." From ee7da5408a10b515e5e73965e611fa4f43a8b6ca Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 2 Jul 2025 12:03:30 -0700 Subject: [PATCH 15/19] fix tests --- dspy/adapters/two_step_adapter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dspy/adapters/two_step_adapter.py b/dspy/adapters/two_step_adapter.py index 405cc99e29..27b8d5e72b 100644 --- a/dspy/adapters/two_step_adapter.py +++ b/dspy/adapters/two_step_adapter.py @@ -39,7 +39,8 @@ class TwoStepAdapter(Adapter): ``` """ - def __init__(self, extraction_model: LM): + def __init__(self, extraction_model: LM, **kwargs): + super().__init__(**kwargs) if not isinstance(extraction_model, LM): raise ValueError("extraction_model must be an instance of LM") self.extraction_model = extraction_model From a7a5a4e3776d83da91ec93212257a78e8d66eec9 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 2 Jul 2025 12:35:51 -0700 Subject: [PATCH 16/19] lint fix --- dspy/adapters/json_adapter.py | 4 +++- dspy/clients/lm.py | 7 +++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 4fb44510cc..51470e2ecd 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -40,7 +40,7 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool: class JSONAdapter(ChatAdapter): def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = True): - """JSONAdapter uses native function calling by default.""" + # JSONAdapter uses native function calling by default. super().__init__(callbacks=callbacks, use_native_function_calling=use_native_function_calling) def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn): @@ -53,6 +53,8 @@ def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, cal has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values()) if _has_open_ended_mapping(signature) or (not self.use_native_function_calling and has_tool_calls): + # We found that structured output mode doesn't work well with dspy.ToolCalls as output field. + # So we fall back to json mode if native function calling is disabled and ToolCalls is present. lm_kwargs["response_format"] = {"type": "json_object"} return call_fn(lm, lm_kwargs, signature, demos, inputs) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index f73ebf04dc..e79a731681 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -83,9 +83,9 @@ def __init__( if model_pattern: # Handle OpenAI reasoning models (o1, o3) - assert max_tokens >= 20_000 and temperature == 1.0, ( - "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`" - ) + assert ( + max_tokens >= 20_000 and temperature == 1.0 + ), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`" self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs) else: self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) @@ -141,7 +141,6 @@ def forward(self, prompt=None, messages=None, **kwargs): if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): settings.usage_tracker.add_usage(self.model, dict(results.usage)) - return results async def aforward(self, prompt=None, messages=None, **kwargs): From dbbaea21329c636da2eb07056ed5a8be035c5c15 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 2 Jul 2025 12:39:13 -0700 Subject: [PATCH 17/19] better test --- tests/adapters/test_chat_adapter.py | 18 ++++++++++++++++++ tests/adapters/test_json_adapter.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tests/adapters/test_chat_adapter.py b/tests/adapters/test_chat_adapter.py index 1321b2abd5..457ba0b0ea 100644 --- a/tests/adapters/test_chat_adapter.py +++ b/tests/adapters/test_chat_adapter.py @@ -437,6 +437,8 @@ def get_weather(city: str) -> str: tools = [dspy.Tool(get_weather)] adapter = dspy.JSONAdapter(use_native_function_calling=True) + + # Case 1: Tool calls are present in the response, while content is None. with mock.patch("litellm.completion") as mock_completion: mock_completion.return_value = ModelResponse( choices=[ @@ -471,3 +473,19 @@ def get_weather(city: str) -> str: ) # `answer` is not present, so we set it to None assert result[0]["answer"] is None + + # Case 2: Tool calls are not present in the response, while content is present. + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="{'answer': 'Paris'}"))], + model="openai/gpt-4o-mini", + ) + result = adapter( + dspy.LM(model="openai/gpt-4o-mini", cache=False), + {}, + MySignature, + [], + {"question": "What is the weather in Paris?", "tools": tools}, + ) + assert result[0]["answer"] == "Paris" + assert result[0]["tool_calls"] is None diff --git a/tests/adapters/test_json_adapter.py b/tests/adapters/test_json_adapter.py index cb654e4342..d58e5f9efc 100644 --- a/tests/adapters/test_json_adapter.py +++ b/tests/adapters/test_json_adapter.py @@ -665,6 +665,8 @@ def get_weather(city: str) -> str: tools = [dspy.Tool(get_weather)] adapter = dspy.JSONAdapter(use_native_function_calling=True) + + # Case 1: Tool calls are present in the response, while content is None. with mock.patch("litellm.completion") as mock_completion: mock_completion.return_value = ModelResponse( choices=[ @@ -700,6 +702,22 @@ def get_weather(city: str) -> str: # `answer` is not present, so we set it to None assert result[0]["answer"] is None + # Case 2: Tool calls are not present in the response, while content is present. + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="{'answer': 'Paris'}"))], + model="openai/gpt-4o-mini", + ) + result = adapter( + dspy.LM(model="openai/gpt-4o-mini", cache=False), + {}, + MySignature, + [], + {"question": "What is the weather in Paris?", "tools": tools}, + ) + assert result[0]["answer"] == "Paris" + assert result[0]["tool_calls"] is None + def test_json_adapter_toolcalls_no_native_function_calling(): class MySignature(dspy.Signature): From 01b3bc1519be8eaf43c644397948a90224cf5611 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 2 Jul 2025 13:44:21 -0700 Subject: [PATCH 18/19] increment --- dspy/predict/react.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 12b1c71a32..c13a8eb9b9 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -96,6 +96,7 @@ def _format_trajectory(self, trajectory: dict[str, Any]): def forward(self, **input_args): trajectory = {} max_iters = input_args.pop("max_iters", self.max_iters) + dspy.settings.adapter.use_native_function_calling = False for idx in range(max_iters): try: pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) @@ -124,6 +125,7 @@ def forward(self, **input_args): async def aforward(self, **input_args): trajectory = {} max_iters = input_args.pop("max_iters", self.max_iters) + dspy.settings.adapter.use_native_function_calling = False for idx in range(max_iters): try: pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) From c8752b48321477c022f109c97a9101876e02c6a0 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 2 Jul 2025 17:21:46 -0700 Subject: [PATCH 19/19] some updates --- dspy/predict/react.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index c13a8eb9b9..57a87ca669 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Callable, Type from litellm import ContextWindowExceededError +from pydantic import ValidationError import dspy from dspy.adapters.types.tool import Tool, ToolCalls @@ -100,6 +101,13 @@ def forward(self, **input_args): for idx in range(max_iters): try: pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) + except ValidationError as err: + trajectory[f"thought_{idx}"] = ( + "Encounter value when parsing the LM response, please try fixing based on the error message." + ) + trajectory[f"tool_calls_{idx}"] = None + trajectory[f"observation_{idx}"] = [f"Error: {err}"] + continue except ValueError as err: logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") break @@ -129,6 +137,13 @@ async def aforward(self, **input_args): for idx in range(max_iters): try: pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) + except ValidationError as err: + trajectory[f"thought_{idx}"] = ( + "Encounter value when parsing the LM response, please try fixing based on the error message." + ) + trajectory[f"tool_calls_{idx}"] = None + trajectory[f"observation_{idx}"] = [f"Error: {err}"] + continue except ValueError as err: logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") break