11
11
from shiny .session import session_context
12
12
from shiny .types import MISSING
13
13
from shiny .ui import Chat
14
- from shiny .ui ._chat import as_transformed_message
15
14
from shiny .ui ._chat_normalize import normalize_message , normalize_message_chunk
16
- from shiny .ui ._chat_types import ChatMessage
15
+ from shiny .ui ._chat_types import ChatMessage , ChatUIMessage
17
16
18
17
# ----------------------------------------------------------------------
19
18
# Helpers
@@ -52,31 +51,22 @@ def generate_content(token_count: int) -> str:
52
51
return " " .join (["foo" for _ in range (1 , n )])
53
52
54
53
msgs = (
55
- as_transformed_message (
56
- {
57
- "content" : generate_content (102 ),
58
- "role" : "system" ,
59
- }
60
- ),
54
+ ChatUIMessage (
55
+ content = generate_content (102 ), role = "system"
56
+ ).as_transformed_message (),
61
57
)
62
58
63
59
# Throws since system message is too long
64
60
with pytest .raises (ValueError ):
65
61
chat ._trim_messages (msgs , token_limits = (100 , 0 ), format = MISSING )
66
62
67
63
msgs = (
68
- as_transformed_message (
69
- {
70
- "content" : generate_content (100 ),
71
- "role" : "system" ,
72
- }
73
- ),
74
- as_transformed_message (
75
- {
76
- "content" : generate_content (2 ),
77
- "role" : "user" ,
78
- }
79
- ),
64
+ ChatUIMessage (
65
+ content = generate_content (100 ), role = "system"
66
+ ).as_transformed_message (),
67
+ ChatUIMessage (
68
+ content = generate_content (2 ), role = "user"
69
+ ).as_transformed_message (),
80
70
)
81
71
82
72
# Throws since only the system message fits
@@ -92,30 +82,24 @@ def generate_content(token_count: int) -> str:
92
82
content3 = generate_content (2 )
93
83
94
84
msgs = (
95
- as_transformed_message (
96
- {
97
- "content" : content1 ,
98
- "role" : "system" ,
99
- }
100
- ),
101
- as_transformed_message (
102
- {
103
- "content" : content2 ,
104
- "role" : "user" ,
105
- }
106
- ),
107
- as_transformed_message (
108
- {
109
- "content" : content3 ,
110
- "role" : "user" ,
111
- }
112
- ),
85
+ ChatUIMessage (
86
+ content = content1 ,
87
+ role = "system" ,
88
+ ).as_transformed_message (),
89
+ ChatUIMessage (
90
+ content = content2 ,
91
+ role = "user" ,
92
+ ).as_transformed_message (),
93
+ ChatUIMessage (
94
+ content = content3 ,
95
+ role = "user" ,
96
+ ).as_transformed_message (),
113
97
)
114
98
115
99
# Should discard the 1st user message
116
100
trimmed = chat ._trim_messages (msgs , token_limits = (103 , 0 ), format = MISSING )
117
101
assert len (trimmed ) == 2
118
- contents = [msg [ " content_server" ] for msg in trimmed ]
102
+ contents = [msg . content_server for msg in trimmed ]
119
103
assert contents == [content1 , content3 ]
120
104
121
105
content1 = generate_content (50 )
@@ -124,38 +108,48 @@ def generate_content(token_count: int) -> str:
124
108
content4 = generate_content (2 )
125
109
126
110
msgs = (
127
- as_transformed_message (
128
- {"content" : content1 , "role" : "system" },
129
- ),
130
- as_transformed_message (
131
- {"content" : content2 , "role" : "user" },
132
- ),
133
- as_transformed_message (
134
- {"content" : content3 , "role" : "system" },
135
- ),
136
- as_transformed_message (
137
- {"content" : content4 , "role" : "user" },
138
- ),
111
+ ChatUIMessage (
112
+ content = content1 ,
113
+ role = "system" ,
114
+ ).as_transformed_message (),
115
+ ChatUIMessage (
116
+ content = content2 ,
117
+ role = "user" ,
118
+ ).as_transformed_message (),
119
+ ChatUIMessage (
120
+ content = content3 ,
121
+ role = "system" ,
122
+ ).as_transformed_message (),
123
+ ChatUIMessage (
124
+ content = content4 ,
125
+ role = "user" ,
126
+ ).as_transformed_message (),
139
127
)
140
128
141
129
# Should discard the 1st user message
142
130
trimmed = chat ._trim_messages (msgs , token_limits = (103 , 0 ), format = MISSING )
143
131
assert len (trimmed ) == 3
144
- contents = [msg [ " content_server" ] for msg in trimmed ]
132
+ contents = [msg . content_server for msg in trimmed ]
145
133
assert contents == [content1 , content3 , content4 ]
146
134
147
135
content1 = generate_content (50 )
148
136
content2 = generate_content (10 )
149
137
150
138
msgs = (
151
- as_transformed_message ({"content" : content1 , "role" : "assistant" }),
152
- as_transformed_message ({"content" : content2 , "role" : "user" }),
139
+ ChatUIMessage (
140
+ content = content1 ,
141
+ role = "assistant" ,
142
+ ).as_transformed_message (),
143
+ ChatUIMessage (
144
+ content = content2 ,
145
+ role = "user" ,
146
+ ).as_transformed_message (),
153
147
)
154
148
155
149
# Anthropic requires 1st message to be a user message
156
150
trimmed = chat ._trim_messages (msgs , token_limits = (30 , 0 ), format = "anthropic" )
157
151
assert len (trimmed ) == 1
158
- contents = [msg [ " content_server" ] for msg in trimmed ]
152
+ contents = [msg . content_server for msg in trimmed ]
159
153
assert contents == [content2 ]
160
154
161
155
@@ -172,13 +166,15 @@ def generate_content(token_count: int) -> str:
172
166
173
167
174
168
def test_string_normalization ():
175
- msg = normalize_message_chunk ("Hello world!" )
176
- assert msg == {"content" : "Hello world!" , "role" : "assistant" }
169
+ m = normalize_message_chunk ("Hello world!" )
170
+ assert m .content == "Hello world!"
171
+ assert m .role == "assistant"
177
172
178
173
179
174
def test_dict_normalization ():
180
- msg = normalize_message_chunk ({"content" : "Hello world!" , "role" : "assistant" })
181
- assert msg == {"content" : "Hello world!" , "role" : "assistant" }
175
+ m = normalize_message_chunk ({"content" : "Hello world!" , "role" : "assistant" })
176
+ assert m .content == "Hello world!"
177
+ assert m .role == "assistant"
182
178
183
179
184
180
def test_langchain_normalization ():
@@ -194,11 +190,15 @@ def test_langchain_normalization():
194
190
195
191
# Mock & normalize return value of BaseChatModel.invoke()
196
192
msg = BaseMessage (content = "Hello world!" , role = "assistant" , type = "foo" )
197
- assert normalize_message (msg ) == {"content" : "Hello world!" , "role" : "assistant" }
193
+ m = normalize_message (msg )
194
+ assert m .content == "Hello world!"
195
+ assert m .role == "assistant"
198
196
199
197
# Mock & normalize return value of BaseChatModel.stream()
200
198
chunk = BaseMessageChunk (content = "Hello " , type = "foo" )
201
- assert normalize_message_chunk (chunk ) == {"content" : "Hello " , "role" : "assistant" }
199
+ m = normalize_message_chunk (chunk )
200
+ assert m .content == "Hello "
201
+ assert m .role == "assistant"
202
202
203
203
204
204
def test_google_normalization ():
@@ -255,7 +255,9 @@ def test_anthropic_normalization():
255
255
usage = Usage (input_tokens = 0 , output_tokens = 0 ),
256
256
)
257
257
258
- assert normalize_message (msg ) == {"content" : "Hello world!" , "role" : "assistant" }
258
+ m = normalize_message (msg )
259
+ assert m .content == "Hello world!"
260
+ assert m .role == "assistant"
259
261
260
262
# Mock return object from Anthropic().messages.create(stream=True)
261
263
chunk = RawContentBlockDeltaEvent (
@@ -264,7 +266,9 @@ def test_anthropic_normalization():
264
266
index = 0 ,
265
267
)
266
268
267
- assert normalize_message_chunk (chunk ) == {"content" : "Hello " , "role" : "assistant" }
269
+ m = normalize_message_chunk (chunk )
270
+ assert m .content == "Hello "
271
+ assert m .role == "assistant"
268
272
269
273
270
274
def test_openai_normalization ():
@@ -309,8 +313,9 @@ def test_openai_normalization():
309
313
created = int (datetime .now ().timestamp ()),
310
314
)
311
315
312
- msg = normalize_message (completion )
313
- assert msg == {"content" : "Hello world!" , "role" : "assistant" }
316
+ m = normalize_message (completion )
317
+ assert m .content == "Hello world!"
318
+ assert m .role == "assistant"
314
319
315
320
# Mock return object from OpenAI().chat.completions.create(stream=True)
316
321
chunk = ChatCompletionChunk (
@@ -329,8 +334,9 @@ def test_openai_normalization():
329
334
],
330
335
)
331
336
332
- msg = normalize_message_chunk (chunk )
333
- assert msg == {"content" : "Hello " , "role" : "assistant" }
337
+ m = normalize_message_chunk (chunk )
338
+ assert m .content == "Hello "
339
+ assert m .role == "assistant"
334
340
335
341
336
342
def test_ollama_normalization ():
@@ -343,8 +349,13 @@ def test_ollama_normalization():
343
349
)
344
350
345
351
msg_dict = {"content" : "Hello world!" , "role" : "assistant" }
346
- assert normalize_message (msg ) == msg_dict
347
- assert normalize_message_chunk (msg ) == msg_dict
352
+ m = normalize_message (msg )
353
+ assert m .content == msg_dict ["content" ]
354
+ assert m .role == msg_dict ["role" ]
355
+
356
+ m = normalize_message_chunk (msg )
357
+ assert m .content == msg_dict ["content" ]
358
+ assert m .role == msg_dict ["role" ]
348
359
349
360
350
361
# ------------------------------------------------------------------------------------
@@ -403,9 +414,7 @@ def test_as_google_message():
403
414
404
415
405
416
def test_as_langchain_message ():
406
- from langchain_core .language_models .base import (
407
- LanguageModelInput ,
408
- )
417
+ from langchain_core .language_models .base import LanguageModelInput
409
418
from langchain_core .language_models .base import (
410
419
Sequence as LangchainSequence , # pyright: ignore[reportPrivateImportUsage]
411
420
)
0 commit comments