Skip to content

Commit 19eb50b

Browse files
committed
Implement CoT no-op for reasoning models
1 parent 734eff2 commit 19eb50b

File tree

5 files changed

+170
-151
lines changed

5 files changed

+170
-151
lines changed

dspy/clients/lm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
callbacks: Optional[List[BaseCallback]] = None,
3838
num_retries: int = 3,
3939
provider: Optional[Provider] = None,
40+
reasoning_model: Optional[bool] = None,
4041
finetuning_model: Optional[str] = None,
4142
launch_kwargs: Optional[dict[str, Any]] = None,
4243
train_kwargs: Optional[dict[str, Any]] = None,
@@ -51,6 +52,7 @@ def __init__(
5152
model_type: The type of the model, either ``"chat"`` or ``"text"``.
5253
temperature: The sampling temperature to use when generating responses.
5354
max_tokens: The maximum number of tokens to generate per response.
55+
reasoning_model: Whether the model is a reasoning model.
5456
cache: Whether to cache the model responses for reuse to improve performance
5557
and reduce costs.
5658
cache_in_memory (deprecated): To enable additional caching with LRU in memory.
@@ -71,6 +73,7 @@ def __init__(
7173
self.callbacks = callbacks or []
7274
self.history = []
7375
self.num_retries = num_retries
76+
self.reasoning_model = reasoning_model if reasoning_model is not None else litellm.supports_reasoning(model)
7477
self.finetuning_model = finetuning_model
7578
self.launch_kwargs = launch_kwargs or {}
7679
self.train_kwargs = train_kwargs or {}

dspy/predict/chain_of_thought.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ def __init__(
2727
"""
2828
super().__init__()
2929
signature = ensure_signature(signature)
30-
prefix = "Reasoning: Let's think step by step in order to"
31-
desc = "${reasoning}"
32-
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
33-
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
34-
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
35-
self.predict = dspy.Predict(extended_signature, **config)
30+
if dspy.settings.lm.reasoning_model:
31+
self.predict = dspy.Predict(signature, **config)
32+
else:
33+
prefix = "Reasoning: Let's think step by step in order to"
34+
desc = "${reasoning}"
35+
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
36+
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
37+
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
38+
self.predict = dspy.Predict(extended_signature, **config)
3639

3740
def forward(self, **kwargs):
3841
return self.predict(**kwargs)

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"optuna>=3.4.0",
3333
"pydantic>=2.0",
3434
"magicattr>=0.1.6",
35-
"litellm>=1.64.0",
35+
"litellm>=1.72.4",
3636
"diskcache>=5.6.0",
3737
"json-repair>=0.30.0",
3838
"tenacity>=8.2.3",
@@ -59,8 +59,8 @@ dev = [
5959
"pillow>=10.1.0",
6060
"datamodel_code_generator>=0.26.3",
6161
"build>=1.0.3",
62-
"litellm>=1.64.0; sys_platform == 'win32'",
63-
"litellm[proxy]>=1.64.0; sys_platform != 'win32'",
62+
"litellm>=1.72.4; sys_platform == 'win32'",
63+
"litellm[proxy]>=1.72.4; sys_platform != 'win32'",
6464
]
6565
test_extras = [
6666
"mcp; python_version >= '3.10'",

tests/predict/test_chain_of_thought.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,15 @@ async def test_async_chain_of_thought():
2323
program = ChainOfThought("question -> answer")
2424
result = await program.acall(question="What is 1+1?")
2525
assert result.answer == "2"
26+
27+
28+
def test_cot_skips_with_reasoning_model():
29+
lm = DummyLM([{"answer": "2"}])
30+
lm.reasoning_model = True
31+
dspy.settings.configure(lm=lm)
32+
signature = dspy.Signature("question -> answer")
33+
predict = ChainOfThought(signature)
34+
assert list(predict.predict.signature.output_fields.keys()) == [
35+
"answer",
36+
]
37+
assert predict(question="What is 1+1?").answer == "2"

0 commit comments

Comments
 (0)