1
1
from __future__ import annotations
2
2
3
3
import inspect
4
- import warnings
5
4
from contextlib import asynccontextmanager
6
5
from typing import (
7
6
Any ,
83
82
84
83
ChunkOption = Literal ["start" , "end" , True , False ]
85
84
86
- PendingMessage = Tuple [Any , ChunkOption , Union [str , None ]]
85
+ PendingMessage = Tuple [Any , Literal [ "start" , "end" , True ] , Union [str , None ]]
87
86
88
87
89
88
@add_example (ex_dir = "../templates/chat/starters/hello" )
@@ -199,15 +198,12 @@ def __init__(
199
198
self .on_error = on_error
200
199
201
200
# Chunked messages get accumulated (using this property) before changing state
202
- self ._current_stream_message = ""
201
+ self ._current_stream_message : str = ""
203
202
self ._current_stream_id : str | None = None
204
203
self ._pending_messages : list [PendingMessage ] = []
205
204
206
- # Identifier for a manual stream (i.e., one started with `.start_message_stream()`)
207
- self ._manual_stream_id : str | None = None
208
- # If a manual stream gets nested within another stream, we need to keep track of
209
- # the accumulated message separately
210
- self ._nested_stream_message : str = ""
205
+ # For tracking message stream state when entering/exiting nested streams
206
+ self ._message_stream_checkpoint : str = ""
211
207
212
208
# If a user input message is transformed into a response, we need to cancel
213
209
# the next user input submit handling
@@ -576,7 +572,16 @@ async def append_message(
576
572
similar) is specified in model's completion method.
577
573
:::
578
574
"""
579
- await self ._append_message (message , icon = icon )
575
+ msg = normalize_message (message )
576
+ msg = await self ._transform_message (msg )
577
+ if msg is None :
578
+ return
579
+ self ._store_message (msg )
580
+ await self ._send_append_message (
581
+ message = msg ,
582
+ chunk = False ,
583
+ icon = icon ,
584
+ )
580
585
581
586
async def append_message_chunk (
582
587
self ,
@@ -618,9 +623,8 @@ async def append_message_chunk(
618
623
"Use .message_stream() or .append_message_stream() to start one."
619
624
)
620
625
621
- return await self ._append_message (
626
+ return await self ._append_message_chunk (
622
627
message_chunk ,
623
- chunk = True ,
624
628
stream_id = stream_id ,
625
629
operation = operation ,
626
630
)
@@ -641,75 +645,39 @@ async def message_stream(self):
641
645
to display "ephemeral" content, then eventually show a final state
642
646
with `.append_message_chunk(operation="replace")`.
643
647
"""
644
- await self ._start_stream ()
648
+ # Save the current stream state in a checkpoint (so that we can handle
649
+ # ``.append_message_chunk(operation="replace")` correctly)
650
+ old_checkpoint = self ._message_stream_checkpoint
651
+ self ._message_stream_checkpoint = self ._current_stream_message
652
+
653
+ # No stream currently exists, start one
654
+ is_root_stream = not self ._current_stream_id
655
+ if is_root_stream :
656
+ await self ._append_message_chunk (
657
+ "" ,
658
+ chunk = "start" ,
659
+ stream_id = _utils .private_random_id (),
660
+ )
661
+
645
662
try :
646
663
yield
647
664
finally :
648
- await self ._end_stream ()
649
-
650
- async def _start_stream (self ):
651
- if self ._manual_stream_id is not None :
652
- # TODO: support this?
653
- raise ValueError ("Nested .message_stream() isn't currently supported." )
654
- # If we're currently streaming (i.e., through append_message_stream()), then
655
- # end the client message stream (since we start a new one below)
656
- if self ._current_stream_id is not None :
657
- await self ._send_append_message (
658
- message = ChatMessage (content = "" , role = "assistant" ),
659
- chunk = "end" ,
660
- operation = "append" ,
661
- )
662
- # Regardless whether this is an "inner" stream, we start a new message on the
663
- # client so it can handle `operation="replace"` without having to track where
664
- # the inner stream started.
665
- self ._manual_stream_id = _utils .private_random_id ()
666
- stream_id = self ._current_stream_id or self ._manual_stream_id
667
- return await self ._append_message (
668
- "" ,
669
- chunk = "start" ,
670
- stream_id = stream_id ,
671
- # TODO: find a cleaner way to do this, and remove the gap between the messages
672
- icon = (
673
- HTML ("<span class='border-0'><span>" )
674
- if self ._is_nested_stream
675
- else None
676
- ),
677
- )
678
-
679
- async def _end_stream (self ):
680
- if self ._manual_stream_id is None and self ._current_stream_id is None :
681
- warnings .warn (
682
- "Tried to end a message stream, but one isn't currently active." ,
683
- stacklevel = 2 ,
684
- )
685
- return
686
-
687
- if self ._is_nested_stream :
688
- # If inside another stream, just update server-side message state
689
- self ._current_stream_message += self ._nested_stream_message
690
- self ._nested_stream_message = ""
691
- else :
692
- # Otherwise, end this "manual" message stream
693
- await self ._append_message (
694
- "" , chunk = "end" , stream_id = self ._manual_stream_id
695
- )
696
-
697
- self ._manual_stream_id = None
698
- return
699
-
700
- @property
701
- def _is_nested_stream (self ):
702
- return (
703
- self ._current_stream_id is not None
704
- and self ._manual_stream_id is not None
705
- and self ._current_stream_id != self ._manual_stream_id
706
- )
665
+ # Restore the previous stream state
666
+ self ._message_stream_checkpoint = old_checkpoint
667
+
668
+ # If this was the root stream, end it
669
+ if is_root_stream :
670
+ await self ._append_message_chunk (
671
+ "" ,
672
+ chunk = "end" ,
673
+ stream_id = self ._current_stream_id ,
674
+ )
707
675
708
- async def _append_message (
676
+ async def _append_message_chunk (
709
677
self ,
710
678
message : Any ,
711
679
* ,
712
- chunk : ChunkOption = False ,
680
+ chunk : Literal [ True , "start" , "end" ] = True ,
713
681
operation : Literal ["append" , "replace" ] = "append" ,
714
682
stream_id : str | None = None ,
715
683
icon : HTML | Tag | TagList | None = None ,
@@ -724,37 +692,40 @@ async def _append_message(
724
692
if chunk == "end" :
725
693
self ._current_stream_id = None
726
694
727
- if chunk is False :
728
- msg = normalize_message (message )
695
+ # Normalize into a ChatMessage()
696
+ msg = normalize_message_chunk (message )
697
+
698
+ # Remember this content chunk for passing to transformer
699
+ this_chunk = msg .content
700
+
701
+ # Transforming requires replacing
702
+ if self ._needs_transform (msg ):
703
+ operation = "replace"
704
+
705
+ if operation == "replace" :
706
+ # Replace up to the latest checkpoint
707
+ self ._current_stream_message = self ._message_stream_checkpoint + this_chunk
708
+ msg .content = self ._current_stream_message
729
709
else :
730
- msg = normalize_message_chunk (message )
731
- if self ._is_nested_stream :
732
- if operation == "replace" :
733
- self ._nested_stream_message = ""
734
- self ._nested_stream_message += msg .content
735
- else :
736
- if operation == "replace" :
737
- self ._current_stream_message = ""
738
- self ._current_stream_message += msg .content
710
+ self ._current_stream_message += msg .content
739
711
740
712
try :
741
- msg = await self ._transform_message (msg , chunk = chunk )
742
- # Act like nothing happened if transformed to None
743
- if msg is None :
744
- return
745
- msg_store = msg
746
- # Transforming requires *replacing* content
747
- if isinstance (msg , TransformedMessage ):
748
- operation = "replace"
713
+ if self ._needs_transform (msg ):
714
+ msg = await self ._transform_message (
715
+ msg , chunk = chunk , chunk_content = this_chunk
716
+ )
717
+ # Act like nothing happened if transformed to None
718
+ if msg is None :
719
+ return
720
+ if chunk == "end" :
721
+ self ._store_message (msg )
749
722
elif chunk == "end" :
750
- # When not transforming, ensure full message is stored
751
- msg_store = ChatMessage (
752
- content = self ._current_stream_message ,
753
- role = "assistant" ,
723
+ # When `operation="append"`, msg.content is just a chunk, but we must
724
+ # store the full message
725
+ self ._store_message (
726
+ ChatMessage ( content = self . _current_stream_message , role = msg . role )
754
727
)
755
- # Only store full messages
756
- if chunk is False or chunk == "end" :
757
- self ._store_message (msg_store )
728
+
758
729
# Send the message to the client
759
730
await self ._send_append_message (
760
731
message = msg ,
@@ -764,10 +735,8 @@ async def _append_message(
764
735
)
765
736
finally :
766
737
if chunk == "end" :
767
- if self ._is_nested_stream :
768
- self ._nested_stream_message = ""
769
- else :
770
- self ._current_stream_message = ""
738
+ self ._current_stream_message = ""
739
+ self ._message_stream_checkpoint = ""
771
740
772
741
async def append_message_stream (
773
742
self ,
@@ -898,21 +867,21 @@ async def _append_message_stream(
898
867
id = _utils .private_random_id ()
899
868
900
869
empty = ChatMessageDict (content = "" , role = "assistant" )
901
- await self ._append_message (empty , chunk = "start" , stream_id = id , icon = icon )
870
+ await self ._append_message_chunk (empty , chunk = "start" , stream_id = id , icon = icon )
902
871
903
872
try :
904
873
async for msg in message :
905
- await self ._append_message (msg , chunk = True , stream_id = id )
874
+ await self ._append_message_chunk (msg , chunk = True , stream_id = id )
906
875
return self ._current_stream_message
907
876
finally :
908
- await self ._append_message (empty , chunk = "end" , stream_id = id )
877
+ await self ._append_message_chunk (empty , chunk = "end" , stream_id = id )
909
878
await self ._flush_pending_messages ()
910
879
911
880
async def _flush_pending_messages (self ):
912
881
still_pending : list [PendingMessage ] = []
913
882
for msg , chunk , stream_id in self ._pending_messages :
914
883
if self ._can_append_message (stream_id ):
915
- await self ._append_message (msg , chunk = chunk , stream_id = stream_id )
884
+ await self ._append_message_chunk (msg , chunk = chunk , stream_id = stream_id )
916
885
else :
917
886
still_pending .append ((msg , chunk , stream_id ))
918
887
self ._pending_messages = still_pending
@@ -1093,23 +1062,20 @@ async def _transform_message(
1093
1062
self ,
1094
1063
message : ChatMessage ,
1095
1064
chunk : ChunkOption = False ,
1096
- ) -> ChatMessage | TransformedMessage | None :
1065
+ chunk_content : str = "" ,
1066
+ ) -> TransformedMessage | None :
1097
1067
res = TransformedMessage .from_chat_message (message )
1098
1068
1099
1069
if message .role == "user" and self ._transform_user is not None :
1100
1070
content = await self ._transform_user (message .content )
1101
1071
elif message .role == "assistant" and self ._transform_assistant is not None :
1102
- all_content = (
1103
- message .content if chunk is False else self ._current_stream_message
1104
- )
1105
- setattr (res , res .pre_transform_key , all_content )
1106
1072
content = await self ._transform_assistant (
1107
- all_content ,
1108
1073
message .content ,
1074
+ chunk_content ,
1109
1075
chunk == "end" or chunk is False ,
1110
1076
)
1111
1077
else :
1112
- return message
1078
+ return res
1113
1079
1114
1080
if content is None :
1115
1081
return None
@@ -1118,6 +1084,13 @@ async def _transform_message(
1118
1084
1119
1085
return res
1120
1086
1087
+ def _needs_transform (self , message : ChatMessage ) -> bool :
1088
+ if message .role == "user" and self ._transform_user is not None :
1089
+ return True
1090
+ elif message .role == "assistant" and self ._transform_assistant is not None :
1091
+ return True
1092
+ return False
1093
+
1121
1094
# Just before storing, handle chunk msg type and calculate tokens
1122
1095
def _store_message (
1123
1096
self ,
0 commit comments