Skip to content

Commit 4b1e04e

Browse files
committed
Update tests
1 parent dadd0c4 commit 4b1e04e

File tree

1 file changed

+81
-72
lines changed

1 file changed

+81
-72
lines changed

tests/pytest/test_chat.py

Lines changed: 81 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from shiny.session import session_context
1212
from shiny.types import MISSING
1313
from shiny.ui import Chat
14-
from shiny.ui._chat import as_transformed_message
1514
from shiny.ui._chat_normalize import normalize_message, normalize_message_chunk
16-
from shiny.ui._chat_types import ChatMessage
15+
from shiny.ui._chat_types import ChatMessage, ChatUIMessage
1716

1817
# ----------------------------------------------------------------------
1918
# Helpers
@@ -52,31 +51,22 @@ def generate_content(token_count: int) -> str:
5251
return " ".join(["foo" for _ in range(1, n)])
5352

5453
msgs = (
55-
as_transformed_message(
56-
{
57-
"content": generate_content(102),
58-
"role": "system",
59-
}
60-
),
54+
ChatUIMessage(
55+
content=generate_content(102), role="system"
56+
).as_transformed_message(),
6157
)
6258

6359
# Throws since system message is too long
6460
with pytest.raises(ValueError):
6561
chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING)
6662

6763
msgs = (
68-
as_transformed_message(
69-
{
70-
"content": generate_content(100),
71-
"role": "system",
72-
}
73-
),
74-
as_transformed_message(
75-
{
76-
"content": generate_content(2),
77-
"role": "user",
78-
}
79-
),
64+
ChatUIMessage(
65+
content=generate_content(100), role="system"
66+
).as_transformed_message(),
67+
ChatUIMessage(
68+
content=generate_content(2), role="user"
69+
).as_transformed_message(),
8070
)
8171

8272
# Throws since only the system message fits
@@ -92,30 +82,24 @@ def generate_content(token_count: int) -> str:
9282
content3 = generate_content(2)
9383

9484
msgs = (
95-
as_transformed_message(
96-
{
97-
"content": content1,
98-
"role": "system",
99-
}
100-
),
101-
as_transformed_message(
102-
{
103-
"content": content2,
104-
"role": "user",
105-
}
106-
),
107-
as_transformed_message(
108-
{
109-
"content": content3,
110-
"role": "user",
111-
}
112-
),
85+
ChatUIMessage(
86+
content=content1,
87+
role="system",
88+
).as_transformed_message(),
89+
ChatUIMessage(
90+
content=content2,
91+
role="user",
92+
).as_transformed_message(),
93+
ChatUIMessage(
94+
content=content3,
95+
role="user",
96+
).as_transformed_message(),
11397
)
11498

11599
# Should discard the 1st user message
116100
trimmed = chat._trim_messages(msgs, token_limits=(103, 0), format=MISSING)
117101
assert len(trimmed) == 2
118-
contents = [msg["content_server"] for msg in trimmed]
102+
contents = [msg.content_server for msg in trimmed]
119103
assert contents == [content1, content3]
120104

121105
content1 = generate_content(50)
@@ -124,38 +108,48 @@ def generate_content(token_count: int) -> str:
124108
content4 = generate_content(2)
125109

126110
msgs = (
127-
as_transformed_message(
128-
{"content": content1, "role": "system"},
129-
),
130-
as_transformed_message(
131-
{"content": content2, "role": "user"},
132-
),
133-
as_transformed_message(
134-
{"content": content3, "role": "system"},
135-
),
136-
as_transformed_message(
137-
{"content": content4, "role": "user"},
138-
),
111+
ChatUIMessage(
112+
content=content1,
113+
role="system",
114+
).as_transformed_message(),
115+
ChatUIMessage(
116+
content=content2,
117+
role="user",
118+
).as_transformed_message(),
119+
ChatUIMessage(
120+
content=content3,
121+
role="system",
122+
).as_transformed_message(),
123+
ChatUIMessage(
124+
content=content4,
125+
role="user",
126+
).as_transformed_message(),
139127
)
140128

141129
# Should discard the 1st user message
142130
trimmed = chat._trim_messages(msgs, token_limits=(103, 0), format=MISSING)
143131
assert len(trimmed) == 3
144-
contents = [msg["content_server"] for msg in trimmed]
132+
contents = [msg.content_server for msg in trimmed]
145133
assert contents == [content1, content3, content4]
146134

147135
content1 = generate_content(50)
148136
content2 = generate_content(10)
149137

150138
msgs = (
151-
as_transformed_message({"content": content1, "role": "assistant"}),
152-
as_transformed_message({"content": content2, "role": "user"}),
139+
ChatUIMessage(
140+
content=content1,
141+
role="assistant",
142+
).as_transformed_message(),
143+
ChatUIMessage(
144+
content=content2,
145+
role="user",
146+
).as_transformed_message(),
153147
)
154148

