diff --git a/dspy/streaming/streaming_listener.py b/dspy/streaming/streaming_listener.py index 1cedeeaeda..91489e2073 100644 --- a/dspy/streaming/streaming_listener.py +++ b/dspy/streaming/streaming_listener.py @@ -7,12 +7,15 @@ from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter +from dspy.adapters.xml_adapter import XMLAdapter from dspy.dsp.utils.settings import settings from dspy.streaming.messages import StreamResponse if TYPE_CHECKING: from dspy.primitives.module import Module +ADAPTER_SUPPORT_STREAMING = [ChatAdapter, XMLAdapter, JSONAdapter] + class StreamListener: """Class that listens to the stream to capture the streeaming of a specific output field of a predictor.""" @@ -45,11 +48,23 @@ def __init__( self.cache_hit = False self.allow_reuse = allow_reuse - self.json_adapter_start_identifier = f'"{self.signature_field_name}":' - self.json_adapter_end_identifier = re.compile(r"\w*\"(,|\s*})") - - self.chat_adapter_start_identifier = f"[[ ## {self.signature_field_name} ## ]]" - self.chat_adapter_end_identifier = re.compile(r"\[\[ ## (\w+) ## \]\]") + self.adapter_identifiers = { + "ChatAdapter": { + "start_identifier": f"[[ ## {self.signature_field_name} ## ]]", + "end_identifier": re.compile(r"\[\[ ## (\w+) ## \]\]"), + "start_indicator": "[", + }, + "JSONAdapter": { + "start_identifier": f'"{self.signature_field_name}":', + "end_identifier": re.compile(r"\w*\"(,|\s*})"), + "start_indicator": '"', + }, + "XMLAdapter": { + "start_identifier": f"<{self.signature_field_name}>", + "end_identifier": re.compile(rf""), + "start_indicator": "<", + }, + } def _buffered_message_end_with_start_identifier(self, concat_message: str, start_identifier: str) -> str: for i in range(len(concat_message)): @@ -58,21 +73,15 @@ def _buffered_message_end_with_start_identifier(self, concat_message: str, start return False def receive(self, chunk: ModelResponseStream): - if isinstance(settings.adapter, JSONAdapter): - start_identifier = self.json_adapter_start_identifier - end_identifier = self.json_adapter_end_identifier - - start_indicator = '"' - elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None: - start_identifier = self.chat_adapter_start_identifier - end_identifier = self.chat_adapter_end_identifier - - start_indicator = "[" - else: + adapter_name = settings.adapter.__class__.__name__ if settings.adapter else "ChatAdapter" + if adapter_name not in self.adapter_identifiers: raise ValueError( - f"Unsupported adapter for streaming: {settings.adapter}, please use either ChatAdapter or " - "JSONAdapter for streaming purposes." + f"Unsupported adapter for streaming: {adapter_name}, please use one of the following adapters: " + f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}" ) + start_identifier = self.adapter_identifiers[adapter_name]["start_identifier"] + end_identifier = self.adapter_identifiers[adapter_name]["end_identifier"] + start_indicator = self.adapter_identifiers[adapter_name]["start_indicator"] if self.stream_end: if self.allow_reuse: @@ -175,13 +184,18 @@ def flush(self) -> str: else: boundary_index = len(last_tokens) return last_tokens[:boundary_index] + elif isinstance(settings.adapter, XMLAdapter): + boundary_index = last_tokens.find(f"") + if boundary_index == -1: + boundary_index = len(last_tokens) + return last_tokens[:boundary_index] elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None: boundary_index = last_tokens.find("[[") return last_tokens[:boundary_index] else: raise ValueError( - f"Unsupported adapter for streaming: {settings.adapter}, please use either ChatAdapter or " - "JSONAdapter for streaming purposes." + f"Unsupported adapter for streaming: {settings.adapter}, please use one of the following adapters: " + f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}" ) diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index d79402d2ef..8a1b6941e2 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -388,7 +388,7 @@ async def gpt_4o_mini_stream_2(): async def completion_side_effect(*args, **kwargs): return stream_generators.pop(0)() # return new async generator instance - with mock.patch("litellm.acompletion", side_effect=completion_side_effect) as mock_completion: + with mock.patch("litellm.acompletion", side_effect=completion_side_effect): program = dspy.streamify( MyProgram(), stream_listeners=[ @@ -484,7 +484,7 @@ async def gpt_4o_mini_stream_2(*args, **kwargs): with mock.patch( "litellm.acompletion", new_callable=AsyncMock, side_effect=[gpt_4o_mini_stream_1(), gpt_4o_mini_stream_2()] - ) as mock_completion: + ): program = dspy.streamify( MyProgram(), stream_listeners=[ @@ -763,3 +763,78 @@ async def completion_side_effect(*args, **kwargs): concat_message = "".join([chunk.chunk for chunk in all_chunks]) # The listener functions twice. assert concat_message == "To get to the other side!To get to the other side!" + +@pytest.mark.anyio +async def test_stream_listener_returns_correct_chunk_xml_adapter(): + class MyProgram(dspy.Module): + def __init__(self): + super().__init__() + self.predict1 = dspy.Predict("question->answer") + self.predict2 = dspy.Predict("question,answer->judgement") + + def forward(self, question, **kwargs): + answer = self.predict1(question=question, **kwargs).answer + judgement = self.predict2(question=question, answer=answer, **kwargs) + return judgement + + async def xml_stream_1(*args, **kwargs): + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="answer"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="!"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="/answer"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="completed"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))]) + + async def xml_stream_2(*args, **kwargs): + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="judgement"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="The"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" humorous"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="."))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="/judgement"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="<"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="completed"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=">"))]) + + stream_generators = [xml_stream_1, xml_stream_2] + + async def completion_side_effect(*args, **kwargs): + return stream_generators.pop(0)() + + with mock.patch("litellm.acompletion", side_effect=completion_side_effect): + program = dspy.streamify( + MyProgram(), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="answer"), + dspy.streaming.StreamListener(signature_field_name="judgement"), + ], + ) + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.XMLAdapter()): + output = program(question="why did a chicken cross the kitchen?") + all_chunks = [] + async for value in output: + if isinstance(value, dspy.streaming.StreamResponse): + all_chunks.append(value) + + assert all_chunks[0].predict_name == "predict1" + assert all_chunks[0].signature_field_name == "answer" + assert all_chunks[0].chunk == "To get to the other side!" + + assert all_chunks[1].predict_name == "predict2" + assert all_chunks[1].signature_field_name == "judgement" + assert all_chunks[1].chunk == "The answer is humorous."