Skip to content

Commit e7e0039

Browse files
committed
fix set_lm
1 parent 1e56d8a commit e7e0039

File tree

5 files changed

+47
-42
lines changed

5 files changed

+47
-42
lines changed

docs/docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,12 @@ Given a few tens or hundreds of representative _inputs_ of your task and a _metr
403403

404404
```python linenums="1"
405405
import dspy
406-
dspy.configure(lm=dspy.LM("openai/gpt-4o-mini-2024-07-18"))
406+
lm=dspy.LM('openai/gpt-4o-mini-2024-07-18')
407407

408408
# Define the DSPy module for classification. It will use the hint at training time, if available.
409409
signature = dspy.Signature("text, hint -> label").with_updated_fields("label", type_=Literal[tuple(CLASSES)])
410410
classify = dspy.ChainOfThought(signature)
411+
classify.set_lm(lm)
411412

412413
# Optimize via BootstrapFinetune.
413414
optimizer = dspy.BootstrapFinetune(metric=(lambda x, y, trace=None: x.label == y.label), num_threads=24)

docs/docs/learn/optimization/optimizers.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,12 @@ optimized_program = teleprompter.compile(YOUR_PROGRAM_HERE, trainset=YOUR_TRAINS
176176

177177
```python linenums="1"
178178
import dspy
179-
dspy.configure(lm=dspy.LM('openai/gpt-4o-mini-2024-07-18'))
179+
lm=dspy.LM('openai/gpt-4o-mini-2024-07-18')
180180

181181
# Define the DSPy module for classification. It will use the hint at training time, if available.
182182
signature = dspy.Signature("text, hint -> label").with_updated_fields('label', type_=Literal[tuple(CLASSES)])
183183
classify = dspy.ChainOfThought(signature)
184+
classify.set_lm(lm)
184185

185186
# Optimize via BootstrapFinetune.
186187
optimizer = dspy.BootstrapFinetune(metric=(lambda x, y, trace=None: x.label == y.label), num_threads=24)

