1
+ from functools import cached_property
1
2
from typing import Any , Optional , Type , Union
2
3
3
4
from pydantic .fields import FieldInfo
4
5
5
6
import dspy
6
- from dspy .clients .base_lm import BaseLM
7
7
from dspy .primitives .module import Module
8
8
from dspy .signatures .field import OutputField
9
9
from dspy .signatures .signature import Signature , ensure_signature
@@ -27,43 +27,60 @@ def __init__(
27
27
**config: The configuration for the module.
28
28
"""
29
29
super ().__init__ ()
30
- signature = ensure_signature (signature )
31
- self .predict_reasoning = dspy .Predict (signature , ** config )
32
- prefix = "Reasoning: Let's think step by step in order to"
33
- desc = "${reasoning}"
34
- rationale_field_type = rationale_field .annotation if rationale_field else rationale_field_type
35
- rationale_field = rationale_field if rationale_field else dspy .OutputField (prefix = prefix , desc = desc )
36
- extended_signature = signature .prepend (name = "reasoning" , field = rationale_field , type_ = rationale_field_type )
37
- self .predict = dspy .Predict (extended_signature , ** config )
30
+ self ._signature = ensure_signature (signature )
31
+ self ._config = config
32
+ self ._rationale_field = rationale_field
33
+ self ._rationale_field_type = rationale_field_type
38
34
39
- def _validate_lm (self , ** kwargs ):
40
- """Helper method to validate the LM configuration."""
41
- lm = kwargs .get ("lm" , dspy .settings .lm )
42
-
43
- if lm is None :
44
- raise ValueError (
45
- "No LM is loaded. Please configure the LM using `dspy.configure(lm=dspy.LM(...))`. e.g, "
46
- "`dspy.configure(lm=dspy.LM('openai/gpt-4o-mini'))`"
35
+ @cached_property
36
+ def predict (self ):
37
+ """Returns the appropriate predict instance based on the LM's reasoning model capability."""
38
+ lm = dspy .settings .lm
39
+ if lm and getattr (lm , "reasoning_model" , False ):
40
+ return dspy .Predict (self ._signature , ** self ._config )
41
+ else :
42
+ prefix = "Reasoning: Let's think step by step in order to"
43
+ desc = "${reasoning}"
44
+ rationale_field_type = (
45
+ self ._rationale_field .annotation if self ._rationale_field else self ._rationale_field_type
47
46
)
48
-
49
- if isinstance (lm , str ):
50
- # Many users mistakenly use `dspy.configure(lm="openai/gpt-4o-mini")` instead of
51
- # `dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))`, so we are providing a specific error message.
52
- raise ValueError (
53
- f"LM must be an instance of `dspy.BaseLM`, not a string. Instead of using a string like "
54
- f"'dspy.configure(lm=\" { lm } \" )', please configure the LM like 'dspy.configure(lm=dspy.LM(\" { lm } \" ))'"
47
+ rationale_field = (
48
+ self ._rationale_field if self ._rationale_field else dspy .OutputField (prefix = prefix , desc = desc )
55
49
)
56
- elif not isinstance ( lm , BaseLM ):
57
- raise ValueError ( f"LM must be an instance of `dspy.BaseLM`, not { type ( lm ) } . Received `lm= { lm } `." )
58
-
59
- return lm
50
+ extended_signature = self . _signature . prepend (
51
+ name = "reasoning" , field = rationale_field , type_ = rationale_field_type
52
+ )
53
+ return dspy . Predict ( extended_signature , ** self . _config )
60
54
61
55
def forward (self , ** kwargs ):
62
- lm = self ._validate_lm (** kwargs )
63
- return self .predict (** kwargs ) if not lm .reasoning_model else self .predict_reasoning (** kwargs )
56
+ return self .predict (** kwargs )
64
57
65
58
async def aforward (self , ** kwargs ):
66
- lm = self ._validate_lm (** kwargs )
67
- return await (
68
- self .predict .acall (** kwargs ) if not lm .reasoning_model else self .predict_reasoning .acall (** kwargs )
69
- )
59
+ return await self .predict .acall (** kwargs )
60
+
61
+ def load_state (self , state ):
62
+ """Override to ensure predict parameter is created before loading state."""
63
+ # If predict state exists but predict hasn't been accessed yet, access it first
64
+ if "predict" in state and "predict" not in self .__dict__ :
65
+ _ = self .predict # This creates the predict instance
66
+
67
+ # Now call the base load_state which will load into all named_parameters
68
+ return super ().load_state (state )
69
+
70
+ def __setstate__ (self , state ):
71
+ """Custom deserialization for cloudpickle to preserve predict instance."""
72
+ # Restore the state normally
73
+ self .__dict__ .update (state )
74
+
75
+ # If predict was cached and serialized, we don't need to do anything special
76
+ # since cloudpickle should have preserved it correctly
77
+
78
+ def __getstate__ (self ):
79
+ """Custom serialization for cloudpickle to ensure predict instance is preserved."""
80
+ state = self .__dict__ .copy ()
81
+ # Force evaluation of cached property if not already done
82
+ if "predict" not in state :
83
+ # Access the predict property to cache it before serialization
84
+ _ = self .predict
85
+ state = self .__dict__ .copy ()
86
+ return state
0 commit comments