Skip to content

Commit dadd0c4

Browse files
committed
Move away from inheriting from TypedDict for internal Chat classes
1 parent c066da0 commit dadd0c4

File tree

4 files changed

+107
-108
lines changed

4 files changed

+107
-108
lines changed

shiny/session/_utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from contextvars import ContextVar, Token
1111
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
1212

13-
from htmltools import TagChild
14-
1513
if TYPE_CHECKING:
1614
from ._session import Session
1715

@@ -134,17 +132,6 @@ def require_active_session(session: Optional[Session]) -> Session:
134132
return session
135133

136134

137-
def process_ui(ui: TagChild) -> tuple[str, list[dict[str, str]]]:
138-
"""
139-
Process a UI element with the session, returning the HTML and dependencies.
140-
"""
141-
if isinstance(ui, (str, float, int)):
142-
return str(ui), []
143-
session = require_active_session(None)
144-
res = session._process_ui(ui)
145-
return res["html"], res["deps"]
146-
147-
148135
# Ideally I'd love not to limit the types for T, but if I don't, the type checker has
149136
# trouble figuring out what `T` is supposed to be when run_thunk is actually used. For
150137
# now, just keep expanding the possible types, as needed.

shiny/ui/_chat.py

Lines changed: 31 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
as_provider_message,
3939
)
4040
from ._chat_tokenizer import TokenEncoding, TokenizersEncoding, get_default_tokenizer
41-
from ._chat_types import ChatMessage, ClientMessage, TransformedMessage
41+
from ._chat_types import ChatMessage, ChatUIMessage, ClientMessage, TransformedMessage
4242
from ._html_deps_py_shiny import chat_deps
4343
from .fill import as_fill_item, as_fillable_container
4444

@@ -240,7 +240,7 @@ async def _init_chat():
240240
@reactive.effect(priority=9999)
241241
@reactive.event(self._user_input)
242242
async def _on_user_input():
243-
msg = ChatMessage(content=self._user_input(), role="user")
243+
msg = ChatUIMessage(content=self._user_input(), role="user")
244244
# It's possible that during the transform, a message is appended, so get
245245
# the length now, so we can insert the new message at the right index
246246
n_pre = len(self._messages())
@@ -251,7 +251,7 @@ async def _on_user_input():
251251
else:
252252
# A transformed value of None is a special signal to suspend input
253253
# handling (i.e., don't generate a response)
254-
self._store_message(as_transformed_message(msg), index=n_pre)
254+
self._store_message(msg.as_transformed_message(), index=n_pre)
255255
await self._remove_loading_message()
256256
self._suspend_input_handler = True
257257

