Skip to content

Commit 982e20f

Browse files
Add token streaming support for XMLAdapter (#8478)
* Add XMLAdapter support to StreamListener * Add XMLAdapter streaming and tests * fix comments
1 parent ecb9c5e commit 982e20f

File tree

2 files changed

+111
-22
lines changed

2 files changed

+111
-22
lines changed

dspy/streaming/streaming_listener.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77

88
from dspy.adapters.chat_adapter import ChatAdapter
99
from dspy.adapters.json_adapter import JSONAdapter
10+
from dspy.adapters.xml_adapter import XMLAdapter
1011
from dspy.dsp.utils.settings import settings
1112
from dspy.streaming.messages import StreamResponse
1213

1314
if TYPE_CHECKING:
1415
from dspy.primitives.module import Module
1516

17+
ADAPTER_SUPPORT_STREAMING = [ChatAdapter, XMLAdapter, JSONAdapter]
18+
1619

1720
class StreamListener:
1821
"""Class that listens to the stream to capture the streeaming of a specific output field of a predictor."""
@@ -45,11 +48,23 @@ def __init__(
4548
self.cache_hit = False
4649
self.allow_reuse = allow_reuse
4750

48-
self.json_adapter_start_identifier = f'"{self.signature_field_name}":'
49-
self.json_adapter_end_identifier = re.compile(r"\w*\"(,|\s*})")
50-
51-
self.chat_adapter_start_identifier = f"[[ ## {self.signature_field_name} ## ]]"
52-
self.chat_adapter_end_identifier = re.compile(r"\[\[ ## (\w+) ## \]\]")
51+
self.adapter_identifiers = {
52+
"ChatAdapter": {
53+
"start_identifier": f"[[ ## {self.signature_field_name} ## ]]",
54+
"end_identifier": re.compile(r"\[\[ ## (\w+) ## \]\]"),
55+
"start_indicator": "[",
56+
},
57+
"JSONAdapter": {
58+
"start_identifier": f'"{self.signature_field_name}":',
59+
"end_identifier": re.compile(r"\w*\"(,|\s*})"),
60+
"start_indicator": '"',
61+
},
62+
"XMLAdapter": {
63+
"start_identifier": f"<{self.signature_field_name}>",
64+
"end_identifier": re.compile(rf"</{self.signature_field_name}>"),
65+
"start_indicator": "<",
66+
},
67+
}
5368

5469
def _buffered_message_end_with_start_identifier(self, concat_message: str, start_identifier: str) -> str:
5570
for i in range(len(concat_message)):
@@ -58,21 +73,15 @@ def _buffered_message_end_with_start_identifier(self, concat_message: str, start
5873
return False
5974

6075
def receive(self, chunk: ModelResponseStream):
61-
if isinstance(settings.adapter, JSONAdapter):
62-
start_identifier = self.json_adapter_start_identifier
63-
end_identifier = self.json_adapter_end_identifier
64-
65-
start_indicator = '"'
66-
elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
67-
start_identifier = self.chat_adapter_start_identifier
68-
end_identifier = self.chat_adapter_end_identifier
69-
70-
start_indicator = "["
71-
else:
76+
adapter_name = settings.adapter.__class__.__name__ if settings.adapter else "ChatAdapter"
77+
if adapter_name not in self.adapter_identifiers:
7278
raise ValueError(
73-
f"Unsupported adapter for streaming: {settings.adapter}, please use either ChatAdapter or "
74-
"JSONAdapter for streaming purposes."
79+
f"Unsupported adapter for streaming: {adapter_name}, please use one of the following adapters: "
80+
f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}"
7581
)
82+
start_identifier = self.adapter_identifiers[adapter_name]["start_identifier"]
83+
end_identifier = self.adapter_identifiers[adapter_name]["end_identifier"]
84+
start_indicator = self.adapter_identifiers[adapter_name]["start_indicator"]
7685

7786
if self.stream_end:
7887
if self.allow_reuse:
@@ -175,13 +184,18 @@ def flush(self) -> str:
175184
else:
176185
boundary_index = len(last_tokens)
177186
return last_tokens[:boundary_index]
187+
elif isinstance(settings.adapter, XMLAdapter):
188+
boundary_index = last_tokens.find(f"</{self.signature_field_name}>")
189+
if boundary_index == -1:
190+
boundary_index = len(last_tokens)
191+
return last_tokens[:boundary_index]
178192
elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
179193
boundary_index = last_tokens.find("[[")
180194
return last_tokens[:boundary_index]
181195
else:
182196
raise ValueError(
183-
f"Unsupported adapter for streaming: {settings.adapter}, please use either ChatAdapter or "
184-
"JSONAdapter for streaming purposes."
197+
f"Unsupported adapter for streaming: {settings.adapter}, please use one of the following adapters: "
198+
f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}"
185199
)
186200

187201

tests/streaming/test_streaming.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ async def gpt_4o_mini_stream_2():
387387
async def completion_side_effect(*args, **kwargs):
388388
return stream_generators.pop(0)() # return new async generator instance
389389

390-
with mock.patch("litellm.acompletion", side_effect=completion_side_effect) as mock_completion:
390+
with mock.patch("litellm.acompletion", side_effect=completion_side_effect):
391391
program = dspy.streamify(
392392
MyProgram(),
393393
stream_listeners=[
@@ -483,7 +483,7 @@ async def gpt_4o_mini_stream_2(*args, **kwargs):
483483

484484
with mock.patch(
485485
"litellm.acompletion", new_callable=AsyncMock, side_effect=[gpt_4o_mini_stream_1(), gpt_4o_mini_stream_2()]
486-
) as mock_completion:
486+
):
487487
program = dspy.streamify(
488488
MyProgram(),
489489
stream_listeners=[
@@ -762,3 +762,78 @@ async def completion_side_effect(*args, **kwargs):
762762
concat_message = "".join([chunk.chunk for chunk in all_chunks])
763763
# The listener functions twice.
764764
assert concat_message == "To get to the other side!To get to the other side!"
765+
766+
@pytest.mark.anyio
767+
async def test_stream_listener_returns_correct_chunk_xml_adapter():
768+
class MyProgram(dspy.Module):
769+
def __init__(self):
770+
super().__init__()
771+
self.predict1 = dspy.Predict("question->answer")
772+
self.predict2 = dspy.Predict("question,answer->judgement")
773+
774+
def forward(self, question, **kwargs):
775+
answer = self.predict1(question=question, **kwargs).answer
776+
judgement = self.predict2(question=question, answer=answer, **kwargs)
777+
return judgement
778+
779+
async def xml_stream_1(*args, **kwargs):
780+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
781+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="answer"))])
782+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
783+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))])
784+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))])
785+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
786+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
787+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))])
788+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))])
789+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="!"))])
790+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
791+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="/answer"))])
792+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
793+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
794+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="completed"))])
795+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
796+
797+
async def xml_stream_2(*args, **kwargs):
798+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
799+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="judgement"))])
800+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
801+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="The"))])
802+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
803+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))])
804+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" humorous"))])
805+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="."))])
806+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
807+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="/judgement"))])
808+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
809+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))])
810+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="completed"))])
811+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))])
812+
813+
stream_generators = [xml_stream_1, xml_stream_2]
814+
815+
async def completion_side_effect(*args, **kwargs):
816+
return stream_generators.pop(0)()
817+
818+
with mock.patch("litellm.acompletion", side_effect=completion_side_effect):
819+
program = dspy.streamify(
820+
MyProgram(),
821+
stream_listeners=[
822+
dspy.streaming.StreamListener(signature_field_name="answer"),
823+
dspy.streaming.StreamListener(signature_field_name="judgement"),
824+
],
825+
)
826+
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.XMLAdapter()):
827+
output = program(question="why did a chicken cross the kitchen?")
828+
all_chunks = []
829+
async for value in output:
830+
if isinstance(value, dspy.streaming.StreamResponse):
831+
all_chunks.append(value)
832+
833+
assert all_chunks[0].predict_name == "predict1"
834+
assert all_chunks[0].signature_field_name == "answer"
835+
assert all_chunks[0].chunk == "To get to the other side!"
836+
837+
assert all_chunks[1].predict_name == "predict2"
838+
assert all_chunks[1].signature_field_name == "judgement"
839+
assert all_chunks[1].chunk == "The answer is humorous."

0 commit comments

Comments
 (0)