docs/docs/tutorials/classification_finetuning/index.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
},
8787
{
8888
"cell_type": "code",
89-
"execution_count": 1,
89+
"execution_count": null,
9090
"metadata": {},
9191
"outputs": [],
9292
"source": [
@@ -97,7 +97,7 @@
9797
"\n",
9898
"# Load the Banking77 dataset.\n",
9999
"CLASSES = load_dataset(\"PolyAI/banking77\", split=\"train\", trust_remote_code=True).features['label'].names\n",
100-
"kwargs = dict(fields=(\"text\", \"label\"), input_keys=(\"text\",), split=\"train\", trust_remote_code=True)\n",
100+
"kwargs = dict(fields=(\"text\", \"label\"), input_keys=(\"text\",\"hint\"), split=\"train\", trust_remote_code=True)\n",
101101
"\n",
102102
"# Load the first 2000 examples from the dataset, and assign a hint to each *training* example.\n",
103103
"raw_data = [\n",

dspy/teleprompt/bootstrap_finetune.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,24 @@ def compile(
8181
key_to_data = {}
8282
for pred_ind, pred in enumerate(student.predictors()):
8383
data_pred_ind = None if self.multitask else pred_ind
84-
lm = pred.lm or settings.lm
85-
training_key = (lm, data_pred_ind)
84+
if pred.lm is None:
85+
raise ValueError(
86+
f"Predictor {pred_ind} does not have an LM assigned. "
87+
f"Please ensure the module's predictors have their LM set before fine-tuning. "
88+
f"You can set it using: your_module.set_lm(your_lm)"
89+
)
90+
training_key = (pred.lm, data_pred_ind)
8691

8792
if training_key not in key_to_data:
8893
train_data, data_format = self._prepare_finetune_data(
89-
trace_data=trace_data, lm=lm, pred_ind=data_pred_ind
94+
trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind
9095
)
91-
logger.info(f"Using {len(train_data)} data points for fine-tuning the model: {lm.model}")
96+
logger.info(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
9297
finetune_kwargs = {
93-
"lm": lm,
98+
"lm": pred.lm,
9499
"train_data": train_data,
95100
"train_data_format": data_format,
96-
"train_kwargs": self.train_kwargs[lm],
101+
"train_kwargs": self.train_kwargs[pred.lm],
97102
}
98103
key_to_data[training_key] = finetune_kwargs
99104

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
import dspy
24
from dspy import Example
35
from dspy.predict import Predict
@@ -12,57 +14,53 @@ def simple_metric(example, prediction, trace=None):
1214

1315
examples = [
1416
Example(input="What is the color of the sky?", output="blue").with_inputs("input"),
15-
Example(input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!"),
17+
Example(input="What does the fox say?", output="Ring-ding-ding-ding-dingeringeding!").with_inputs("input"),
1618
]
1719
trainset = [examples[0]]
1820

1921

2022
def test_bootstrap_finetune_initialization():
21-
# Initialize BootstrapFinetune with a dummy metric and minimal setup
23+
"""Test BootstrapFinetune initialization with various parameters."""
2224
bootstrap = BootstrapFinetune(metric=simple_metric)
2325
assert bootstrap.metric == simple_metric, "Metric not correctly initialized"
24-
assert bootstrap.multitask, "Multitask should default to True"
26+
assert bootstrap.multitask == True, "Multitask should default to True"
2527

2628

2729
class SimpleModule(dspy.Module):
28-
def __init__(self, signature, lm=None):
30+
def __init__(self, signature):
2931
super().__init__()
3032
self.predictor = Predict(signature)
31-
if lm:
32-
self.predictor.lm = lm
3333

3434
def forward(self, **kwargs):
3535
return self.predictor(**kwargs)
3636

3737

38-
def test_compile_with_predict_instances_no_explicit_lm():
39-
"""Test BootstrapFinetune compile with predictors that don't have explicit LMs."""
40-
from unittest.mock import patch
38+
def test_error_handling_during_bootstrap():
39+
"""Test error handling during the bootstrapping process."""
40+
41+
class BuggyModule(dspy.Module):
42+
def __init__(self, signature):
43+
super().__init__()
44+
self.predictor = Predict(signature)
4145

42-
# Create student and teacher modules without explicit LMs in predictors
43-
student = SimpleModule("input -> output")
44-
teacher = SimpleModule("input -> output")
46+
def forward(self, **kwargs):
47+
raise RuntimeError("Simulated error")
4548

46-
lm = DummyLM(["Initial thoughts", "Finish[blue]"])
49+
student = SimpleModule("input -> output")
50+
teacher = BuggyModule("input -> output")
51+
52+
# Setup DummyLM to simulate an error scenario
53+
lm = DummyLM(
54+
[
55+
{"output": "Initial thoughts"}, # Simulate initial teacher's prediction
56+
]
57+
)
4758
dspy.settings.configure(lm=lm)
4859

49-
# Verify that the predictor doesn't have an explicit LM
50-
assert student.predictor.lm is None
51-
bootstrap = BootstrapFinetune(metric=simple_metric)
52-
53-
# Mock all the components that would fail without proper setup
54-
with patch("dspy.teleprompt.bootstrap_finetune.all_predictors_have_lms"), \
55-
patch("dspy.teleprompt.bootstrap_finetune.prepare_teacher", return_value=teacher), \
56-
patch("dspy.teleprompt.bootstrap_finetune.bootstrap_trace_data", return_value=[]), \
57-
patch.object(bootstrap, "_prepare_finetune_data", return_value=([], "openai")), \
58-
patch.object(bootstrap, "finetune_lms") as mock_finetune_lms:
59-
60-
mock_finetune_lms.return_value = {(lm, None): lm}
61-
62-
# This should not raise AttributeError due to the fix
63-
compiled_student = bootstrap.compile(student, teacher=teacher, trainset=trainset)
64-
65-
assert compiled_student is not None, "Failed to compile student"
66-
mock_finetune_lms.assert_called_once()
67-
60+
bootstrap = BootstrapFinetune(
61+
metric=simple_metric,
62+
max_errors=1,
63+
)
6864

65+
with pytest.raises(RuntimeError, match="Simulated error"):
66+
bootstrap.compile(student, teacher=teacher, trainset=trainset)

0 commit comments

Comments
 (0)