Skip to content

Commit c8b1c4e

Browse files
committed
Implement CoT no-op for reasoning models
1 parent aa0a2fc commit c8b1c4e

File tree

5 files changed

+235
-42
lines changed

5 files changed

+235
-42
lines changed

dspy/clients/lm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
callbacks: Optional[List[BaseCallback]] = None,
3838
num_retries: int = 3,
3939
provider=None,
40+
reasoning_model: Optional[bool] = None,
4041
finetuning_model: Optional[str] = None,
4142
launch_kwargs: Optional[dict[str, Any]] = None,
4243
train_kwargs: Optional[dict[str, Any]] = None,
@@ -51,6 +52,7 @@ def __init__(
5152
model_type: The type of the model, either ``"chat"`` or ``"text"``.
5253
temperature: The sampling temperature to use when generating responses.
5354
max_tokens: The maximum number of tokens to generate per response.
55+
reasoning_model: Whether the model is a reasoning model.
5456
cache: Whether to cache the model responses for reuse to improve performance
5557
and reduce costs.
5658
cache_in_memory (deprecated): To enable additional caching with LRU in memory.
@@ -71,6 +73,7 @@ def __init__(
7173
self.callbacks = callbacks or []
7274
self.history = []
7375
self.num_retries = num_retries
76+
self.reasoning_model = reasoning_model if reasoning_model is not None else litellm.supports_reasoning(model)
7477
self.finetuning_model = finetuning_model
7578
self.launch_kwargs = launch_kwargs or {}
7679
self.train_kwargs = train_kwargs or {}

dspy/predict/chain_of_thought.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ def __init__(
2727
"""
2828
super().__init__()
2929
signature = ensure_signature(signature)
30-
prefix = "Reasoning: Let's think step by step in order to"
31-
desc = "${reasoning}"
32-
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
33-
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
34-
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
35-
self.predict = dspy.Predict(extended_signature, **config)
30+
if dspy.settings.lm.reasoning_model:
31+
self.predict = dspy.Predict(signature, **config)
32+
else:
33+
prefix = "Reasoning: Let's think step by step in order to"
34+
desc = "${reasoning}"
35+
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
36+
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
37+
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
38+
self.predict = dspy.Predict(extended_signature, **config)
3639

3740
def forward(self, **kwargs):
3841
return self.predict(**kwargs)

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"optuna>=3.4.0",
3535
"pydantic>=2.0",
3636
"magicattr>=0.1.6",
37-
"litellm>=1.60.3",
37+
"litellm>=1.72.4",
3838
"diskcache>=5.6.0",
3939
"json-repair>=0.30.0",
4040
"tenacity>=8.2.3",
@@ -50,7 +50,7 @@ dependencies = [
5050
[project.optional-dependencies]
5151
anthropic = ["anthropic>=0.18.0,<1.0.0"]
5252
weaviate = ["weaviate-client~=4.5.4"]
53-
aws = ["boto3~=1.34.78"]
53+
aws = ["boto3~=1.34.34"]
5454
mcp = ["mcp; python_version >= '3.10'"]
5555
langchain = ["langchain_core"]
5656
dev = [
@@ -62,8 +62,8 @@ dev = [
6262
"pillow>=10.1.0",
6363
"datamodel_code_generator>=0.26.3",
6464
"build>=1.0.3",
65-
"litellm>=1.60.3; sys_platform == 'win32'",
66-
"litellm[proxy]>=1.60.3; sys_platform != 'win32'",
65+
"litellm>=1.72.4; sys_platform == 'win32'",
66+
"litellm[proxy]>=1.72.4; sys_platform != 'win32'",
6767
]
6868
test_extras = [
6969
"mcp; python_version >= '3.10'",

tests/predict/test_chain_of_thought.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,15 @@ async def test_async_chain_of_thought():
2323
program = ChainOfThought("question -> answer")
2424
result = await program.acall(question="What is 1+1?")
2525
assert result.answer == "2"
26+
27+
28+
def test_cot_skips_with_reasoning_model():
29+
lm = DummyLM([{"answer": "2"}])
30+
lm.reasoning_model = True
31+
dspy.settings.configure(lm=lm)
32+
signature = dspy.Signature("question -> answer")
33+
predict = ChainOfThought(signature)
34+
assert list(predict.predict.signature.output_fields.keys()) == [
35+
"answer",
36+
]
37+
assert predict(question="What is 1+1?").answer == "2"

0 commit comments

Comments
 (0)