Skip to content

Commit eb8a1ff

Browse files
committed
Close #1621. Add a normalizer function argument to Chat.append_message_stream()
1 parent c97f09b commit eb8a1ff

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

shiny/ui/_chat.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,12 @@ async def append_message(self, message: Any) -> None:
507507
await self._append_message(message)
508508

509509
async def _append_message(
510-
self, message: Any, *, chunk: ChunkOption = False, stream_id: str | None = None
510+
self,
511+
message: Any,
512+
*,
513+
chunk: ChunkOption = False,
514+
stream_id: str | None = None,
515+
normalizer: Callable[[object], str] | None = None,
511516
) -> None:
512517
# If currently we're in a stream, handle other messages (outside the stream) later
513518
if not self._can_append_message(stream_id):
@@ -519,6 +524,15 @@ async def _append_message(
519524
if chunk == "end":
520525
self._current_stream_id = None
521526

527+
# Apply the user provided normalizer, if any
528+
if normalizer is not None:
529+
res = normalizer(message)
530+
if not isinstance(res, str):
531+
raise ValueError(
532+
f"Normalizer function must return a string, got {type(res)}"
533+
)
534+
message = {"content": res, "role": "assistant"}
535+
522536
if chunk is False:
523537
msg = normalize_message(message)
524538
chunk_content = None
@@ -539,7 +553,11 @@ async def _append_message(
539553
msg = self._store_message(msg, chunk=chunk)
540554
await self._send_append_message(msg, chunk=chunk)
541555

542-
async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any]):
556+
async def append_message_stream(
557+
self,
558+
message: Iterable[Any] | AsyncIterable[Any],
559+
normalizer: Callable[[object], str] | None = None,
560+
) -> None:
543561
"""
544562
Append a message as a stream of message chunks.
545563
@@ -550,6 +568,11 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any
550568
message chunk formats are supported, including a string, a dictionary with
551569
`content` and `role` keys, or a relevant chat completion object from
552570
platforms like OpenAI, Anthropic, Ollama, and others.
571+
normalizer
572+
A function to apply to each message chunk (i.e., each item of the `message`
573+
iterator) before appending it to the chat. This is useful for handling
574+
response formats that `Chat` may not already natively support. The function
575+
should take a message chunk and return a string.
553576
554577
Note
555578
----
@@ -562,7 +585,7 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any
562585
# Run the stream in the background to get non-blocking behavior
563586
@reactive.extended_task
564587
async def _stream_task():
565-
await self._append_message_stream(message)
588+
await self._append_message_stream(message, normalizer)
566589

567590
_stream_task()
568591

@@ -582,15 +605,21 @@ async def _handle_error():
582605
ctx.on_invalidate(_handle_error.destroy)
583606
self._effects.append(_handle_error)
584607

585-
async def _append_message_stream(self, message: AsyncIterable[Any]):
608+
async def _append_message_stream(
609+
self,
610+
message: AsyncIterable[Any],
611+
normalizer: Callable[[object], str] | None = None,
612+
) -> None:
586613
id = _utils.private_random_id()
587614

588615
empty = ChatMessage(content="", role="assistant")
589616
await self._append_message(empty, chunk="start", stream_id=id)
590617

591618
try:
592619
async for msg in message:
593-
await self._append_message(msg, chunk=True, stream_id=id)
620+
await self._append_message(
621+
msg, chunk=True, stream_id=id, normalizer=normalizer
622+
)
594623
finally:
595624
await self._append_message(empty, chunk="end", stream_id=id)
596625
await self._flush_pending_messages()

0 commit comments

Comments
 (0)