Skip to content

Commit aa497d2

Browse files
aelisseekibergus
authored andcommitted
adds a terminal_output function to pretty-print model output in CLI.
PiperOrigin-RevId: 820159448
1 parent 87db38a commit aa497d2

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

examples/realtime_simple_cli.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
from genai_processors.core import text_to_speech
5959
from genai_processors.examples import models
6060
import pyaudio
61-
import termcolor
6261

6362
# You need to define the project id in the environment variables.
6463
GOOGLE_PROJECT_ID = os.environ['GOOGLE_PROJECT_ID']
@@ -144,23 +143,10 @@ async def run_conversation() -> None:
144143
+ realtime.LiveProcessor(turn_processor=genai_processor + tts)
145144
+ play_output
146145
)
147-
148-
async for part in conversation_agent(text.terminal_input()):
149-
# Print the transcription and the output of the model (should include status
150-
# parts and other metadata parts)
151-
match part.role:
152-
case 'user':
153-
color = 'green'
154-
case 'model':
155-
color = 'red'
156-
case _:
157-
color = 'yellow'
158-
part_role = part.role or 'default'
159-
print(
160-
termcolor.colored(
161-
f'{part_role}: {part.text}', color, 'on_grey', attrs=['bold']
162-
)
163-
)
146+
prompt = 'USER (ctrl+D to end)> '
147+
await text.terminal_output(
148+
conversation_agent(text.terminal_input(prompt=prompt)), prompt=prompt
149+
)
164150

165151

166152
def main(argv: Sequence[str]):

genai_processors/core/text.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import dataclasses_json
2424
from genai_processors import content_api
2525
from genai_processors import processor
26+
import termcolor
2627

2728

2829
_MAX_LOOP_COUNT = 1000
@@ -229,7 +230,7 @@ def _extract_part(
229230
]
230231
break
231232
offset += len(c.text)
232-
to_process = left_over + part_buffer[part_idx + 1:]
233+
to_process = left_over + part_buffer[part_idx + 1 :]
233234
return to_yield, to_process
234235

235236
async def call(
@@ -414,3 +415,53 @@ async def terminal_input(
414415
except EOFError:
415416
# Exit on ctrl+D.
416417
return
418+
419+
420+
async def terminal_output(
421+
content: AsyncIterable[content_api.ProcessorPartTypes],
422+
prompt: str = '',
423+
) -> None:
424+
"""Prints the part to the terminal.
425+
426+
Consumes all the content and prints it to the terminal. Prints the prompt
427+
when an `end_of_turn` part is encountered.
428+
429+
The parts are printed with their role in green, red or yellow. The text parts
430+
are printed in bold.
431+
432+
Args:
433+
content: The content to print.
434+
prompt: The prompt to print when the model is done.
435+
"""
436+
old_part_role = None
437+
async for part in content:
438+
match part.role:
439+
case 'user':
440+
color = 'green'
441+
case 'model':
442+
color = 'red'
443+
case _:
444+
color = 'yellow'
445+
part_role = part.role or 'default'
446+
if content_api.is_text(part.mimetype):
447+
content_text = part.text
448+
else:
449+
content_text = f'<{part.mimetype}>'
450+
if part_role != old_part_role:
451+
old_part_role = part_role
452+
print(
453+
termcolor.colored(
454+
f'\n{part_role}: {content_text}', color, attrs=['bold']
455+
),
456+
end='',
457+
flush=True,
458+
)
459+
else:
460+
print(
461+
termcolor.colored(f'{content_text}', color, attrs=['bold']),
462+
end='',
463+
flush=True,
464+
)
465+
# Reprint the prompt when the model is done.
466+
if content_api.is_end_of_turn(part):
467+
print('\n' + prompt, end='', flush=True)

0 commit comments

Comments
 (0)