38
38
as_provider_message ,
39
39
)
40
40
from ._chat_tokenizer import TokenEncoding , TokenizersEncoding , get_default_tokenizer
41
- from ._chat_types import ChatMessage , ClientMessage , TransformedMessage
41
+ from ._chat_types import ChatMessage , ChatUIMessage , ClientMessage , TransformedMessage
42
42
from ._html_deps_py_shiny import chat_deps
43
43
from .fill import as_fill_item , as_fillable_container
44
44
@@ -240,7 +240,7 @@ async def _init_chat():
240
240
@reactive .effect (priority = 9999 )
241
241
@reactive .event (self ._user_input )
242
242
async def _on_user_input ():
243
- msg = ChatMessage (content = self ._user_input (), role = "user" )
243
+ msg = ChatUIMessage (content = self ._user_input (), role = "user" )
244
244
# It's possible that during the transform, a message is appended, so get
245
245
# the length now, so we can insert the new message at the right index
246
246
n_pre = len (self ._messages ())
@@ -251,7 +251,7 @@ async def _on_user_input():
251
251
else :
252
252
# A transformed value of None is a special signal to suspend input
253
253
# handling (i.e., don't generate a response)
254
- self ._store_message (as_transformed_message (msg ), index = n_pre )
254
+ self ._store_message (msg . as_transformed_message (), index = n_pre )
255
255
await self ._remove_loading_message ()
256
256
self ._suspend_input_handler = True
257
257
@@ -492,14 +492,17 @@ def messages(
492
492
res : list [ChatMessage | ProviderMessage ] = []
493
493
for i , m in enumerate (messages ):
494
494
transform = False
495
- if m [ " role" ] == "assistant" :
495
+ if m . role == "assistant" :
496
496
transform = transform_assistant
497
- elif m [ " role" ] == "user" :
497
+ elif m . role == "user" :
498
498
transform = transform_user == "all" or (
499
499
transform_user == "last" and i == len (messages ) - 1
500
500
)
501
- content_key = m ["transform_key" if transform else "pre_transform_key" ]
502
- chat_msg = ChatMessage (content = str (m [content_key ]), role = m ["role" ])
501
+ content_key = getattr (
502
+ m , "transform_key" if transform else "pre_transform_key"
503
+ )
504
+ content = getattr (m , content_key )
505
+ chat_msg = ChatMessage (content = str (content ), role = m .role )
503
506
if not isinstance (format , MISSING_TYPE ):
504
507
chat_msg = as_provider_message (chat_msg , format )
505
508
res .append (chat_msg )
@@ -593,9 +596,9 @@ async def _append_message(
593
596
else :
594
597
msg = normalize_message_chunk (message )
595
598
# Update the current stream message
596
- chunk_content = msg [ " content" ]
599
+ chunk_content = msg . content
597
600
self ._current_stream_message += chunk_content
598
- msg [ " content" ] = self ._current_stream_message
601
+ msg . content = self ._current_stream_message
599
602
if chunk == "end" :
600
603
self ._current_stream_message = ""
601
604
@@ -771,7 +774,7 @@ async def _send_append_message(
771
774
chunk : ChunkOption = False ,
772
775
icon : HTML | Tag | TagList | None = None ,
773
776
):
774
- if message [ " role" ] == "system" :
777
+ if message . role == "system" :
775
778
# System messages are not displayed in the UI
776
779
return
777
780
@@ -786,21 +789,21 @@ async def _send_append_message(
786
789
elif chunk == "end" :
787
790
chunk_type = "message_end"
788
791
789
- content = message [ " content_client" ]
792
+ content = message . content_client
790
793
content_type = "html" if isinstance (content , HTML ) else "markdown"
791
794
792
795
# TODO: pass along dependencies for both content and icon (if any)
793
796
msg = ClientMessage (
794
797
content = str (content ),
795
- role = message [ " role" ] ,
798
+ role = message . role ,
796
799
content_type = content_type ,
797
800
chunk_type = chunk_type ,
798
801
)
799
802
800
803
if icon is not None :
801
804
msg ["icon" ] = str (icon )
802
805
803
- deps = message .get ( " html_deps" , [])
806
+ deps = message .html_deps
804
807
if deps :
805
808
msg ["html_deps" ] = deps
806
809
@@ -928,19 +931,19 @@ async def _transform_wrapper(content: str, chunk: str, done: bool):
928
931
929
932
async def _transform_message (
930
933
self ,
931
- message : ChatMessage ,
934
+ message : ChatUIMessage ,
932
935
chunk : ChunkOption = False ,
933
936
chunk_content : str | None = None ,
934
937
) -> TransformedMessage | None :
935
- res = as_transformed_message (message )
936
- key = res [ " transform_key" ]
938
+ res = message . as_transformed_message ()
939
+ key = res . transform_key
937
940
938
- if message [ " role" ] == "user" and self ._transform_user is not None :
939
- content = await self ._transform_user (message [ " content" ] )
941
+ if message . role == "user" and self ._transform_user is not None :
942
+ content = await self ._transform_user (message . content )
940
943
941
- elif message [ " role" ] == "assistant" and self ._transform_assistant is not None :
944
+ elif message . role == "assistant" and self ._transform_assistant is not None :
942
945
content = await self ._transform_assistant (
943
- message [ " content" ] ,
946
+ message . content ,
944
947
chunk_content or "" ,
945
948
chunk == "end" or chunk is False ,
946
949
)
@@ -975,7 +978,7 @@ def _store_message(
975
978
messages .insert (index , message )
976
979
977
980
self ._messages .set (tuple (messages ))
978
- if message [ " role" ] == "user" :
981
+ if message . role == "user" :
979
982
self ._latest_user_input .set (message )
980
983
981
984
return None
@@ -1000,9 +1003,9 @@ def _trim_messages(
1000
1003
n_other_messages : int = 0
1001
1004
token_counts : list [int ] = []
1002
1005
for m in messages :
1003
- count = self ._get_token_count (m [ " content_server" ] )
1006
+ count = self ._get_token_count (m . content_server )
1004
1007
token_counts .append (count )
1005
- if m [ " role" ] == "system" :
1008
+ if m . role == "system" :
1006
1009
n_system_tokens += count
1007
1010
n_system_messages += 1
1008
1011
else :
@@ -1023,7 +1026,7 @@ def _trim_messages(
1023
1026
n_other_messages2 : int = 0
1024
1027
token_counts .reverse ()
1025
1028
for i , m in enumerate (reversed (messages )):
1026
- if m [ " role" ] == "system" :
1029
+ if m . role == "system" :
1027
1030
messages2 .append (m )
1028
1031
continue
1029
1032
remaining_non_system_tokens -= token_counts [i ]
@@ -1046,13 +1049,13 @@ def _trim_anthropic_messages(
1046
1049
self ,
1047
1050
messages : tuple [TransformedMessage , ...],
1048
1051
) -> tuple [TransformedMessage , ...]:
1049
- if any (m [ " role" ] == "system" for m in messages ):
1052
+ if any (m . role == "system" for m in messages ):
1050
1053
raise ValueError (
1051
1054
"Anthropic requires a system prompt to be specified in it's `.create()` method "
1052
1055
"(not in the chat messages with `role: system`)."
1053
1056
)
1054
1057
for i , m in enumerate (messages ):
1055
- if m [ " role" ] == "user" :
1058
+ if m . role == "user" :
1056
1059
return messages [i :]
1057
1060
1058
1061
return ()
@@ -1098,7 +1101,8 @@ def user_input(self, transform: bool = False) -> str | None:
1098
1101
if msg is None :
1099
1102
return None
1100
1103
key = "content_server" if transform else "content_client"
1101
- return str (msg [key ])
1104
+ val = getattr (msg , key )
1105
+ return str (val )
1102
1106
1103
1107
def _user_input (self ) -> str :
1104
1108
id = self .user_input_id
@@ -1361,27 +1365,4 @@ def chat_ui(
1361
1365
return res
1362
1366
1363
1367
1364
- def as_transformed_message (message : ChatMessage ) -> TransformedMessage :
1365
- if message ["role" ] == "user" :
1366
- transform_key = "content_server"
1367
- pre_transform_key = "content_client"
1368
- else :
1369
- transform_key = "content_client"
1370
- pre_transform_key = "content_server"
1371
-
1372
- res = TransformedMessage (
1373
- content_client = message ["content" ],
1374
- content_server = message ["content" ],
1375
- role = message ["role" ],
1376
- transform_key = transform_key ,
1377
- pre_transform_key = pre_transform_key ,
1378
- )
1379
-
1380
- deps = message .get ("html_deps" , [])
1381
- if deps :
1382
- res ["html_deps" ] = deps
1383
-
1384
- return res
1385
-
1386
-
1387
1368
CHAT_INSTANCES : WeakValueDictionary [str , Chat ] = WeakValueDictionary ()
0 commit comments