Skip to content

[WIP] React use dspy.ToolCalls #8472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@


class Adapter:
def __init__(self, callbacks: list[BaseCallback] | None = None):
def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False):
self.callbacks = callbacks or []
self.use_native_function_calling = use_native_function_calling

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
Expand All @@ -33,9 +34,8 @@ def _call_preprocess(
lm_kwargs: dict[str, Any],
signature: Type[Signature],
inputs: dict[str, Any],
use_native_function_calling: bool = False,
) -> dict[str, Any]:
if use_native_function_calling:
if self.use_native_function_calling:
tool_call_input_field_name = self._get_tool_call_input_field_name(signature)
tool_call_output_field_name = self._get_tool_call_output_field_name(signature)

Expand All @@ -57,19 +57,23 @@ def _call_preprocess(
lm_kwargs["tools"] = litellm_tools

signature_for_native_function_calling = signature.delete(tool_call_output_field_name)
signature_for_native_function_calling = signature_for_native_function_calling.delete(
tool_call_input_field_name
)

return signature_for_native_function_calling

return signature

def _call_postprocess(
self,
signature: Type[Signature],
processed_signature: Type[Signature],
original_signature: Type[Signature],
outputs: list[dict[str, Any]],
) -> list[dict[str, Any]]:
values = []

tool_call_output_field_name = self._get_tool_call_output_field_name(signature)
tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature)

