Skip to content

Commit 1bd16ec

Browse files
committed
Support nested streams and simplify logic
1 parent 648d9cc commit 1bd16ec

File tree

2 files changed

+120
-130
lines changed

2 files changed

+120
-130
lines changed

shiny/ui/_chat.py

Lines changed: 87 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import inspect
4-
import warnings
54
from contextlib import asynccontextmanager
65
from typing import (
76
Any,
@@ -83,7 +82,7 @@
8382

8483
ChunkOption = Literal["start", "end", True, False]
8584

86-
PendingMessage = Tuple[Any, ChunkOption, Union[str, None]]
85+
PendingMessage = Tuple[Any, Literal["start", "end", True], Union[str, None]]
8786

8887

8988
@add_example(ex_dir="../templates/chat/starters/hello")
@@ -199,15 +198,12 @@ def __init__(
199198
self.on_error = on_error
200199

201200
# Chunked messages get accumulated (using this property) before changing state
202-
self._current_stream_message = ""
201+
self._current_stream_message: str = ""
203202
self._current_stream_id: str | None = None
204203
self._pending_messages: list[PendingMessage] = []
205204

206-
# Identifier for a manual stream (i.e., one started with `.start_message_stream()`)
207-
self._manual_stream_id: str | None = None
208-
# If a manual stream gets nested within another stream, we need to keep track of
209-
# the accumulated message separately
210-
self._nested_stream_message: str = ""
205+
# For tracking message stream state when entering/exiting nested streams
206+
self._message_stream_checkpoint: str = ""
211207

212208
# If a user input message is transformed into a response, we need to cancel
213209
# the next user input submit handling
@@ -576,7 +572,16 @@ async def append_message(
576572
similar) is specified in model's completion method.
577573
:::
578574
"""
579-
await self._append_message(message, icon=icon)
575+
msg = normalize_message(message)
576+
msg = await self._transform_message(msg)
577+
if msg is None:
578+
return
579+
self._store_message(msg)
580+
await self._send_append_message(
581+
message=msg,
582+
chunk=False,
583+
icon=icon,
584+
)
580585

581586
async def append_message_chunk(
582587
self,
@@ -618,9 +623,8 @@ async def append_message_chunk(
618623
"Use .message_stream() or .append_message_stream() to start one."
619624
)
620625

621-
return await self._append_message(
626+
return await self._append_message_chunk(
622627
message_chunk,
623-
chunk=True,
624628
stream_id=stream_id,
625629
operation=operation,
626630
)
@@ -641,75 +645,39 @@ async def message_stream(self):
641645
to display "ephemeral" content, then eventually show a final state
642646
with `.append_message_chunk(operation="replace")`.
643647
"""
644-
await self._start_stream()
648+
# Save the current stream state in a checkpoint (so that we can handle
649+
# ``.append_message_chunk(operation="replace")` correctly)
650+
old_checkpoint = self._message_stream_checkpoint
651+
self._message_stream_checkpoint = self._current_stream_message
652+
653+
# No stream currently exists, start one
654+
is_root_stream = not self._current_stream_id
655+
if is_root_stream:
656+
await self._append_message_chunk(
657+
"",
658+
chunk="start",
659+
stream_id=_utils.private_random_id(),
660+
)
661+
645662
try:
646663
yield
647664
finally:
648-
await self._end_stream()
649-
650-
async def _start_stream(self):
651-
if self._manual_stream_id is not None:
652-
# TODO: support this?
653-
raise ValueError("Nested .message_stream() isn't currently supported.")
654-
# If we're currently streaming (i.e., through append_message_stream()), then
655-
# end the client message stream (since we start a new one below)
656-
if self._current_stream_id is not None:
657-
await self._send_append_message(
658-
message=ChatMessage(content="", role="assistant"),
659-
chunk="end",
660-
operation="append",
661-
)
662-
# Regardless whether this is an "inner" stream, we start a new message on the
663-
# client so it can handle `operation="replace"` without having to track where
664-
# the inner stream started.
665-
self._manual_stream_id = _utils.private_random_id()
666-
stream_id = self._current_stream_id or self._manual_stream_id
667-
return await self._append_message(
668-
"",
669-
chunk="start",
670-
stream_id=stream_id,
671-
# TODO: find a cleaner way to do this, and remove the gap between the messages
672-
icon=(
673-
HTML("<span class='border-0'><span>")
674-
if self._is_nested_stream
675-
else None
676-
),
677-
)
678-
679-
async def _end_stream(self):
680-
if self._manual_stream_id is None and self._current_stream_id is None:
681-
warnings.warn(
682-
"Tried to end a message stream, but one isn't currently active.",
683-
stacklevel=2,
684-
)
685-
return
686-
687-
if self._is_nested_stream:
688-
# If inside another stream, just update server-side message state
689-
self._current_stream_message += self._nested_stream_message
690-
self._nested_stream_message = ""
691-
else:
692-
# Otherwise, end this "manual" message stream
693-
await self._append_message(
694-
"", chunk="end", stream_id=self._manual_stream_id
695-
)
696-
697-
self._manual_stream_id = None
698-
return
699-
700-
@property
701-
def _is_nested_stream(self):
702-
return (
703-
self._current_stream_id is not None
704-
and self._manual_stream_id is not None
705-
and self._current_stream_id != self._manual_stream_id
706-
)
665+
# Restore the previous stream state
666+
self._message_stream_checkpoint = old_checkpoint
667+
668+
# If this was the root stream, end it
669+
if is_root_stream:
670+
await self._append_message_chunk(
671+
"",
672+
chunk="end",
673+
stream_id=self._current_stream_id,
674+
)
707675

708-
async def _append_message(
676+
async def _append_message_chunk(
709677
self,
710678
message: Any,
711679
*,
712-
chunk: ChunkOption = False,
680+
chunk: Literal[True, "start", "end"] = True,
713681
operation: Literal["append", "replace"] = "append",
714682
stream_id: str | None = None,
715683
icon: HTML | Tag | TagList | None = None,
@@ -724,37 +692,40 @@ async def _append_message(
724692
if chunk == "end":
725693
self._current_stream_id = None
726694

727-
if chunk is False:
728-
msg = normalize_message(message)
695+
# Normalize into a ChatMessage()
696+
msg = normalize_message_chunk(message)
697+
698+
# Remember this content chunk for passing to transformer
699+
this_chunk = msg.content
700+
701+
# Transforming requires replacing
702+
if self._needs_transform(msg):
703+
operation = "replace"
704+
705+
if operation == "replace":
706+
# Replace up to the latest checkpoint
707+
self._current_stream_message = self._message_stream_checkpoint + this_chunk
708+
msg.content = self._current_stream_message
729709
else:
730-
msg = normalize_message_chunk(message)
731-
if self._is_nested_stream:
732-
if operation == "replace":
733-
self._nested_stream_message = ""
734-
self._nested_stream_message += msg.content
735-
else:
736-
if operation == "replace":
737-
self._current_stream_message = ""
738-
self._current_stream_message += msg.content
710+
self._current_stream_message += msg.content
739711

740712
try:
741-
msg = await self._transform_message(msg, chunk=chunk)
742-
# Act like nothing happened if transformed to None
743-
if msg is None:
744-
return
745-
msg_store = msg
746-
# Transforming requires *replacing* content
747-
if isinstance(msg, TransformedMessage):
748-
operation = "replace"
713+
if self._needs_transform(msg):
714+
msg = await self._transform_message(
715+
msg, chunk=chunk, chunk_content=this_chunk
716+
)
717+
# Act like nothing happened if transformed to None
718+
if msg is None:
719+
return
720+
if chunk == "end":
721+
self._store_message(msg)
749722
elif chunk == "end":
750-
# When not transforming, ensure full message is stored
751-
msg_store = ChatMessage(
752-
content=self._current_stream_message,
753-
role="assistant",
723+
# When `operation="append"`, msg.content is just a chunk, but we must
724+
# store the full message
725+
self._store_message(
726+
ChatMessage(content=self._current_stream_message, role=msg.role)
754727
)
755-
# Only store full messages
756-
if chunk is False or chunk == "end":
757-
self._store_message(msg_store)
728+
758729
# Send the message to the client
759730
await self._send_append_message(
760731
message=msg,
@@ -764,10 +735,8 @@ async def _append_message(
764735
)
765736
finally:
766737
if chunk == "end":
767-
if self._is_nested_stream:
768-
self._nested_stream_message = ""
769-
else:
770-
self._current_stream_message = ""
738+
self._current_stream_message = ""
739+
self._message_stream_checkpoint = ""
771740

772741
async def append_message_stream(
773742
self,
@@ -898,21 +867,21 @@ async def _append_message_stream(
898867
id = _utils.private_random_id()
899868

900869
empty = ChatMessageDict(content="", role="assistant")
901-
await self._append_message(empty, chunk="start", stream_id=id, icon=icon)
870+
await self._append_message_chunk(empty, chunk="start", stream_id=id, icon=icon)
902871

903872
try:
904873
async for msg in message:
905-
await self._append_message(msg, chunk=True, stream_id=id)
874+
await self._append_message_chunk(msg, chunk=True, stream_id=id)
906875
return self._current_stream_message
907876
finally:
908-
await self._append_message(empty, chunk="end", stream_id=id)
877+
await self._append_message_chunk(empty, chunk="end", stream_id=id)
909878
await self._flush_pending_messages()
910879

911880
async def _flush_pending_messages(self):
912881
still_pending: list[PendingMessage] = []
913882
for msg, chunk, stream_id in self._pending_messages:
914883
if self._can_append_message(stream_id):
915-
await self._append_message(msg, chunk=chunk, stream_id=stream_id)
884+
await self._append_message_chunk(msg, chunk=chunk, stream_id=stream_id)
916885
else:
917886
still_pending.append((msg, chunk, stream_id))
918887
self._pending_messages = still_pending
@@ -1093,23 +1062,20 @@ async def _transform_message(
10931062
self,
10941063
message: ChatMessage,
10951064
chunk: ChunkOption = False,
1096-
) -> ChatMessage | TransformedMessage | None:
1065+
chunk_content: str = "",
1066+
) -> TransformedMessage | None:
10971067
res = TransformedMessage.from_chat_message(message)
10981068

10991069
if message.role == "user" and self._transform_user is not None:
11001070
content = await self._transform_user(message.content)
11011071
elif message.role == "assistant" and self._transform_assistant is not None:
1102-
all_content = (
1103-
message.content if chunk is False else self._current_stream_message
1104-
)
1105-
setattr(res, res.pre_transform_key, all_content)
11061072
content = await self._transform_assistant(
1107-
all_content,
11081073
message.content,
1074+
chunk_content,
11091075
chunk == "end" or chunk is False,
11101076
)
11111077
else:
1112-
return message
1078+
return res
11131079

11141080
if content is None:
11151081
return None
@@ -1118,6 +1084,13 @@ async def _transform_message(
11181084

11191085
return res
11201086

1087+
def _needs_transform(self, message: ChatMessage) -> bool:
1088+
if message.role == "user" and self._transform_user is not None:
1089+
return True
1090+
elif message.role == "assistant" and self._transform_assistant is not None:
1091+
return True
1092+
return False
1093+
11211094
# Just before storing, handle chunk msg type and calculate tokens
11221095
def _store_message(
11231096
self,

tests/playwright/shiny/components/chat/inject/app.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from shiny import reactive
44
from shiny.express import input, render, ui
55

6-
ui.page_opts(title="Hello Chat")
6+
ui.page_opts(title="Hello message streams")
77

88
chat = ui.Chat(id="chat")
99
chat.ui()
@@ -15,30 +15,47 @@ async def _():
1515
await chat.append_message_stream(mock_stream())
1616

1717

18+
SLEEP_TIME = 0.25
19+
20+
1821
async def mock_stream():
1922
yield "Starting outer stream...\n\n"
20-
await asyncio.sleep(0.5)
23+
await asyncio.sleep(SLEEP_TIME)
2124
await mock_tool()
22-
await asyncio.sleep(0.5)
25+
await asyncio.sleep(SLEEP_TIME)
2326
yield "\n\n...outer stream complete"
2427

2528

26-
# While the "outer" `.append_message_stream()` is running,
27-
# start an "inner" stream with .message_stream()
2829
async def mock_tool():
29-
steps = [
30-
"Starting inner stream 🔄...\n\n",
31-
"Progress: 0%...",
32-
"Progress: 50%...",
33-
"Progress: 100%...",
34-
]
30+
# While the "outer" `.append_message_stream()` is running,
31+
# start an "inner" stream with .message_stream()
3532
async with chat.message_stream():
36-
for chunk in steps:
37-
await chat.append_message_chunk(chunk)
38-
await asyncio.sleep(0.5)
33+
await chat.append_message_chunk("\n\nStarting inner stream 1 🔄...")
34+
await asyncio.sleep(SLEEP_TIME)
35+
await chat.append_message_chunk("Progress: 0%")
36+
await asyncio.sleep(SLEEP_TIME)
37+
38+
async with chat.message_stream():
39+
await chat.append_message_chunk("\n\nStarting nested stream 2 🔄...")
40+
await asyncio.sleep(SLEEP_TIME)
41+
await chat.append_message_chunk("Progress: 0%")
42+
await asyncio.sleep(SLEEP_TIME)
43+
await chat.append_message_chunk(" Progress: 50%")
44+
await asyncio.sleep(SLEEP_TIME)
45+
await chat.append_message_chunk(" Progress: 100%")
46+
await asyncio.sleep(SLEEP_TIME)
47+
await chat.append_message_chunk(
48+
"\n\nCompleted _another_ inner stream ✅", operation="replace"
49+
)
50+
51+
await chat.append_message_chunk("\n\nBack to stream 1...")
52+
await chat.append_message_chunk(" Progress: 50%")
53+
await asyncio.sleep(SLEEP_TIME)
54+
await chat.append_message_chunk(" Progress: 100%")
55+
await asyncio.sleep(SLEEP_TIME)
56+
3957
await chat.append_message_chunk(
40-
"Completed inner stream ✅",
41-
operation="replace",
58+
"\n\nCompleted inner _and nested_ stream ✅", operation="replace"
4259
)
4360

4461

0 commit comments

Comments
 (0)