Skip to content

Commit d3abb63

Browse files
support predicted output
1 parent 015f795 commit d3abb63

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

dspy/clients/lm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def __init__(
8383

8484
if model_pattern:
8585
# Handle OpenAI reasoning models (o1, o3)
86-
assert (
87-
max_tokens >= 20_000 and temperature == 1.0
88-
), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`"
86+
assert max_tokens >= 20_000 and temperature == 1.0, (
87+
"OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`"
88+
)
8989
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
9090
else:
9191
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

dspy/predict/predict.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ def _forward_preprocess(self, **kwargs):
108108
if (temperature is None or temperature <= 0.15) and num_generations > 1:
109109
config["temperature"] = 0.7
110110

111+
if "prediction" in kwargs:
112+
if (
113+
isinstance(kwargs["prediction"], dict)
114+
and kwargs["prediction"].get("type") == "content"
115+
and "content" in kwargs["prediction"]
116+
):
117+
# If the `prediction` is the standard predicted outputs format
118+
# (https://platform.openai.com/docs/guides/predicted-outputs), we remvoe it from input kwargs and add it
119+
# to the lm kwargs.
120+
config["prediction"] = kwargs.pop("prediction")
121+
111122
if not all(k in kwargs for k in signature.input_fields):
112123
present = [k for k in signature.input_fields if k in kwargs]
113124
missing = [k for k in signature.input_fields if k not in kwargs]

tests/predict/test_predict.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,27 @@ async def test_async_predict():
572572
dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]))
573573
result = await program.acall(question="What is the capital of France?")
574574
assert result.answer == "Paris"
575+
576+
577+
def test_predicted_outputs_piped_from_predict_to_lm_call():
578+
program = Predict("question -> answer")
579+
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))
580+
581+
with patch("litellm.completion") as mock_completion:
582+
program(
583+
question="Why did a chicken cross the kitchen?",
584+
prediction={"type": "content", "content": "A chicken crossing the kitchen"},
585+
)
586+
587+
assert mock_completion.call_args[1]["prediction"] == {
588+
"type": "content",
589+
"content": "A chicken crossing the kitchen",
590+
}
591+
592+
# If the signature has prediction as an input field, and the prediction is not set as the standard predicted output
593+
# format, it should not be passed to the LM.
594+
program = Predict("question, prediction -> judgement")
595+
with patch("litellm.completion") as mock_completion:
596+
program(question="Why did a chicken cross the kitchen?", prediction="To get to the other side!")
597+
598+
assert "prediction" not in mock_completion.call_args[1]

0 commit comments

Comments
 (0)