Skip to content

Commit a97437b

Browse files
Fix the error that status message streaming was blocking (#8432)
* Fix the status message streaming error * fix tests
1 parent 94299c7 commit a97437b

File tree

2 files changed

+75
-21
lines changed

2 files changed

+75
-21
lines changed

dspy/streaming/messages.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import asyncio
2+
import concurrent.futures
23
from dataclasses import dataclass
34
from typing import Any, Dict, Optional
45

6+
from asyncer import syncify
7+
58
from dspy.dsp.utils.settings import settings
69
from dspy.utils.callback import BaseCallback
710

@@ -21,36 +24,29 @@ class StatusMessage:
2124

2225

2326
def sync_send_to_stream(stream, message):
24-
"""Send message to stream in a sync context, regardless of whether the caller is async or not."""
25-
# Try to get current event loop, create one if none exists
26-
try:
27-
loop = asyncio.get_event_loop()
28-
except RuntimeError:
29-
# "There is no current event loop in thread" error
30-
loop = asyncio.new_event_loop()
31-
asyncio.set_event_loop(loop)
27+
"""Send message to stream in a sync context, regardless of event loop state."""
3228

33-
# If we're in an async context
34-
if loop.is_running():
35-
# In an async context, we need to use an approach that doesn't block
36-
# Create a new thread and run a new event loop there
37-
import concurrent.futures
29+
async def _send():
30+
await stream.send(message)
3831

39-
def run_async_in_new_loop():
32+
try:
33+
asyncio.get_running_loop()
34+
35+
# If we're in an event loop, offload to a new thread with its own event loop
36+
def run_in_new_loop():
4037
new_loop = asyncio.new_event_loop()
4138
asyncio.set_event_loop(new_loop)
4239
try:
43-
return new_loop.run_until_complete(stream.send(message))
40+
return new_loop.run_until_complete(_send())
4441
finally:
4542
new_loop.close()
4643

47-
# Run the function in a separate thread
4844
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
49-
future = executor.submit(run_async_in_new_loop)
50-
return future.result() # This shouldn't hang now
51-
else:
52-
# We're in a sync context, use run_until_complete
53-
return loop.run_until_complete(stream.send(message))
45+
future = executor.submit(run_in_new_loop)
46+
return future.result()
47+
except RuntimeError:
48+
# Not in an event loop, safe to use a new event loop in this thread
49+
return syncify(_send)()
5450

5551

5652
class StatusMessageProvider:

tests/streaming/test_streaming.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import asyncio
12
import os
3+
import time
24
from unittest import mock
35
from unittest.mock import AsyncMock
46

@@ -649,3 +651,59 @@ async def gemini_stream_2(*args, **kwargs):
649651
assert all_chunks[1].predict_name == "predict2"
650652
assert all_chunks[1].signature_field_name == "judgement"
651653
assert all_chunks[1].chunk == "The answer provides a humorous and relevant punchline to the classic joke setup."
654+
655+
656+
@pytest.mark.anyio
657+
async def test_status_message_non_blocking():
658+
def dummy_tool():
659+
time.sleep(1)
660+
return "dummy_tool_output"
661+
662+
class MyProgram(dspy.Module):
663+
def forward(self, question, **kwargs):
664+
dspy.Tool(dummy_tool)()
665+
return dspy.Prediction(answer="dummy_tool_output")
666+
667+
program = dspy.streamify(MyProgram(), status_message_provider=StatusMessageProvider())
668+
669+
with mock.patch("litellm.acompletion", new_callable=AsyncMock, side_effect=[dummy_tool]):
670+
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
671+
output = program(question="why did a chicken cross the kitchen?")
672+
timestamps = []
673+
async for value in output:
674+
if isinstance(value, dspy.streaming.StatusMessage):
675+
timestamps.append(time.time())
676+
677+
# timestamps[0]: tool start message
678+
# timestamps[1]: tool end message
679+
# There should be ~1 second delay between the tool start and end messages because we explicitly sleep for 1 second
680+
# in the tool.
681+
assert timestamps[1] - timestamps[0] >= 1
682+
683+
684+
@pytest.mark.anyio
685+
async def test_status_message_non_blocking_async_program():
686+
async def dummy_tool():
687+
await asyncio.sleep(1)
688+
return "dummy_tool_output"
689+
690+
class MyProgram(dspy.Module):
691+
async def aforward(self, question, **kwargs):
692+
await dspy.Tool(dummy_tool).acall()
693+
return dspy.Prediction(answer="dummy_tool_output")
694+
695+
program = dspy.streamify(MyProgram(), status_message_provider=StatusMessageProvider(), is_async_program=True)
696+
697+
with mock.patch("litellm.acompletion", new_callable=AsyncMock, side_effect=[dummy_tool]):
698+
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
699+
output = program(question="why did a chicken cross the kitchen?")
700+
timestamps = []
701+
async for value in output:
702+
if isinstance(value, dspy.streaming.StatusMessage):
703+
timestamps.append(time.time())
704+
705+
# timestamps[0]: tool start message
706+
# timestamps[1]: tool end message
707+
# There should be ~1 second delay between the tool start and end messages because we explicitly sleep for 1 second
708+
# in the tool.
709+
assert timestamps[1] - timestamps[0] >= 1

0 commit comments

Comments
 (0)