Skip to content

Commit ecb9c5e

Browse files
Provide a dspy.syncify so that users can run optimizer on async dspy programs (#8509)
* Add sync wrapper so that async program can be used with optimizer * fix tests
1 parent 538fbb3 commit ecb9c5e

File tree

5 files changed

+121
-1
lines changed

5 files changed

+121
-1
lines changed

dspy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls # isort: skip
1010
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1111
from dspy.utils.asyncify import asyncify
12+
from dspy.utils.syncify import syncify
1213
from dspy.utils.saving import load
1314
from dspy.streaming.streamify import streamify
1415
from dspy.utils.usage_tracker import track_usage

dspy/teleprompt/bootstrap.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,12 @@ def _bootstrap_one_example(self, example, round_idx=0):
245245

246246
# Update the traces
247247
for name, demos in name2traces.items():
248-
from datasets.fingerprint import Hasher
249248

250249
# If there are multiple traces for the same predictor in the sample example,
251250
# sample 50/50 from the first N-1 traces or the last trace.
252251
if len(demos) > 1:
252+
from datasets.fingerprint import Hasher
253+
253254
rng = random.Random(Hasher.hash(tuple(demos)))
254255
demos = [rng.choice(demos[:-1]) if rng.random() < 0.5 else demos[-1]]
255256
self.name2traces[name].extend(demos)

dspy/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dspy.utils.callback import BaseCallback, with_callbacks
88
from dspy.utils.dummies import DummyLM, DummyVectorizer, dummy_rm
99
from dspy.utils.inspect_history import pretty_print_history
10+
from dspy.utils.syncify import syncify
1011

1112

1213
def download(url):

dspy/utils/syncify.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import asyncio
2+
from types import MethodType
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from dspy.primitives.module import Module
7+
8+
9+
def run_async(coro):
10+
"""Run an async coroutine from a synchronous context."""
11+
try:
12+
loop = asyncio.get_running_loop()
13+
except RuntimeError:
14+
loop = None
15+
16+
if loop and loop.is_running():
17+
# If we're in a running event loop (e.g., Jupyter), use asyncio.create_task and run until done
18+
import nest_asyncio
19+
20+
nest_asyncio.apply()
21+
return asyncio.get_event_loop().run_until_complete(coro)
22+
else:
23+
return asyncio.run(coro)
24+
25+
26+
def syncify(program: "Module", in_place: bool = True) -> "Module":
27+
"""Convert an async DSPy module to a sync program.
28+
29+
There are two modes of this function:
30+
31+
- `in_place=True` (recommended): Modify the module in place. But this may not work if you already have a `forward`
32+
method which does different things from `aforward`.
33+
- `in_place=False`: Return a wrapper module. This changes the module's architecture, but it's more robust.
34+
35+
Args:
36+
program: The async program to convert, must have an `aforward` method implemented.
37+
in_place: If True, modify the module in place. Otherwise, return a wrapper module.
38+
39+
Returns:
40+
The sync program, which has a `forward` method that can be called from a synchronous context.
41+
"""
42+
if in_place:
43+
44+
def forward(self, *args, **kwargs):
45+
return run_async(self.aforward(*args, **kwargs))
46+
47+
# Create the `forward` method in place.
48+
program.forward = MethodType(forward, program)
49+
return program
50+
else:
51+
from dspy.primitives.module import Module
52+
53+
class SyncWrapper(Module):
54+
def __init__(self, program: "Module"):
55+
self.program = program
56+
57+
def forward(self, *args, **kwargs):
58+
return run_async(self.program.aforward(*args, **kwargs))
59+
60+
return SyncWrapper(program)

tests/utils/test_syncify.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import asyncio
2+
3+
import dspy
4+
5+
6+
def test_syncify_in_place():
7+
class MyProgram(dspy.Module):
8+
async def aforward(self, x: int) -> int:
9+
await asyncio.sleep(0.01)
10+
return x + 1
11+
12+
sync_program = dspy.syncify(MyProgram())
13+
assert sync_program(1) == 2
14+
assert sync_program(2) == 3
15+
16+
17+
def test_syncify_with_wrapper():
18+
class MyProgram(dspy.Module):
19+
async def aforward(self, x: int) -> int:
20+
await asyncio.sleep(0.01)
21+
return x + 1
22+
23+
sync_program = dspy.syncify(MyProgram(), in_place=False)
24+
assert sync_program(1) == 2
25+
assert sync_program(2) == 3
26+
27+
28+
def test_syncify_works_with_optimizers():
29+
class MyProgram(dspy.Module):
30+
def __init__(self):
31+
self.predict = dspy.Predict("question->answer")
32+
33+
async def aforward(self, question: str):
34+
return await self.predict.acall(question=question)
35+
36+
async_program = MyProgram()
37+
38+
def dummy_metric(gold, pred, traces=None):
39+
return True
40+
41+
# We only test the optimizer completes without errors, so the LM response doesn't matter.
42+
lm = dspy.utils.DummyLM([{"answer": "dummy"} for _ in range(100)])
43+
dspy.configure(lm=lm)
44+
45+
dataset = [dspy.Example(question="question", answer="answer").with_inputs("question") for _ in range(10)]
46+
47+
optimizer = dspy.BootstrapFewShot(metric=dummy_metric, max_bootstrapped_demos=2, max_labeled_demos=0)
48+
49+
# Test syncify in place
50+
sync_program = dspy.syncify(async_program, in_place=True)
51+
optimized_program = optimizer.compile(sync_program, trainset=dataset)
52+
assert len(optimized_program.predictors()[0].demos) == 2
53+
54+
# Test syncify with wrapper
55+
sync_program = dspy.syncify(async_program, in_place=False)
56+
optimized_program = optimizer.compile(sync_program, trainset=dataset)
57+
assert len(optimized_program.predictors()[0].demos) == 2

0 commit comments

Comments
 (0)