|
| 1 | +import asyncio |
1 | 2 | import os
|
| 3 | +import time |
2 | 4 | from unittest import mock
|
3 | 5 | from unittest.mock import AsyncMock
|
4 | 6 |
|
@@ -649,3 +651,59 @@ async def gemini_stream_2(*args, **kwargs):
|
649 | 651 | assert all_chunks[1].predict_name == "predict2"
|
650 | 652 | assert all_chunks[1].signature_field_name == "judgement"
|
651 | 653 | assert all_chunks[1].chunk == "The answer provides a humorous and relevant punchline to the classic joke setup."
|
| 654 | + |
| 655 | + |
| 656 | +@pytest.mark.anyio |
| 657 | +async def test_status_message_non_blocking(): |
| 658 | + def dummy_tool(): |
| 659 | + time.sleep(1) |
| 660 | + return "dummy_tool_output" |
| 661 | + |
| 662 | + class MyProgram(dspy.Module): |
| 663 | + def forward(self, question, **kwargs): |
| 664 | + dspy.Tool(dummy_tool)() |
| 665 | + return dspy.Prediction(answer="dummy_tool_output") |
| 666 | + |
| 667 | + program = dspy.streamify(MyProgram(), status_message_provider=StatusMessageProvider()) |
| 668 | + |
| 669 | + with mock.patch("litellm.acompletion", new_callable=AsyncMock, side_effect=[dummy_tool]): |
| 670 | + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)): |
| 671 | + output = program(question="why did a chicken cross the kitchen?") |
| 672 | + timestamps = [] |
| 673 | + async for value in output: |
| 674 | + if isinstance(value, dspy.streaming.StatusMessage): |
| 675 | + timestamps.append(time.time()) |
| 676 | + |
| 677 | + # timestamps[0]: tool start message |
| 678 | + # timestamps[1]: tool end message |
| 679 | + # There should be ~1 second delay between the tool start and end messages because we explicitly sleep for 1 second |
| 680 | + # in the tool. |
| 681 | + assert timestamps[1] - timestamps[0] >= 1 |
| 682 | + |
| 683 | + |
| 684 | +@pytest.mark.anyio |
| 685 | +async def test_status_message_non_blocking_async_program(): |
| 686 | + async def dummy_tool(): |
| 687 | + await asyncio.sleep(1) |
| 688 | + return "dummy_tool_output" |
| 689 | + |
| 690 | + class MyProgram(dspy.Module): |
| 691 | + async def aforward(self, question, **kwargs): |
| 692 | + await dspy.Tool(dummy_tool).acall() |
| 693 | + return dspy.Prediction(answer="dummy_tool_output") |
| 694 | + |
| 695 | + program = dspy.streamify(MyProgram(), status_message_provider=StatusMessageProvider(), is_async_program=True) |
| 696 | + |
| 697 | + with mock.patch("litellm.acompletion", new_callable=AsyncMock, side_effect=[dummy_tool]): |
| 698 | + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)): |
| 699 | + output = program(question="why did a chicken cross the kitchen?") |
| 700 | + timestamps = [] |
| 701 | + async for value in output: |
| 702 | + if isinstance(value, dspy.streaming.StatusMessage): |
| 703 | + timestamps.append(time.time()) |
| 704 | + |
| 705 | + # timestamps[0]: tool start message |
| 706 | + # timestamps[1]: tool end message |
| 707 | + # There should be ~1 second delay between the tool start and end messages because we explicitly sleep for 1 second |
| 708 | + # in the tool. |
| 709 | + assert timestamps[1] - timestamps[0] >= 1 |
0 commit comments