File tree 3 files changed +38
-3
lines changed 3 files changed +38
-3
lines changed Original file line number Diff line number Diff line change @@ -83,9 +83,9 @@ def __init__(
83
83
84
84
if model_pattern :
85
85
# Handle OpenAI reasoning models (o1, o3)
86
- assert (
87
- max_tokens >= 20_000 and temperature == 1.0
88
- ), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`"
86
+ assert max_tokens >= 20_000 and temperature == 1.0 , (
87
+ "OpenAI's reasoning models require passing temperature= 1.0 and max_tokens >= 20_000 to `dspy.LM(...)`"
88
+ )
89
89
self .kwargs = dict (temperature = temperature , max_completion_tokens = max_tokens , ** kwargs )
90
90
else :
91
91
self .kwargs = dict (temperature = temperature , max_tokens = max_tokens , ** kwargs )
Original file line number Diff line number Diff line change @@ -108,6 +108,17 @@ def _forward_preprocess(self, **kwargs):
108
108
if (temperature is None or temperature <= 0.15 ) and num_generations > 1 :
109
109
config ["temperature" ] = 0.7
110
110
111
+ if "prediction" in kwargs :
112
+ if (
113
+ isinstance (kwargs ["prediction" ], dict )
114
+ and kwargs ["prediction" ].get ("type" ) == "content"
115
+ and "content" in kwargs ["prediction" ]
116
+ ):
117
+ # If the `prediction` is the standard predicted outputs format
118
+ # (https://platform.openai.com/docs/guides/predicted-outputs), we remvoe it from input kwargs and add it
119
+ # to the lm kwargs.
120
+ config ["prediction" ] = kwargs .pop ("prediction" )
121
+
111
122
if not all (k in kwargs for k in signature .input_fields ):
112
123
present = [k for k in signature .input_fields if k in kwargs ]
113
124
missing = [k for k in signature .input_fields if k not in kwargs ]
Original file line number Diff line number Diff line change @@ -572,3 +572,27 @@ async def test_async_predict():
572
572
dspy .settings .configure (lm = DummyLM ([{"answer" : "Paris" }]))
573
573
result = await program .acall (question = "What is the capital of France?" )
574
574
assert result .answer == "Paris"
575
+
576
+
577
+ def test_predicted_outputs_piped_from_predict_to_lm_call ():
578
+ program = Predict ("question -> answer" )
579
+ dspy .settings .configure (lm = dspy .LM ("openai/gpt-4o-mini" ))
580
+
581
+ with patch ("litellm.completion" ) as mock_completion :
582
+ program (
583
+ question = "Why did a chicken cross the kitchen?" ,
584
+ prediction = {"type" : "content" , "content" : "A chicken crossing the kitchen" },
585
+ )
586
+
587
+ assert mock_completion .call_args [1 ]["prediction" ] == {
588
+ "type" : "content" ,
589
+ "content" : "A chicken crossing the kitchen" ,
590
+ }
591
+
592
+ # If the signature has prediction as an input field, and the prediction is not set as the standard predicted output
593
+ # format, it should not be passed to the LM.
594
+ program = Predict ("question, prediction -> judgement" )
595
+ with patch ("litellm.completion" ) as mock_completion :
596
+ program (question = "Why did a chicken cross the kitchen?" , prediction = "To get to the other side!" )
597
+
598
+ assert "prediction" not in mock_completion .call_args [1 ]
You can’t perform that action at this time.
0 commit comments