Skip to content

Commit 1df6768

Browse files
Support async path for ChatAdapter and JSONAdapter (#8419)
* init * Add async path to chatadapter and jsonadapter call * fix tests
1 parent dc64c67 commit 1df6768

File tree

4 files changed

+313
-20
lines changed

4 files changed

+313
-20
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,26 @@ def __call__(
5050
raise e
5151
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)
5252

53+
async def acall(
54+
self,
55+
lm: LM,
56+
lm_kwargs: dict[str, Any],
57+
signature: Type[Signature],
58+
demos: list[dict[str, Any]],
59+
inputs: dict[str, Any],
60+
) -> list[dict[str, Any]]:
61+
try:
62+
return await super().acall(lm, lm_kwargs, signature, demos, inputs)
63+
except Exception as e:
64+
# fallback to JSONAdapter
65+
from dspy.adapters.json_adapter import JSONAdapter
66+
67+
if isinstance(e, ContextWindowExceededError) or isinstance(self, JSONAdapter):
68+
# On context window exceeded error or already using JSONAdapter, we don't want to retry with a different
69+
# adapter.
70+
raise e
71+
return await JSONAdapter().acall(lm, lm_kwargs, signature, demos, inputs)
72+
5373
def format_field_description(self, signature: Type[Signature]) -> str:
5474
return (
5575
f"Your input fields are:\n{get_field_description_string(signature.input_fields)}\n"

dspy/adapters/json_adapter.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222

2323
logger = logging.getLogger(__name__)
2424

25+
ERROR_MESSAGE_ON_JSON_ADAPTER_FAILURE = (
26+
"Both structured output format and JSON mode failed. Please choose a model that supports "
27+
"`response_format` argument. Original error: {}"
28+
)
29+
2530

2631
def _has_open_ended_mapping(signature: SignatureMeta) -> bool:
2732
"""
@@ -37,6 +42,18 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool:
3742

3843

3944
class JSONAdapter(ChatAdapter):
45+
def _json_adapter_call_common(self, lm, lm_kwargs, signature, demos, inputs, call_fn):
46+
"""Common call logic to be used for both sync and async calls."""
47+
provider = lm.model.split("/", 1)[0] or "openai"
48+
params = litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider)
49+
50+
if not params or "response_format" not in params:
51+
return call_fn(lm, lm_kwargs, signature, demos, inputs)
52+
53+
if _has_open_ended_mapping(signature):
54+
lm_kwargs["response_format"] = {"type": "json_object"}
55+
return call_fn(lm, lm_kwargs, signature, demos, inputs)
56+
4057
def __call__(
4158
self,
4259
lm: LM,
@@ -45,37 +62,49 @@ def __call__(
4562
demos: list[dict[str, Any]],
4663
inputs: dict[str, Any],
4764
) -> list[dict[str, Any]]:
48-
provider = lm.model.split("/", 1)[0] or "openai"
49-
params = litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider)
65+
result = self._json_adapter_call_common(lm, lm_kwargs, signature, demos, inputs, super().__call__)
66+
if result:
67+
return result
5068

51-
# If response_format is not supported, use basic call
52-
if not params or "response_format" not in params:
69+
try:
70+
structured_output_model = _get_structured_outputs_response_format(signature)
71+
lm_kwargs["response_format"] = structured_output_model
5372
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
73+
except Exception:
74+
logger.warning("Failed to use structured output format, falling back to JSON mode.")
75+
try:
76+
lm_kwargs["response_format"] = {"type": "json_object"}
77+
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
78+
except AdapterParseError as e:
79+
raise e
80+
except Exception as e:
81+
raise RuntimeError(ERROR_MESSAGE_ON_JSON_ADAPTER_FAILURE.format(e)) from e
5482

55-
# Check early for open-ended mapping types before trying structured outputs.
56-
if _has_open_ended_mapping(signature):
57-
lm_kwargs["response_format"] = {"type": "json_object"}
58-
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
83+
async def acall(
84+
self,
85+
lm: LM,
86+
lm_kwargs: dict[str, Any],
87+
signature: Type[Signature],
88+
demos: list[dict[str, Any]],
89+
inputs: dict[str, Any],
90+
) -> list[dict[str, Any]]:
91+
result = self._json_adapter_call_common(lm, lm_kwargs, signature, demos, inputs, super().acall)
92+
if result:
93+
return await result
5994

60-
# Try structured output first, fall back to basic JSON if it fails.
6195
try:
6296
structured_output_model = _get_structured_outputs_response_format(signature)
6397
lm_kwargs["response_format"] = structured_output_model
64-
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
98+
return await super().acall(lm, lm_kwargs, signature, demos, inputs)
6599
except Exception:
66100
logger.warning("Failed to use structured output format, falling back to JSON mode.")
67101
try:
68102
lm_kwargs["response_format"] = {"type": "json_object"}
69-
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
103+
return await super().acall(lm, lm_kwargs, signature, demos, inputs)
70104
except AdapterParseError as e:
71-
# On AdapterParseError, we raise the original error.
72105
raise e
73106
except Exception as e:
74-
# On any other error, we raise a RuntimeError with the original error message.
75-
raise RuntimeError(
76-
"Both structured output format and JSON mode failed. Please choose a model that supports "
77-
f"`response_format` argument. Original error: {e}"
78-
) from e
107+
raise RuntimeError(ERROR_MESSAGE_ON_JSON_ADAPTER_FAILURE.format(e)) from e
79108

80109
def _call_preprocess(
81110
self,

tests/adapters/test_chat_adapter.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pydantic
55
import pytest
6+
from litellm.utils import Choices, Message, ModelResponse
67

78
import dspy
89

@@ -376,3 +377,48 @@ class MySignature(dspy.Signature):
376377
assert messages[2]["content"] == "[[ ## answer ## ]]\nParis\n\n[[ ## completed ## ]]\n"
377378
assert messages[3]["content"] == "[[ ## question ## ]]\nWhat is the capital of Germany?"
378379
assert messages[4]["content"] == "[[ ## answer ## ]]\nBerlin\n\n[[ ## completed ## ]]\n"
380+
381+
382+
def test_chat_adapter_fallback_to_json_adapter_on_exception():
383+
signature = dspy.make_signature("question->answer")
384+
adapter = dspy.ChatAdapter()
385+
386+
with mock.patch("litellm.completion") as mock_completion:
387+
# Mock returning a response compatible with JSONAdapter but not ChatAdapter
388+
mock_completion.return_value = ModelResponse(
389+
choices=[Choices(message=Message(content="{'answer': 'Paris'}"))],
390+
model="openai/gpt-4o-mini",
391+
)
392+
393+
lm = dspy.LM("openai/gpt-4o-mini", cache=False)
394+
395+
with mock.patch("dspy.adapters.json_adapter.JSONAdapter.__call__") as mock_json_adapter_call:
396+
adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
397+
mock_json_adapter_call.assert_called_once()
398+
399+
# The parse should succeed
400+
result = adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
401+
assert result == [{"answer": "Paris"}]
402+
403+
404+
@pytest.mark.asyncio
405+
async def test_chat_adapter_fallback_to_json_adapter_on_exception_async():
406+
signature = dspy.make_signature("question->answer")
407+
adapter = dspy.ChatAdapter()
408+
409+
with mock.patch("litellm.acompletion") as mock_completion:
410+
# Mock returning a response compatible with JSONAdapter but not ChatAdapter
411+
mock_completion.return_value = ModelResponse(
412+
choices=[Choices(message=Message(content="{'answer': 'Paris'}"))],
413+
model="openai/gpt-4o-mini",
414+
)
415+
416+
lm = dspy.LM("openai/gpt-4o-mini", cache=False)
417+
418+
with mock.patch("dspy.adapters.json_adapter.JSONAdapter.acall") as mock_json_adapter_acall:
419+
await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
420+
mock_json_adapter_acall.assert_called_once()
421+
422+
# The parse should succeed
423+
result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
424+
assert result == [{"answer": "Paris"}]

0 commit comments

Comments
 (0)