Skip to content

Commit 903e7e4

Browse files
committed
Fix circular import and serialization issues.
1 parent ad9f5c1 commit 903e7e4

File tree

2 files changed

+52
-35
lines changed

2 files changed

+52
-35
lines changed

dspy/predict/chain_of_thought.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
from functools import cached_property
12
from typing import Any, Optional, Type, Union
23

34
from pydantic.fields import FieldInfo
45

56
import dspy
6-
from dspy.clients.base_lm import BaseLM
77
from dspy.primitives.module import Module
88
from dspy.signatures.field import OutputField
99
from dspy.signatures.signature import Signature, ensure_signature
@@ -27,43 +27,60 @@ def __init__(
2727
**config: The configuration for the module.
2828
"""
2929
super().__init__()
30-
signature = ensure_signature(signature)
31-
self.predict_reasoning = dspy.Predict(signature, **config)
32-
prefix = "Reasoning: Let's think step by step in order to"
33-
desc = "${reasoning}"
34-
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
35-
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
36-
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
37-
self.predict = dspy.Predict(extended_signature, **config)
30+
self._signature = ensure_signature(signature)
31+
self._config = config
32+
self._rationale_field = rationale_field
33+
self._rationale_field_type = rationale_field_type
3834

39-
def _validate_lm(self, **kwargs):
40-
"""Helper method to validate the LM configuration."""
41-
lm = kwargs.get("lm", dspy.settings.lm)
42-
43-
if lm is None:
44-
raise ValueError(
45-
"No LM is loaded. Please configure the LM using `dspy.configure(lm=dspy.LM(...))`. e.g, "
46-
"`dspy.configure(lm=dspy.LM('openai/gpt-4o-mini'))`"
35+
@cached_property
36+
def predict(self):
37+
"""Returns the appropriate predict instance based on the LM's reasoning model capability."""
38+
lm = dspy.settings.lm
39+
if lm and getattr(lm, "reasoning_model", False):
40+
return dspy.Predict(self._signature, **self._config)
41+
else:
42+
prefix = "Reasoning: Let's think step by step in order to"
43+
desc = "${reasoning}"
44+
rationale_field_type = (
45+
self._rationale_field.annotation if self._rationale_field else self._rationale_field_type
4746
)
48-
49-
if isinstance(lm, str):
50-
# Many users mistakenly use `dspy.configure(lm="openai/gpt-4o-mini")` instead of
51-
# `dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))`, so we are providing a specific error message.
52-
raise ValueError(
53-
f"LM must be an instance of `dspy.BaseLM`, not a string. Instead of using a string like "
54-
f"'dspy.configure(lm=\"{lm}\")', please configure the LM like 'dspy.configure(lm=dspy.LM(\"{lm}\"))'"
47+
rationale_field = (
48+
self._rationale_field if self._rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
5549
)
56-
elif not isinstance(lm, BaseLM):
57-
raise ValueError(f"LM must be an instance of `dspy.BaseLM`, not {type(lm)}. Received `lm={lm}`.")
58-
59-
return lm
50+
extended_signature = self._signature.prepend(
51+
name="reasoning", field=rationale_field, type_=rationale_field_type
52+
)
53+
return dspy.Predict(extended_signature, **self._config)
6054

6155
def forward(self, **kwargs):
62-
lm = self._validate_lm(**kwargs)
63-
return self.predict(**kwargs) if not lm.reasoning_model else self.predict_reasoning(**kwargs)
56+
return self.predict(**kwargs)
6457

6558
async def aforward(self, **kwargs):
66-
lm = self._validate_lm(**kwargs)
67-
return await (
68-
self.predict.acall(**kwargs) if not lm.reasoning_model else self.predict_reasoning.acall(**kwargs)
69-
)
59+
return await self.predict.acall(**kwargs)
60+
61+
def load_state(self, state):
62+
"""Override to ensure predict parameter is created before loading state."""
63+
# If predict state exists but predict hasn't been accessed yet, access it first
64+
if "predict" in state and "predict" not in self.__dict__:
65+
_ = self.predict # This creates the predict instance
66+
67+
# Now call the base load_state which will load into all named_parameters
68+
return super().load_state(state)
69+
70+
def __setstate__(self, state):
71+
"""Custom deserialization for cloudpickle to preserve predict instance."""
72+
# Restore the state normally
73+
self.__dict__.update(state)
74+
75+
# If predict was cached and serialized, we don't need to do anything special
76+
# since cloudpickle should have preserved it correctly
77+
78+
def __getstate__(self):
79+
"""Custom serialization for cloudpickle to ensure predict instance is preserved."""
80+
state = self.__dict__.copy()
81+
# Force evaluation of cached property if not already done
82+
if "predict" not in state:
83+
# Access the predict property to cache it before serialization
84+
_ = self.predict
85+
state = self.__dict__.copy()
86+
return state

tests/predict/test_chain_of_thought.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_cot_skips_with_reasoning_model():
3131
dspy.settings.configure(lm=lm)
3232
signature = dspy.Signature("question -> answer")
3333
predict = ChainOfThought(signature)
34-
assert list(predict.predict_reasoning.signature.output_fields.keys()) == [
34+
assert list(predict.predict.signature.output_fields.keys()) == [
3535
"answer",
3636
]
3737
assert predict(question="What is 1+1?").answer == "2"

0 commit comments

Comments
 (0)