155149
# Anthropic requires 1st message to be a user message
156150
trimmed = chat._trim_messages(msgs, token_limits=(30, 0), format="anthropic")
157151
assert len(trimmed) == 1
158-
contents = [msg["content_server"] for msg in trimmed]
152+
contents = [msg.content_server for msg in trimmed]
159153
assert contents == [content2]
160154

161155

@@ -172,13 +166,15 @@ def generate_content(token_count: int) -> str:
172166

173167

174168
def test_string_normalization():
175-
msg = normalize_message_chunk("Hello world!")
176-
assert msg == {"content": "Hello world!", "role": "assistant"}
169+
m = normalize_message_chunk("Hello world!")
170+
assert m.content == "Hello world!"
171+
assert m.role == "assistant"
177172

178173

179174
def test_dict_normalization():
180-
msg = normalize_message_chunk({"content": "Hello world!", "role": "assistant"})
181-
assert msg == {"content": "Hello world!", "role": "assistant"}
175+
m = normalize_message_chunk({"content": "Hello world!", "role": "assistant"})
176+
assert m.content == "Hello world!"
177+
assert m.role == "assistant"
182178

183179

184180
def test_langchain_normalization():
@@ -194,11 +190,15 @@ def test_langchain_normalization():
194190

195191
# Mock & normalize return value of BaseChatModel.invoke()
196192
msg = BaseMessage(content="Hello world!", role="assistant", type="foo")
197-
assert normalize_message(msg) == {"content": "Hello world!", "role": "assistant"}
193+
m = normalize_message(msg)
194+
assert m.content == "Hello world!"
195+
assert m.role == "assistant"
198196

199197
# Mock & normalize return value of BaseChatModel.stream()
200198
chunk = BaseMessageChunk(content="Hello ", type="foo")
201-
assert normalize_message_chunk(chunk) == {"content": "Hello ", "role": "assistant"}
199+
m = normalize_message_chunk(chunk)
200+
assert m.content == "Hello "
201+
assert m.role == "assistant"
202202

203203

204204
def test_google_normalization():
@@ -255,7 +255,9 @@ def test_anthropic_normalization():
255255
usage=Usage(input_tokens=0, output_tokens=0),
256256
)
257257

258-
assert normalize_message(msg) == {"content": "Hello world!", "role": "assistant"}
258+
m = normalize_message(msg)
259+
assert m.content == "Hello world!"
260+
assert m.role == "assistant"
259261

260262
# Mock return object from Anthropic().messages.create(stream=True)
261263
chunk = RawContentBlockDeltaEvent(
@@ -264,7 +266,9 @@ def test_anthropic_normalization():
264266
index=0,
265267
)
266268

267-
assert normalize_message_chunk(chunk) == {"content": "Hello ", "role": "assistant"}
269+
m = normalize_message_chunk(chunk)
270+
assert m.content == "Hello "
271+
assert m.role == "assistant"
268272

269273

270274
def test_openai_normalization():
@@ -309,8 +313,9 @@ def test_openai_normalization():
309313
created=int(datetime.now().timestamp()),
310314
)
311315

312-
msg = normalize_message(completion)
313-
assert msg == {"content": "Hello world!", "role": "assistant"}
316+
m = normalize_message(completion)
317+
assert m.content == "Hello world!"
318+
assert m.role == "assistant"
314319

315320
# Mock return object from OpenAI().chat.completions.create(stream=True)
316321
chunk = ChatCompletionChunk(
@@ -329,8 +334,9 @@ def test_openai_normalization():
329334
],
330335
)
331336

332-
msg = normalize_message_chunk(chunk)
333-
assert msg == {"content": "Hello ", "role": "assistant"}
337+
m = normalize_message_chunk(chunk)
338+
assert m.content == "Hello "
339+
assert m.role == "assistant"
334340

335341

336342
def test_ollama_normalization():
@@ -343,8 +349,13 @@ def test_ollama_normalization():
343349
)
344350

345351
msg_dict = {"content": "Hello world!", "role": "assistant"}
346-
assert normalize_message(msg) == msg_dict
347-
assert normalize_message_chunk(msg) == msg_dict
352+
m = normalize_message(msg)
353+
assert m.content == msg_dict["content"]
354+
assert m.role == msg_dict["role"]
355+
356+
m = normalize_message_chunk(msg)
357+
assert m.content == msg_dict["content"]
358+
assert m.role == msg_dict["role"]
348359

349360

350361
# ------------------------------------------------------------------------------------
@@ -403,9 +414,7 @@ def test_as_google_message():
403414

404415

405416
def test_as_langchain_message():
406-
from langchain_core.language_models.base import (
407-
LanguageModelInput,
408-
)
417+
from langchain_core.language_models.base import LanguageModelInput
409418
from langchain_core.language_models.base import (
410419
Sequence as LangchainSequence, # pyright: ignore[reportPrivateImportUsage]
411420
)

0 commit comments

Comments
 (0)