Skip to content

Commit 5473f11

Browse files
committed
Support late-configuration of LM to fix certain tests.
1 parent 8f2d726 commit 5473f11

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

dspy/predict/chain_of_thought.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pydantic.fields import FieldInfo
44

55
import dspy
6+
from dspy.clients.base_lm import BaseLM
67
from dspy.primitives.module import Module
78
from dspy.signatures.field import OutputField
89
from dspy.signatures.signature import Signature, ensure_signature
@@ -27,18 +28,42 @@ def __init__(
2728
"""
2829
super().__init__()
2930
signature = ensure_signature(signature)
30-
if dspy.settings.lm.reasoning_model:
31-
self.predict = dspy.Predict(signature, **config)
32-
else:
33-
prefix = "Reasoning: Let's think step by step in order to"
34-
desc = "${reasoning}"
35-
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
36-
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
37-
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
38-
self.predict = dspy.Predict(extended_signature, **config)
31+
self.predict_reasoning = dspy.Predict(signature, **config)
32+
prefix = "Reasoning: Let's think step by step in order to"
33+
desc = "${reasoning}"
34+
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
35+
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
36+
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
37+
self.predict = dspy.Predict(extended_signature, **config)
38+
39+
def _validate_lm(self, **kwargs):
40+
"""Helper method to validate the LM configuration."""
41+
lm = kwargs.get("lm", dspy.settings.lm)
42+
43+
if lm is None:
44+
raise ValueError(
45+
"No LM is loaded. Please configure the LM using `dspy.configure(lm=dspy.LM(...))`. e.g, "
46+
"`dspy.configure(lm=dspy.LM('openai/gpt-4o-mini'))`"
47+
)
48+
49+
if isinstance(lm, str):
50+
# Many users mistakenly use `dspy.configure(lm="openai/gpt-4o-mini")` instead of
51+
# `dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))`, so we are providing a specific error message.
52+
raise ValueError(
53+
f"LM must be an instance of `dspy.BaseLM`, not a string. Instead of using a string like "
54+
f"'dspy.configure(lm=\"{lm}\")', please configure the LM like 'dspy.configure(lm=dspy.LM(\"{lm}\"))'"
55+
)
56+
elif not isinstance(lm, BaseLM):
57+
raise ValueError(f"LM must be an instance of `dspy.BaseLM`, not {type(lm)}. Received `lm={lm}`.")
58+
59+
return lm
3960

4061
def forward(self, **kwargs):
41-
return self.predict(**kwargs)
62+
lm = self._validate_lm(**kwargs)
63+
return self.predict(**kwargs) if not lm.reasoning_model else self.predict_reasoning(**kwargs)
4264

4365
async def aforward(self, **kwargs):
44-
return await self.predict.acall(**kwargs)
66+
lm = self._validate_lm(**kwargs)
67+
return await (
68+
self.predict.acall(**kwargs) if not lm.reasoning_model else self.predict_reasoning.acall(**kwargs)
69+
)

tests/predict/test_chain_of_thought.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_cot_skips_with_reasoning_model():
3131
dspy.settings.configure(lm=lm)
3232
signature = dspy.Signature("question -> answer")
3333
predict = ChainOfThought(signature)
34-
assert list(predict.predict.signature.output_fields.keys()) == [
34+
assert list(predict.predict_reasoning.signature.output_fields.keys()) == [
3535
"answer",
3636
]
3737
assert predict(question="What is 1+1?").answer == "2"

0 commit comments

Comments
 (0)