for output in outputs:
output_logprobs = None
Expand All @@ -82,10 +86,14 @@ def _call_postprocess(
tool_calls = output.get("tool_calls")

if text:
value = self.parse(signature, text)
value = self.parse(processed_signature, text)
for field_name in original_signature.output_fields.keys():
if field_name not in value:
# We need to set the field not present in the processed signature to None for consistency.
value[field_name] = None
else:
value = {}
for field_name in signature.output_fields.keys():
for field_name in original_signature.output_fields.keys():
value[field_name] = None

if tool_calls and tool_call_output_field_name:
Expand Down Expand Up @@ -117,7 +125,7 @@ def __call__(
inputs = self.format(processed_signature, demos, inputs)

outputs = lm(messages=inputs, **lm_kwargs)
return self._call_postprocess(signature, outputs)
return self._call_postprocess(processed_signature, signature, outputs)

async def acall(
self,
Expand All @@ -131,7 +139,7 @@ async def acall(
inputs = self.format(processed_signature, demos, inputs)

outputs = await lm.acall(messages=inputs, **lm_kwargs)
return self._call_postprocess(signature, outputs)
return self._call_postprocess(processed_signature, signature, outputs)

def format(
self,
Expand Down
4 changes: 0 additions & 4 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)
from dspy.clients.lm import LM
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback
from dspy.utils.exceptions import AdapterParseError

field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")
Expand All @@ -27,9 +26,6 @@ class FieldInfoWithName(NamedTuple):


class ChatAdapter(Adapter):
def __init__(self, callbacks: list[BaseCallback] | None = None):
super().__init__(callbacks)

def __call__(
self,
lm: LM,
Expand Down
33 changes: 20 additions & 13 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic.fields import FieldInfo

from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName
from dspy.adapters.types.tool import ToolCalls
from dspy.adapters.utils import (
format_field_value,
get_annotation_name,
Expand All @@ -18,6 +19,7 @@
)
from dspy.clients.lm import LM
from dspy.signatures.signature import Signature, SignatureMeta
from dspy.utils.callback import BaseCallback
from dspy.utils.exceptions import AdapterParseError

logger = logging.getLogger(__name__)
Expand All @@ -37,6 +39,10 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool:


class JSONAdapter(ChatAdapter):
def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = True):
# JSONAdapter uses native function calling by default.
super().__init__(callbacks=callbacks, use_native_function_calling=use_native_function_calling)

def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn):
"""Common call logic to be used for both sync and async calls."""
provider = lm.model.split("/", 1)[0] or "openai"
Expand All @@ -45,7 +51,10 @@ def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, cal
if not params or "response_format" not in params:
return call_fn(lm, lm_kwargs, signature, demos, inputs)

if _has_open_ended_mapping(signature):
has_tool_calls = any(field.annotation == ToolCalls for field in signature.output_fields.values())
if _has_open_ended_mapping(signature) or (not self.use_native_function_calling and has_tool_calls):
# We found that structured output mode doesn't work well with dspy.ToolCalls as output field.
# So we fall back to json mode if native function calling is disabled and ToolCalls is present.
lm_kwargs["response_format"] = {"type": "json_object"}
return call_fn(lm, lm_kwargs, signature, demos, inputs)

Expand All @@ -62,7 +71,9 @@ def __call__(
return result

try:
structured_output_model = _get_structured_outputs_response_format(signature)
structured_output_model = _get_structured_outputs_response_format(
signature, self.use_native_function_calling
)
lm_kwargs["response_format"] = structured_output_model
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
except Exception:
Expand Down Expand Up @@ -91,16 +102,6 @@ async def acall(
lm_kwargs["response_format"] = {"type": "json_object"}
return await super().acall(lm, lm_kwargs, signature, demos, inputs)

def _call_preprocess(
self,
lm: "LM",
lm_kwargs: dict[str, Any],
signature: Type[Signature],
inputs: dict[str, Any],
use_native_function_calling: bool = True,
) -> dict[str, Any]:
return super()._call_preprocess(lm, lm_kwargs, signature, inputs, use_native_function_calling)

def format_field_structure(self, signature: Type[Signature]) -> str:
parts = []
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")
Expand Down Expand Up @@ -206,7 +207,10 @@ def format_finetune_data(
raise NotImplementedError


def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[pydantic.BaseModel]:
def _get_structured_outputs_response_format(
signature: SignatureMeta,
use_native_function_calling: bool = True,
) -> type[pydantic.BaseModel]:
"""
Builds a Pydantic model from a DSPy signature's output_fields and ensures the generated JSON schema
is compatible with OpenAI Structured Outputs (all objects have a "required" key listing every property,
Expand All @@ -227,6 +231,9 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py
fields = {}
for name, field in signature.output_fields.items():
annotation = field.annotation
if use_native_function_calling and annotation == ToolCalls:
# Skip ToolCalls field if native function calling is enabled.
continue
default = field.default if hasattr(field, "default") else ...
fields[name] = (annotation, default)

Expand Down
3 changes: 2 additions & 1 deletion dspy/adapters/two_step_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class TwoStepAdapter(Adapter):
```
"""

def __init__(self, extraction_model: LM):
def __init__(self, extraction_model: LM, **kwargs):
super().__init__(**kwargs)
if not isinstance(extraction_model, LM):
raise ValueError("extraction_model must be an instance of LM")
self.extraction_model = extraction_model
Expand Down
11 changes: 9 additions & 2 deletions dspy/adapters/types/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, get_origin, get_type_hints

import pydantic
from jsonschema import ValidationError, validate
from pydantic import BaseModel, TypeAdapter, create_model

Expand Down Expand Up @@ -288,7 +289,7 @@ def from_dict_list(cls, tool_calls_dicts: list[dict[str, Any]]) -> "ToolCalls":
def description(cls) -> str:
return (
"Tool calls information, including the name of the tools and the arguments to be passed to it. "
"Arguments must be provided in JSON format."
'`args` must be provided in JSON format, e.g., `{"arg1": "value1", "arg2": "value2"}`'
)

def format(self) -> list[dict[str, Any]]:
Expand All @@ -303,11 +304,17 @@ def format(self) -> list[dict[str, Any]]:
"name": tool_call.name,
"arguments": tool_call.args,
},
} for tool_call in self.tool_calls
}
for tool_call in self.tool_calls
],
}
]

@pydantic.model_serializer()
def serialize_model(self):
"""Override so that the ToolCalls are not formatted as a message array when using as inputs."""
return self.format()


def _resolve_json_schema_reference(schema: dict) -> dict:
"""Recursively resolve json model schema, expanding all references."""
Expand Down
75 changes: 50 additions & 25 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from typing import TYPE_CHECKING, Any, Callable, Literal, Type
from typing import TYPE_CHECKING, Any, Callable, Type

from litellm import ContextWindowExceededError
from pydantic import ValidationError

import dspy
from dspy.adapters.types.tool import Tool
from dspy.adapters.types.tool import Tool, ToolCalls
from dspy.primitives.module import Module
from dspy.signatures.signature import ensure_signature

