Skip to content

Commit 1e8d3e2

Browse files
Per module lm history (#8199)
* per module lm history * better testing * add async testing * fix test * fix tests
1 parent d53daec commit 1e8d3e2

File tree

9 files changed

+230
-79
lines changed

9 files changed

+230
-79
lines changed

dspy/clients/base_lm.py

Lines changed: 6 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from dspy.dsp.utils import settings
55
from dspy.utils.callback import with_callbacks
6+
from dspy.utils.inspect_history import pretty_print_history
67

78
MAX_HISTORY_SIZE = 10_000
89
GLOBAL_HISTORY = []
@@ -81,6 +82,9 @@ def _process_lm_response(self, response, prompt, messages, **kwargs):
8182
}
8283
self.history.append(entry)
8384
self.update_global_history(entry)
85+
caller_modules = settings.caller_modules or []
86+
for module in caller_modules:
87+
module.history.append(entry)
8488
return outputs
8589

8690
@with_callbacks
@@ -129,7 +133,7 @@ def copy(self, **kwargs):
129133
return new_instance
130134

131135
def inspect_history(self, n: int = 1):
132-
_inspect_history(self.history, n)
136+
return inspect_history(self.history, n)
133137

134138
def update_global_history(self, entry):
135139
if settings.disable_history:
@@ -141,66 +145,6 @@ def update_global_history(self, entry):
141145
GLOBAL_HISTORY.append(entry)
142146

143147

144-
def _green(text: str, end: str = "\n"):
145-
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end
146-
147-
148-
def _red(text: str, end: str = "\n"):
149-
return "\x1b[31m" + str(text) + "\x1b[0m" + end
150-
151-
152-
def _blue(text: str, end: str = "\n"):
153-
return "\x1b[34m" + str(text) + "\x1b[0m" + end
154-
155-
156-
def _inspect_history(history, n: int = 1):
157-
"""Prints the last n prompts and their completions."""
158-
159-
for item in history[-n:]:
160-
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
161-
outputs = item["outputs"]
162-
timestamp = item.get("timestamp", "Unknown time")
163-
164-
print("\n\n\n")
165-
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")
166-
167-
for msg in messages:
168-
print(_red(f"{msg['role'].capitalize()} message:"))
169-
if isinstance(msg["content"], str):
170-
print(msg["content"].strip())
171-
else:
172-
if isinstance(msg["content"], list):
173-
for c in msg["content"]:
174-
if c["type"] == "text":
175-
print(c["text"].strip())
176-
elif c["type"] == "image_url":
177-
image_str = ""
178-
if "base64" in c["image_url"].get("url", ""):
179-
len_base64 = len(c["image_url"]["url"].split("base64,")[1])
180-
image_str = (
181-
f"<{c['image_url']['url'].split('base64,')[0]}base64,"
182-
f"<IMAGE BASE 64 ENCODED({len_base64!s})>"
183-
)
184-
else:
185-
image_str = f"<image_url: {c['image_url']['url']}>"
186-
print(_blue(image_str.strip()))
187-
elif c["type"] == "input_audio":
188-
audio_format = c["input_audio"]["format"]
189-
len_audio = len(c["input_audio"]["data"])
190-
audio_str = f"<audio format='{audio_format}' base64-encoded, length={len_audio}>"
191-
print(_blue(audio_str.strip()))
192-
print("\n")
193-
194-
print(_red("Response:"))
195-
print(_green(outputs[0].strip()))
196-
197-
if len(outputs) > 1:
198-
choices_text = f" \t (and {len(outputs) - 1} other completions)"
199-
print(_red(choices_text, end=""))
200-
201-
print("\n\n\n")
202-
203-
204148
def inspect_history(n: int = 1):
205149
"""The global history shared across all LMs."""
206-
return _inspect_history(GLOBAL_HISTORY, n)
150+
return pretty_print_history(GLOBAL_HISTORY, n)

dspy/dsp/utils/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
track_usage=False,
2323
usage_tracker=None,
2424
caller_predict=None,
25+
caller_modules=None,
2526
stream_listeners=[],
2627
provide_traceback=False, # Whether to include traceback information in error logs.
2728
num_threads=8, # Number of threads to use for parallel processing.

dspy/predict/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
class Predict(Module, Parameter):
1919
def __init__(self, signature, callbacks=None, **config):
20+
super().__init__(callbacks=callbacks)
2021
self.stage = random.randbytes(8).hex()
2122
self.signature = ensure_signature(signature)
2223
self.config = config
23-
self.callbacks = callbacks or []
2424
self.reset()
2525

2626
def reset(self):

dspy/primitives/program.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dspy.predict.parallel import Parallel
77
from dspy.primitives.module import BaseModule
88
from dspy.utils.callback import with_callbacks
9+
from dspy.utils.inspect_history import pretty_print_history
910
from dspy.utils.usage_tracker import track_usage
1011

1112

@@ -20,26 +21,38 @@ def _base_init(self):
2021
def __init__(self, callbacks=None):
2122
self.callbacks = callbacks or []
2223
self._compiled = False
24+
# LM calling history of the module.
25+
self.history = []
2326

