|
23 | 23 | import dataclasses_json |
24 | 24 | from genai_processors import content_api |
25 | 25 | from genai_processors import processor |
| 26 | +import termcolor |
26 | 27 |
|
27 | 28 |
|
28 | 29 | _MAX_LOOP_COUNT = 1000 |
@@ -229,7 +230,7 @@ def _extract_part( |
229 | 230 | ] |
230 | 231 | break |
231 | 232 | offset += len(c.text) |
232 | | - to_process = left_over + part_buffer[part_idx + 1:] |
| 233 | + to_process = left_over + part_buffer[part_idx + 1 :] |
233 | 234 | return to_yield, to_process |
234 | 235 |
|
235 | 236 | async def call( |
@@ -414,3 +415,53 @@ async def terminal_input( |
414 | 415 | except EOFError: |
415 | 416 | # Exit on ctrl+D. |
416 | 417 | return |
| 418 | + |
| 419 | + |
| 420 | +async def terminal_output( |
| 421 | + content: AsyncIterable[content_api.ProcessorPartTypes], |
| 422 | + prompt: str = '', |
| 423 | +) -> None: |
| 424 | + """Prints the part to the terminal. |
| 425 | +
|
| 426 | + Consumes all the content and prints it to the terminal. Prints the prompt |
| 427 | + when an `end_of_turn` part is encountered. |
| 428 | +
|
| 429 | + The parts are printed with their role in green, red or yellow. The text parts |
| 430 | + are printed in bold. |
| 431 | +
|
| 432 | + Args: |
| 433 | + content: The content to print. |
| 434 | + prompt: The prompt to print when the model is done. |
| 435 | + """ |
| 436 | + old_part_role = None |
| 437 | + async for part in content: |
| 438 | + match part.role: |
| 439 | + case 'user': |
| 440 | + color = 'green' |
| 441 | + case 'model': |
| 442 | + color = 'red' |
| 443 | + case _: |
| 444 | + color = 'yellow' |
| 445 | + part_role = part.role or 'default' |
| 446 | + if content_api.is_text(part.mimetype): |
| 447 | + content_text = part.text |
| 448 | + else: |
| 449 | + content_text = f'<{part.mimetype}>' |
| 450 | + if part_role != old_part_role: |
| 451 | + old_part_role = part_role |
| 452 | + print( |
| 453 | + termcolor.colored( |
| 454 | + f'\n{part_role}: {content_text}', color, attrs=['bold'] |
| 455 | + ), |
| 456 | + end='', |
| 457 | + flush=True, |
| 458 | + ) |
| 459 | + else: |
| 460 | + print( |
| 461 | + termcolor.colored(f'{content_text}', color, attrs=['bold']), |
| 462 | + end='', |
| 463 | + flush=True, |
| 464 | + ) |
| 465 | + # Reprint the prompt when the model is done. |
| 466 | + if content_api.is_end_of_turn(part): |
| 467 | + print('\n' + prompt, end='', flush=True) |
0 commit comments