Skip to content

Commit 64881c7

Browse files
Allow reusing the stream listener (#8461)
* Allow reusing the stream listner * split out
1 parent 097f857 commit 64881c7

File tree

2 files changed

+76
-3
lines changed

2 files changed

+76
-3
lines changed

dspy/streaming/streaming_listener.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,22 @@
1717
class StreamListener:
1818
"""Class that listens to the stream to capture the streeaming of a specific output field of a predictor."""
1919

20-
def __init__(self, signature_field_name: str, predict: Any = None, predict_name: str | None = None):
20+
def __init__(
21+
self,
22+
signature_field_name: str,
23+
predict: Any = None,
24+
predict_name: str | None = None,
25+
allow_reuse: bool = False,
26+
):
2127
"""
2228
Args:
2329
signature_field_name: The name of the field to listen to.
2430
predict: The predictor to listen to. If None, when calling `streamify()` it will automatically look for
2531
the predictor that has the `signature_field_name` in its signature.
2632
predict_name: The name of the predictor to listen to. If None, when calling `streamify()` it will
2733
automatically look for the predictor that has the `signature_field_name` in its signature.
34+
allow_reuse: If True, the stream listener can be reused for multiple streams. Please note that this could
35+
hurt the performance because the same stream chunk is sent to multiple listeners.
2836
"""
2937
self.signature_field_name = signature_field_name
3038
self.predict = predict
@@ -35,6 +43,7 @@ def __init__(self, signature_field_name: str, predict: Any = None, predict_name:
3543
self.stream_start = False
3644
self.stream_end = False
3745
self.cache_hit = False
46+
self.allow_reuse = allow_reuse
3847

3948
self.json_adapter_start_identifier = f'"{self.signature_field_name}":'
4049
self.json_adapter_end_identifier = re.compile(r"\w*\"(,|\s*})")
@@ -53,7 +62,7 @@ def receive(self, chunk: ModelResponseStream):
5362
start_identifier = self.json_adapter_start_identifier
5463
end_identifier = self.json_adapter_end_identifier
5564

56-
start_indicator = "{"
65+
start_indicator = '"'
5766
elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
5867
start_identifier = self.chat_adapter_start_identifier
5968
end_identifier = self.chat_adapter_end_identifier
@@ -66,7 +75,15 @@ def receive(self, chunk: ModelResponseStream):
6675
)
6776

6877
if self.stream_end:
69-
return
78+
if self.allow_reuse:
79+
# Clear up the state for the next stream.
80+
self.stream_end = False
81+
self.cache_hit = False
82+
self.field_start_queue = []
83+
self.field_end_queue = Queue()
84+
self.stream_start = False
85+
else:
86+
return
7087

7188
try:
7289
chunk_message = chunk.choices[0].delta.content

tests/streaming/test_streaming.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,59 @@ async def aforward(self, question, **kwargs):
707707
# There should be ~1 second delay between the tool start and end messages because we explicitly sleep for 1 second
708708
# in the tool.
709709
assert timestamps[1] - timestamps[0] >= 1
710+
711+
712+
@pytest.mark.anyio
713+
async def test_stream_listener_allow_reuse():
714+
class MyProgram(dspy.Module):
715+
def __init__(self):
716+
super().__init__()
717+
self.predict = dspy.Predict("question->answer")
718+
719+
def forward(self, question, **kwargs):
720+
self.predict(question=question, **kwargs)
721+
return self.predict(question=question, **kwargs)
722+
723+
program = dspy.streamify(
724+
MyProgram(),
725+
stream_listeners=[
726+
dspy.streaming.StreamListener(signature_field_name="answer", allow_reuse=True),
727+
],
728+
)
729+
730+
async def gpt_4o_mini_stream(*args, **kwargs):
731+
# Recorded streaming from openai/gpt-4o-mini
732+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[["))])
733+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
734+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
735+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
736+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]\n\n"))])
737+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))])
738+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))])
739+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
740+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
741+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))])
742+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))])
743+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="!"))])
744+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n"))])
745+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
746+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" completed"))])
747+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
748+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]"))])
749+
750+
stream_generators = [gpt_4o_mini_stream, gpt_4o_mini_stream]
751+
752+
async def completion_side_effect(*args, **kwargs):
753+
return stream_generators.pop(0)() # return new async generator instance
754+
755+
with mock.patch("litellm.acompletion", side_effect=completion_side_effect):
756+
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
757+
output = program(question="why did a chicken cross the kitchen?")
758+
all_chunks = []
759+
async for value in output:
760+
if isinstance(value, dspy.streaming.StreamResponse):
761+
all_chunks.append(value)
762+
763+
concat_message = "".join([chunk.chunk for chunk in all_chunks])
764+
# The listener functions twice.
765+
assert concat_message == "To get to the other side!To get to the other side!"

0 commit comments

Comments
 (0)