2427
@with_callbacks
2528
def __call__(self, *args, **kwargs):
26-
if settings.track_usage and settings.usage_tracker is None:
27-
with track_usage() as usage_tracker:
28-
output = self.forward(*args, **kwargs)
29+
caller_modules = settings.caller_modules or []
30+
caller_modules = list(caller_modules)
31+
caller_modules.append(self)
32+
33+
with settings.context(caller_modules=caller_modules):
34+
if settings.track_usage and settings.usage_tracker is None:
35+
with track_usage() as usage_tracker:
36+
output = self.forward(*args, **kwargs)
2937
output.set_lm_usage(usage_tracker.get_total_tokens())
3038
return output
3139

32-
return self.forward(*args, **kwargs)
40+
return self.forward(*args, **kwargs)
3341

3442
@with_callbacks
3543
async def acall(self, *args, **kwargs):
36-
if settings.track_usage and settings.usage_tracker is None:
37-
with track_usage() as usage_tracker:
38-
output = await self.aforward(*args, **kwargs)
39-
output.set_lm_usage(usage_tracker.get_total_tokens())
40-
return output
44+
caller_modules = settings.caller_modules or []
45+
caller_modules = list(caller_modules)
46+
caller_modules.append(self)
47+
48+
with settings.context(caller_modules=caller_modules):
49+
if settings.track_usage and settings.usage_tracker is None:
50+
with track_usage() as usage_tracker:
51+
output = await self.aforward(*args, **kwargs)
52+
output.set_lm_usage(usage_tracker.get_total_tokens())
53+
return output
4154

42-
return await self.aforward(*args, **kwargs)
55+
return await self.aforward(*args, **kwargs)
4356

4457
def named_predictors(self):
4558
from dspy.predict.predict import Predict
@@ -75,6 +88,8 @@ def map_named_predictors(self, func):
7588
set_attribute_by_name(self, name, func(predictor))
7689
return self
7790

91+
def inspect_history(self, n: int = 1):
92+
return pretty_print_history(self.history, n)
7893

7994
def batch(
8095
self,

dspy/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dspy.utils import exceptions
77
from dspy.utils.callback import BaseCallback, with_callbacks
88
from dspy.utils.dummies import DummyLM, DummyVectorizer, dummy_rm
9+
from dspy.utils.inspect_history import pretty_print_history
910

1011

1112
def download(url):
@@ -30,4 +31,5 @@ def download(url):
3031
"dummy_rm",
3132
"StatusMessage",
3233
"StatusMessageProvider",
34+
"pretty_print_history",
3335
]

dspy/utils/inspect_history.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
def _green(text: str, end: str = "\n"):
2+
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end
3+
4+
5+
def _red(text: str, end: str = "\n"):
6+
return "\x1b[31m" + str(text) + "\x1b[0m" + end
7+
8+
9+
def _blue(text: str, end: str = "\n"):
10+
return "\x1b[34m" + str(text) + "\x1b[0m" + end
11+
12+
13+
def pretty_print_history(history, n: int = 1):
14+
"""Prints the last n prompts and their completions."""
15+
16+
for item in history[-n:]:
17+
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
18+
outputs = item["outputs"]
19+
timestamp = item.get("timestamp", "Unknown time")
20+
21+
print("\n\n\n")
22+
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")
23+
24+
for msg in messages:
25+
print(_red(f"{msg['role'].capitalize()} message:"))
26+
if isinstance(msg["content"], str):
27+
print(msg["content"].strip())
28+
else:
29+
if isinstance(msg["content"], list):
30+
for c in msg["content"]:
31+
if c["type"] == "text":
32+
print(c["text"].strip())
33+
elif c["type"] == "image_url":
34+
image_str = ""
35+
if "base64" in c["image_url"].get("url", ""):
36+
len_base64 = len(c["image_url"]["url"].split("base64,")[1])
37+
image_str = (
38+
f"<{c['image_url']['url'].split('base64,')[0]}base64,"
39+
f"<IMAGE BASE 64 ENCODED({len_base64!s})>"
40+
)
41+
else:
42+
image_str = f"<image_url: {c['image_url']['url']}>"
43+
print(_blue(image_str.strip()))
44+
print("\n")
45+
46+
print(_red("Response:"))
47+
print(_green(outputs[0].strip()))
48+
49+
if len(outputs) > 1:
50+
choices_text = f" \t (and {len(outputs) - 1} other completions)"
51+
print(_red(choices_text, end=""))
52+
53+
print("\n\n\n")

tests/predict/test_parallel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ def __init__(self):
6868
self.parallel = dspy.Parallel(num_threads=2)
6969

7070
def forward(self, input):
71-
dspy.settings.configure(lm=lm)
72-
res1 = self.predictor.batch([input] * 5)
71+
with dspy.context(lm=lm):
72+
res1 = self.predictor.batch([input] * 5)
7373

74-
dspy.settings.configure(lm=res_lm)
75-
res2 = self.predictor2.batch([input] * 5)
74+
with dspy.context(lm=res_lm):
75+
res2 = self.predictor2.batch([input] * 5)
7676

7777
return (res1, res2)
7878

0 commit comments

Comments
 (0)