Skip to content

Commit 734eff2

Browse files
allow error to pass through in JSONAdapter (#8445)
1 parent 0d0824e commit 734eff2

File tree

2 files changed

+20
-25
lines changed

2 files changed

+20
-25
lines changed

dspy/adapters/json_adapter.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@
2222

2323
logger = logging.getLogger(__name__)
2424

25-
ERROR_MESSAGE_ON_JSON_ADAPTER_FAILURE = (
26-
"Both structured output format and JSON mode failed. Please choose a model that supports "
27-
"`response_format` argument. Original error: {}"
28-
)
29-
3025

3126
def _has_open_ended_mapping(signature: SignatureMeta) -> bool:
3227
"""
@@ -72,13 +67,8 @@ def __call__(
7267
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
7368
except Exception:
7469
logger.warning("Failed to use structured output format, falling back to JSON mode.")
75-
try:
76-
lm_kwargs["response_format"] = {"type": "json_object"}
77-
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
78-
except AdapterParseError as e:
79-
raise e
80-
except Exception as e:
81-
raise RuntimeError(ERROR_MESSAGE_ON_JSON_ADAPTER_FAILURE.format(e)) from e
70+
lm_kwargs["response_format"] = {"type": "json_object"}
71+
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
8272

8373
async def acall(
8474
self,
@@ -98,13 +88,8 @@ async def acall(
9888
return await super().acall(lm, lm_kwargs, signature, demos, inputs)
9989
except Exception:
10090
logger.warning("Failed to use structured output format, falling back to JSON mode.")
101-
try:
102-
lm_kwargs["response_format"] = {"type": "json_object"}
103-
return await super().acall(lm, lm_kwargs, signature, demos, inputs)
104-
except AdapterParseError as e:
105-
raise e
106-
except Exception as e:
107-
raise RuntimeError(ERROR_MESSAGE_ON_JSON_ADAPTER_FAILURE.format(e)) from e
91+
lm_kwargs["response_format"] = {"type": "json_object"}
92+
return await super().acall(lm, lm_kwargs, signature, demos, inputs)
10893

10994
def _call_preprocess(
11095
self,

tests/adapters/test_json_adapter.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from litellm.utils import Choices, Message, ModelResponse
66

77
import dspy
8-
from dspy.adapters.json_adapter import ERROR_MESSAGE_ON_JSON_ADAPTER_FAILURE
98

109

1110
def test_json_adapter_passes_structured_output_when_supported_by_model():
@@ -616,12 +615,18 @@ class TestSignature(dspy.Signature):
616615
dspy.configure(lm=dspy.LM(model="openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter())
617616

618617
with mock.patch("litellm.completion") as mock_completion:
619-
mock_completion.side_effect = RuntimeError("Failed!")
618+
mock_completion.side_effect = RuntimeError("RuntimeError!")
620619

621620
with pytest.raises(RuntimeError) as error:
622621
program(question="Dummy question!")
623622

624-
assert ERROR_MESSAGE_ON_JSON_ADAPTER_FAILURE[:50] in str(error.value)
623+
assert "RuntimeError!" in str(error.value)
624+
625+
mock_completion.side_effect = ValueError("ValueError!")
626+
with pytest.raises(ValueError) as error:
627+
program(question="Dummy question!")
628+
629+
assert "ValueError!" in str(error.value)
625630

626631

627632
@pytest.mark.asyncio
@@ -633,10 +638,15 @@ class TestSignature(dspy.Signature):
633638
program = dspy.Predict(TestSignature)
634639

635640
with mock.patch("litellm.acompletion") as mock_acompletion:
636-
mock_acompletion.side_effect = RuntimeError("Failed!")
637-
638641
with dspy.context(lm=dspy.LM(model="openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
642+
mock_acompletion.side_effect = RuntimeError("RuntimeError!")
639643
with pytest.raises(RuntimeError) as error:
640644
await program.acall(question="Dummy question!")
641645

642-
assert ERROR_MESSAGE_ON_JSON_ADAPTER_FAILURE[:50] in str(error.value)
646+
assert "RuntimeError!" in str(error.value)
647+
648+
mock_acompletion.side_effect = ValueError("ValueError!")
649+
with pytest.raises(ValueError) as error:
650+
await program.acall(question="Dummy question!")
651+
652+
assert "ValueError!" in str(error.value)

0 commit comments

Comments
 (0)