Expand Down Expand Up @@ -53,10 +54,10 @@ def get_weather(city: str) -> str:
[
f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.",
f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n",
"To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.",
"To do this, you will interleave next_thought and next_tool_calls in each turn, and also when finishing the task.",
"After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n",
"When writing next_thought, you may reason about the current situation and plan for future steps.",
"When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n",
"When selecting the tools to call, the tool must be one of:\n",
]
)

Expand All @@ -69,14 +70,14 @@ def get_weather(city: str) -> str:

for idx, tool in enumerate(tools.values()):
instr.append(f"({idx + 1}) {tool}")
instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format")
instr.append("When providing `next_tool_calls`, the args value inside the field must be in JSON format")

react_signature = (
dspy.Signature({**signature.input_fields}, "\n".join(instr))
.append("trajectory", dspy.InputField(), type_=str)
.append("tools", dspy.InputField(), type_=list[Tool])
.append("next_thought", dspy.OutputField(), type_=str)
.append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())])
.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any])
.append("next_tool_calls", dspy.OutputField(), type_=ToolCalls)
)

fallback_signature = dspy.Signature(
Expand All @@ -96,23 +97,34 @@ def _format_trajectory(self, trajectory: dict[str, Any]):
def forward(self, **input_args):
trajectory = {}
max_iters = input_args.pop("max_iters", self.max_iters)
dspy.settings.adapter.use_native_function_calling = False
for idx in range(max_iters):
try:
pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
except ValidationError as err:
trajectory[f"thought_{idx}"] = (
"Encounter value when parsing the LM response, please try fixing based on the error message."
)
trajectory[f"tool_calls_{idx}"] = None
trajectory[f"observation_{idx}"] = [f"Error: {err}"]
continue
except ValueError as err:
logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}")
break

trajectory[f"thought_{idx}"] = pred.next_thought
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
except Exception as err:
trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"

if pred.next_tool_name == "finish":
trajectory[f"tool_calls_{idx}"] = str(pred.next_tool_calls)
trajectory[f"observation_{idx}"] = []

tool_calls = [] if pred.next_tool_calls is None else pred.next_tool_calls.tool_calls
tool_names = [tool_call.name for tool_call in tool_calls]
for tool_call in tool_calls:
try:
trajectory[f"observation_{idx}"].append(f"{self.tools[tool_call.name](**tool_call.args)}")
except Exception as err:
trajectory[f"observation_{idx}"].append(f"Execution error in {tool_call.name}: {_fmt_exc(err)}")

if "finish" in tool_names or len(tool_names) == 0:
break

extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
Expand All @@ -121,23 +133,35 @@ def forward(self, **input_args):
async def aforward(self, **input_args):
trajectory = {}
max_iters = input_args.pop("max_iters", self.max_iters)
dspy.settings.adapter.use_native_function_calling = False
for idx in range(max_iters):
try:
pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
except ValidationError as err:
trajectory[f"thought_{idx}"] = (
"Encounter value when parsing the LM response, please try fixing based on the error message."
)
trajectory[f"tool_calls_{idx}"] = None
trajectory[f"observation_{idx}"] = [f"Error: {err}"]
continue
except ValueError as err:
logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}")
break

trajectory[f"thought_{idx}"] = pred.next_thought
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args)
except Exception as err:
trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"

if pred.next_tool_name == "finish":
trajectory[f"tool_calls_{idx}"] = str(pred.next_tool_calls)
trajectory[f"observation_{idx}"] = []

tool_calls = [] if pred.next_tool_calls is None else pred.next_tool_calls.tool_calls
tool_names = [tool_call.name for tool_call in tool_calls]
for tool_call in tool_calls:
try:
result = await self.tools[tool_call.name].acall(**tool_call.args)
trajectory[f"observation_{idx}"].append(result)
except Exception as err:
trajectory[f"observation_{idx}"].append(f"Execution error in {tool_call.name}: {_fmt_exc(err)}")

if "finish" in tool_names or len(tool_names) == 0:
break

extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
Expand All @@ -149,6 +173,7 @@ def _call_with_potential_trajectory_truncation(self, module, trajectory, **input
return module(
**input_args,
trajectory=self._format_trajectory(trajectory),
tools=list(self.tools.values()),
)
except ContextWindowExceededError:
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
Expand Down
Loading
Loading