Skip to content

Commit dd8cf1c

Browse files
Merge async settings change into main (#8361)
* Make dspy.settings and dspy.context safe in async setup (#8203) (#8324) * Make dspy.settings and dspy.context safe in async setup (#8203) * init * increment * test * track usage async test * better-test * revert wrong changes * fix tests * fix tests * fix * allow ipython to run configure * remove Ipython imports * relax the constraint * remove TaskGroup usage, which is only available after 3.10 --------- Co-authored-by: Omar Khattab <okhat@users.noreply.github.com> * more robust usage tracker in async (#8329) * merge from main and fix lint --------- Co-authored-by: Chen Qian <chen.qian@databricks.com> * fix databricks notebook case (#8359) --------- Co-authored-by: Omar Khattab <okhat@users.noreply.github.com>
1 parent 716e82c commit dd8cf1c

File tree

14 files changed

+438
-135
lines changed

14 files changed

+438
-135
lines changed

dspy/dsp/utils/settings.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import contextvars
13
import copy
24
import threading
35
from contextlib import contextmanager
@@ -38,13 +40,7 @@
3840
# Global lock for settings configuration
3941
global_lock = threading.Lock()
4042

41-
42-
class ThreadLocalOverrides(threading.local):
43-
def __init__(self):
44-
self.overrides = dotdict()
45-
46-
47-
thread_local_overrides = ThreadLocalOverrides()
43+
thread_local_overrides = contextvars.ContextVar("context_overrides", default=dotdict())
4844

4945

5046
class Settings:
@@ -75,7 +71,7 @@ def lock(self):
7571
return global_lock
7672

7773
def __getattr__(self, name):
78-
overrides = getattr(thread_local_overrides, "overrides", dotdict())
74+
overrides = thread_local_overrides.get()
7975
if name in overrides:
8076
return overrides[name]
8177
elif name in main_thread_config:
@@ -96,7 +92,7 @@ def __setitem__(self, key, value):
9692
self.__setattr__(key, value)
9793

9894
def __contains__(self, key):
99-
overrides = getattr(thread_local_overrides, "overrides", dotdict())
95+
overrides = thread_local_overrides.get()
10096
return key in overrides or key in main_thread_config
10197

10298
def get(self, key, default=None):
@@ -106,23 +102,60 @@ def get(self, key, default=None):
106102
return default
107103

108104
def copy(self):
109-
overrides = getattr(thread_local_overrides, "overrides", dotdict())
105+
overrides = thread_local_overrides.get()
110106
return dotdict({**main_thread_config, **overrides})
111107

112108
@property
113109
def config(self):
114110
return self.copy()
115111

116-
def configure(self, **kwargs):
112+
def _ensure_configure_allowed(self):
117113
global main_thread_config, config_owner_thread_id
118114
current_thread_id = threading.get_ident()
119115

120-
with self.lock:
121-
# First configuration: establish ownership. If ownership established, only that thread can configure.
122-
if config_owner_thread_id in [None, current_thread_id]:
123-
config_owner_thread_id = current_thread_id
124-
else:
125-
raise RuntimeError("dspy.settings can only be changed by the thread that initially configured it.")
116+
if config_owner_thread_id is None:
117+
# First `configure` call is always allowed.
118+
config_owner_thread_id = current_thread_id
119+
return
120+
121+
if config_owner_thread_id != current_thread_id:
122+
# Disallow a second `configure` calls from other threads.
123+
raise RuntimeError("dspy.settings can only be changed by the thread that initially configured it.")
124+
125+
# Async task doesn't allow a second `configure` call, must use dspy.context(...) instead.
126+
is_async_task = False
127+
try:
128+
if asyncio.current_task() is not None:
129+
is_async_task = True
130+
except RuntimeError:
131+
# This exception (e.g., "no current task") means we are not in an async loop/task,
132+
# or asyncio module itself is not fully functional in this specific sub-thread context.
133+
is_async_task = False
134+
135+
if not is_async_task:
136+
return
137+
138+
# We are in an async task. Now check for IPython and allow calling `configure` from IPython.
139+
in_ipython = False
140+
try:
141+
from IPython import get_ipython
142+
143+
# get_ipython is a global injected by IPython environments.
144+
# We check its existence and type to be more robust.
145+
in_ipython = get_ipython() is not None
146+
except Exception:
147+
# If `IPython` is not installed or `get_ipython` failed, we are not in an IPython environment.
148+
in_ipython = False
149+
150+
if not in_ipython:
151+
raise RuntimeError(
152+
"dspy.settings.configure(...) cannot be called a second time from an async task. Use "
153+
"`dspy.context(...)` instead."
154+
)
155+
156+
def configure(self, **kwargs):
157+
# If no exception is raised, the `configure` call is allowed.
158+
self._ensure_configure_allowed()
126159

127160
# Update global config
128161
for k, v in kwargs.items():
@@ -136,17 +169,17 @@ def context(self, **kwargs):
136169
If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides.
137170
"""
138171

139-
original_overrides = getattr(thread_local_overrides, "overrides", dotdict()).copy()
172+
original_overrides = thread_local_overrides.get().copy()
140173
new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs})
141-
thread_local_overrides.overrides = new_overrides
174+
token = thread_local_overrides.set(new_overrides)
142175

143176
try:
144177
yield
145178
finally:
146-
thread_local_overrides.overrides = original_overrides
179+
thread_local_overrides.reset(token)
147180

148181
def __repr__(self):
149-
overrides = getattr(thread_local_overrides, "overrides", dotdict())
182+
overrides = thread_local_overrides.get()
150183
combined_config = {**main_thread_config, **overrides}
151184
return repr(combined_config)
152185

dspy/primitives/program.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import magicattr
44

5-
from dspy.dsp.utils.settings import settings
5+
from dspy.dsp.utils.settings import settings, thread_local_overrides
66
from dspy.predict.parallel import Parallel
77
from dspy.primitives.module import BaseModule
88
from dspy.utils.callback import with_callbacks
@@ -51,7 +51,7 @@ def __call__(self, *args, **kwargs):
5151
caller_modules.append(self)
5252

5353
with settings.context(caller_modules=caller_modules):
54-
if settings.track_usage and settings.usage_tracker is None:
54+
if settings.track_usage and thread_local_overrides.get().get("usage_tracker") is None:
5555
with track_usage() as usage_tracker:
5656
output = self.forward(*args, **kwargs)
5757
output.set_lm_usage(usage_tracker.get_total_tokens())
@@ -66,7 +66,7 @@ async def acall(self, *args, **kwargs):
6666
caller_modules.append(self)
6767

6868
with settings.context(caller_modules=caller_modules):
69-
if settings.track_usage and settings.usage_tracker is None:
69+
if settings.track_usage and thread_local_overrides.get().get("usage_tracker") is None:
7070
with track_usage() as usage_tracker:
7171
output = await self.aforward(*args, **kwargs)
7272
output.set_lm_usage(usage_tracker.get_total_tokens())

dspy/streaming/streamify.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,16 +222,10 @@ def apply_sync_streaming(async_generator: AsyncGenerator) -> Generator:
222222

223223
# To propagate prediction request ID context to the child thread
224224
context = contextvars.copy_context()
225-
from dspy.dsp.utils.settings import thread_local_overrides
226-
227-
parent_overrides = thread_local_overrides.overrides.copy()
228225

229226
def producer():
230227
"""Runs in a background thread to fetch items asynchronously."""
231228

232-
original_overrides = thread_local_overrides.overrides
233-
thread_local_overrides.overrides = parent_overrides.copy()
234-
235229
async def runner():
236230
try:
237231
async for item in async_generator:
@@ -241,7 +235,6 @@ async def runner():
241235
queue.put(stop_sentinel)
242236

243237
context.run(asyncio.run, runner())
244-
thread_local_overrides.overrides = original_overrides
245238

246239
# Start the producer in a background thread
247240
thread = threading.Thread(target=producer, daemon=True)

dspy/utils/asyncify.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,17 @@ async def async_program(*args, **kwargs) -> Any:
4646
# Capture the current overrides at call-time.
4747
from dspy.dsp.utils.settings import thread_local_overrides
4848

49-
parent_overrides = thread_local_overrides.overrides.copy()
49+
parent_overrides = thread_local_overrides.get().copy()
5050

5151
def wrapped_program(*a, **kw):
5252
from dspy.dsp.utils.settings import thread_local_overrides
5353

54-
original_overrides = thread_local_overrides.overrides
55-
thread_local_overrides.overrides = parent_overrides.copy()
54+
original_overrides = thread_local_overrides.get()
55+
token = thread_local_overrides.set({**original_overrides, **parent_overrides.copy()})
5656
try:
5757
return program(*a, **kw)
5858
finally:
59-
thread_local_overrides.overrides = original_overrides
59+
thread_local_overrides.reset(token)
6060

6161
# Create a fresh asyncified callable each time, ensuring the latest context is used.
6262
call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter())

dspy/utils/parallelizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,16 @@ def worker(parent_overrides, submission_id, index, item):
8686
# Apply parent's thread-local overrides
8787
from dspy.dsp.utils.settings import thread_local_overrides
8888

89-
original = thread_local_overrides.overrides
90-
thread_local_overrides.overrides = parent_overrides.copy()
89+
original = thread_local_overrides.get()
90+
token = thread_local_overrides.set({**original, **parent_overrides.copy()})
9191
if parent_overrides.get("usage_tracker"):
9292
# Usage tracker needs to be deep copied across threads so that each thread tracks its own usage
9393
thread_local_overrides.overrides["usage_tracker"] = copy.deepcopy(parent_overrides["usage_tracker"])
9494

9595
try:
9696
return index, function(item)
9797
finally:
98-
thread_local_overrides.overrides = original
98+
thread_local_overrides.reset(token)
9999

100100
# Handle Ctrl-C in the main thread
101101
@contextlib.contextmanager
@@ -121,7 +121,7 @@ def handler(sig, frame):
121121
with interrupt_manager():
122122
from dspy.dsp.utils.settings import thread_local_overrides
123123

124-
parent_overrides = thread_local_overrides.overrides.copy()
124+
parent_overrides = thread_local_overrides.get().copy()
125125

126126
futures_map = {}
127127
futures_set = set()

tests/adapters/test_two_step_adapter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,8 @@ class TestSignature(dspy.Signature):
9494
mock_extraction_lm.kwargs = {"temperature": 1.0}
9595
mock_extraction_lm.model = "openai/gpt-4o"
9696

97-
dspy.configure(lm=mock_main_lm, adapter=dspy.TwoStepAdapter(extraction_model=mock_extraction_lm))
98-
99-
result = await program.acall(question="What is 5 + 7?")
97+
with dspy.context(lm=mock_main_lm, adapter=dspy.TwoStepAdapter(extraction_model=mock_extraction_lm)):
98+
result = await program.acall(question="What is 5 + 7?")
10099

101100
assert result.answer == 12
102101

tests/callback/test_callback.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,12 @@ def test_callback_complex_module():
189189
@pytest.mark.asyncio
190190
async def test_callback_async_module():
191191
callback = MyCallback()
192-
dspy.settings.configure(
192+
with dspy.context(
193193
lm=DummyLM({"How are you?": {"answer": "test output", "reasoning": "No more responses"}}),
194194
callbacks=[callback],
195-
)
196-
197-
cot = dspy.ChainOfThought("question -> answer", n=3)
198-
result = await cot.acall(question="How are you?")
195+
):
196+
cot = dspy.ChainOfThought("question -> answer", n=3)
197+
result = await cot.acall(question="How are you?")
199198
assert result["answer"] == "test output"
200199
assert result["reasoning"] == "No more responses"
201200

tests/predict/test_chain_of_thought.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_initialization_with_string_signature():
1919
@pytest.mark.asyncio
2020
async def test_async_chain_of_thought():
2121
lm = DummyLM([{"reasoning": "find the number after 1", "answer": "2"}])
22-
dspy.settings.configure(lm=lm)
23-
program = ChainOfThought("question -> answer")
24-
result = await program.acall(question="What is 1+1?")
25-
assert result.answer == "2"
22+
with dspy.context(lm=lm):
23+
program = ChainOfThought("question -> answer")
24+
result = await program.acall(question="What is 1+1?")
25+
assert result.answer == "2"

tests/predict/test_predict.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import asyncio
12
import copy
23
import enum
4+
import time
5+
import types
36
from datetime import datetime
47
from unittest.mock import patch
58

@@ -506,6 +509,69 @@ def test_lm_usage():
506509
assert result.get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
507510

508511

512+
def test_lm_usage_with_parallel():
513+
program = Predict("question -> answer")
514+
515+
def program_wrapper(question):
516+
# Sleep to make it possible to cause a race condition
517+
time.sleep(0.5)
518+
return program(question=question)
519+
520+
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True)
521+
with patch(
522+
"dspy.clients.lm.litellm_completion",
523+
return_value=ModelResponse(
524+
choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}],
525+
usage={"total_tokens": 10},
526+
),
527+
):
528+
parallelizer = dspy.Parallel()
529+
input_pairs = [
530+
(program_wrapper, {"question": "What is the capital of France?"}),
531+
(program_wrapper, {"question": "What is the capital of France?"}),
532+
]
533+
results = parallelizer(input_pairs)
534+
assert results[0].answer == "Paris"
535+
assert results[1].answer == "Paris"
536+
assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
537+
assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
538+
539+
540+
@pytest.mark.asyncio
541+
async def test_lm_usage_with_async():
542+
program = Predict("question -> answer")
543+
544+
original_aforward = program.aforward
545+
546+
async def patched_aforward(self, **kwargs):
547+
await asyncio.sleep(1)
548+
return await original_aforward(**kwargs)
549+
550+
program.aforward = types.MethodType(patched_aforward, program)
551+
552+
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True):
553+
with patch(
554+
"litellm.acompletion",
555+
return_value=ModelResponse(
556+
choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}],
557+
usage={"total_tokens": 10},
558+
),
559+
):
560+
coroutines = [
561+
program.acall(question="What is the capital of France?"),
562+
program.acall(question="What is the capital of France?"),
563+
program.acall(question="What is the capital of France?"),
564+
program.acall(question="What is the capital of France?"),
565+
]
566+
results = await asyncio.gather(*coroutines)
567+
assert results[0].answer == "Paris"
568+
assert results[1].answer == "Paris"
569+
assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
570+
assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
571+
assert results[2].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
572+
assert results[3].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
573+
574+
509575
def test_positional_arguments():
510576
program = Predict("question -> answer")
511577
with pytest.raises(ValueError) as e:
@@ -569,9 +635,9 @@ class ConstrainedSignature(dspy.Signature):
569635
@pytest.mark.asyncio
570636
async def test_async_predict():
571637
program = Predict("question -> answer")
572-
dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]))
573-
result = await program.acall(question="What is the capital of France?")
574-
assert result.answer == "Paris"
638+
with dspy.context(lm=DummyLM([{"answer": "Paris"}])):
639+
result = await program.acall(question="What is the capital of France?")
640+
assert result.answer == "Paris"
575641

576642

577643
def test_predicted_outputs_piped_from_predict_to_lm_call():

0 commit comments

Comments
 (0)