Skip to content

Commit 554262b

Browse files
committed
done
1 parent c396217 commit 554262b

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import dspy
2+
3+
from .predict import Module
4+
5+
6+
class ChainOfThoughtWithHint(Module):
7+
def __init__(self, signature, rationale_field_type=None, **config):
8+
self.signature = dspy.ensure_signature(signature)
9+
self.module = dspy.ChainOfThought(signature, rationale_field_type=rationale_field_type, **config)
10+
11+
def forward(self, **kwargs):
12+
if kwargs.get("hint"):
13+
hint = f"\n\t\t(secret hint: {kwargs.pop('hint')})"
14+
original_kwargs = kwargs.copy()
15+
16+
# Convert the first field's value to string and append the hint
17+
last_key = list(self.signature.input_fields.keys())[-1]
18+
kwargs[last_key] = str(kwargs[last_key]) + hint
19+
20+
# Run CoT then update the trace with original kwargs, i.e. without the hint.
21+
with dspy.context(trace=[]):
22+
pred = self.module(**kwargs)
23+
this_trace = dspy.settings.trace[-1]
24+
25+
dspy.settings.trace.append((this_trace[0], original_kwargs, this_trace[2]))
26+
return pred
27+
28+
return self.module(**kwargs)

dspy/teleprompt/bootstrap_finetune.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,18 @@ def compile(
8181
key_to_data = {}
8282
for pred_ind, pred in enumerate(student.predictors()):
8383
data_pred_ind = None if self.multitask else pred_ind
84-
training_key = (pred.lm, data_pred_ind)
84+
lm = pred.lm or settings.lm
85+
training_key = (lm, data_pred_ind)
8586
if training_key not in key_to_data:
8687
train_data, data_format = self._prepare_finetune_data(
87-
trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind
88+
trace_data=trace_data, lm=lm, pred_ind=data_pred_ind
8889
)
89-
logger.info(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
90+
logger.info(f"Using {len(train_data)} data points for fine-tuning the model: {lm.model}")
9091
finetune_kwargs = {
91-
"lm": pred.lm,
92+
"lm": lm,
9293
"train_data": train_data,
9394
"train_data_format": data_format,
94-
"train_kwargs": self.train_kwargs[pred.lm],
95+
"train_kwargs": self.train_kwargs[lm],
9596
}
9697
key_to_data[training_key] = finetune_kwargs
9798

0 commit comments

Comments
 (0)