@@ -492,14 +492,17 @@ def messages(
492492
res: list[ChatMessage | ProviderMessage] = []
493493
for i, m in enumerate(messages):
494494
transform = False
495-
if m["role"] == "assistant":
495+
if m.role == "assistant":
496496
transform = transform_assistant
497-
elif m["role"] == "user":
497+
elif m.role == "user":
498498
transform = transform_user == "all" or (
499499
transform_user == "last" and i == len(messages) - 1
500500
)
501-
content_key = m["transform_key" if transform else "pre_transform_key"]
502-
chat_msg = ChatMessage(content=str(m[content_key]), role=m["role"])
501+
content_key = getattr(
502+
m, "transform_key" if transform else "pre_transform_key"
503+
)
504+
content = getattr(m, content_key)
505+
chat_msg = ChatMessage(content=str(content), role=m.role)
503506
if not isinstance(format, MISSING_TYPE):
504507
chat_msg = as_provider_message(chat_msg, format)
505508
res.append(chat_msg)
@@ -593,9 +596,9 @@ async def _append_message(
593596
else:
594597
msg = normalize_message_chunk(message)
595598
# Update the current stream message
596-
chunk_content = msg["content"]
599+
chunk_content = msg.content
597600
self._current_stream_message += chunk_content
598-
msg["content"] = self._current_stream_message
601+
msg.content = self._current_stream_message
599602
if chunk == "end":
600603
self._current_stream_message = ""
601604

@@ -771,7 +774,7 @@ async def _send_append_message(
771774
chunk: ChunkOption = False,
772775
icon: HTML | Tag | TagList | None = None,
773776
):
774-
if message["role"] == "system":
777+
if message.role == "system":
775778
# System messages are not displayed in the UI
776779
return
777780

@@ -786,21 +789,21 @@ async def _send_append_message(
786789
elif chunk == "end":
787790
chunk_type = "message_end"
788791

789-
content = message["content_client"]
792+
content = message.content_client
790793
content_type = "html" if isinstance(content, HTML) else "markdown"
791794

792795
# TODO: pass along dependencies for both content and icon (if any)
793796
msg = ClientMessage(
794797
content=str(content),
795-
role=message["role"],
798+
role=message.role,
796799
content_type=content_type,
797800
chunk_type=chunk_type,
798801
)
799802

800803
if icon is not None:
801804
msg["icon"] = str(icon)
802805

803-
deps = message.get("html_deps", [])
806+
deps = message.html_deps
804807
if deps:
805808
msg["html_deps"] = deps
806809

@@ -928,19 +931,19 @@ async def _transform_wrapper(content: str, chunk: str, done: bool):
928931

929932
async def _transform_message(
930933
self,
931-
message: ChatMessage,
934+
message: ChatUIMessage,
932935
chunk: ChunkOption = False,
933936
chunk_content: str | None = None,
934937
) -> TransformedMessage | None:
935-
res = as_transformed_message(message)
936-
key = res["transform_key"]
938+
res = message.as_transformed_message()
939+
key = res.transform_key
937940

938-
if message["role"] == "user" and self._transform_user is not None:
939-
content = await self._transform_user(message["content"])
941+
if message.role == "user" and self._transform_user is not None:
942+
content = await self._transform_user(message.content)
940943

941-
elif message["role"] == "assistant" and self._transform_assistant is not None:
944+
elif message.role == "assistant" and self._transform_assistant is not None:
942945
content = await self._transform_assistant(
943-
message["content"],
946+
message.content,
944947
chunk_content or "",
945948
chunk == "end" or chunk is False,
946949
)
@@ -975,7 +978,7 @@ def _store_message(
975978
messages.insert(index, message)
976979

977980
self._messages.set(tuple(messages))
978-
if message["role"] == "user":
981+
if message.role == "user":
979982
self._latest_user_input.set(message)
980983

981984
return None
@@ -1000,9 +1003,9 @@ def _trim_messages(
10001003
n_other_messages: int = 0
10011004
token_counts: list[int] = []
10021005
for m in messages:
1003-
count = self._get_token_count(m["content_server"])
1006+
count = self._get_token_count(m.content_server)
10041007
token_counts.append(count)
1005-
if m["role"] == "system":
1008+
if m.role == "system":
10061009
n_system_tokens += count
10071010
n_system_messages += 1
10081011
else:
@@ -1023,7 +1026,7 @@ def _trim_messages(
10231026
n_other_messages2: int = 0
10241027
token_counts.reverse()
10251028
for i, m in enumerate(reversed(messages)):
1026-
if m["role"] == "system":
1029+
if m.role == "system":
10271030
messages2.append(m)
10281031
continue
10291032
remaining_non_system_tokens -= token_counts[i]
@@ -1046,13 +1049,13 @@ def _trim_anthropic_messages(
10461049
self,
10471050
messages: tuple[TransformedMessage, ...],
10481051
) -> tuple[TransformedMessage, ...]:
1049-
if any(m["role"] == "system" for m in messages):
1052+
if any(m.role == "system" for m in messages):
10501053
raise ValueError(
10511054
"Anthropic requires a system prompt to be specified in it's `.create()` method "
10521055
"(not in the chat messages with `role: system`)."
10531056
)
10541057
for i, m in enumerate(messages):
1055-
if m["role"] == "user":
1058+
if m.role == "user":
10561059
return messages[i:]
10571060

10581061
return ()
@@ -1098,7 +1101,8 @@ def user_input(self, transform: bool = False) -> str | None:
10981101
if msg is None:
10991102
return None
11001103
key = "content_server" if transform else "content_client"
1101-
return str(msg[key])
1104+
val = getattr(msg, key)
1105+
return str(val)
11021106

11031107
def _user_input(self) -> str:
11041108
id = self.user_input_id
@@ -1361,27 +1365,4 @@ def chat_ui(
13611365
return res
13621366

13631367

1364-
def as_transformed_message(message: ChatMessage) -> TransformedMessage:
1365-
if message["role"] == "user":
1366-
transform_key = "content_server"
1367-
pre_transform_key = "content_client"
1368-
else:
1369-
transform_key = "content_client"
1370-
pre_transform_key = "content_server"
1371-
1372-
res = TransformedMessage(
1373-
content_client=message["content"],
1374-
content_server=message["content"],
1375-
role=message["role"],
1376-
transform_key=transform_key,
1377-
pre_transform_key=pre_transform_key,
1378-
)
1379-
1380-
deps = message.get("html_deps", [])
1381-
if deps:
1382-
res["html_deps"] = deps
1383-
1384-
return res
1385-
1386-
13871368
CHAT_INSTANCES: WeakValueDictionary[str, Chat] = WeakValueDictionary()

0 commit comments

Comments
 (0)