|
| 1 | +import pytest |
| 2 | +from unittest.mock import MagicMock |
| 3 | +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow, LlmResponse, Event, ConnectionClosedOK |
| 4 | + |
| 5 | +@pytest.mark.asyncio |
| 6 | +async def test_receive_from_model_yields_events(): |
| 7 | + flow = BaseLlmFlow() |
| 8 | + |
| 9 | + fake_response_1 = LlmResponse( |
| 10 | + content=MagicMock(role='assistant'), |
| 11 | + error_code=None, |
| 12 | + interrupted=False |
| 13 | + ) |
| 14 | + fake_response_2 = LlmResponse( |
| 15 | + content=MagicMock(role='user'), |
| 16 | + error_code=None, |
| 17 | + interrupted=False |
| 18 | + ) |
| 19 | + |
| 20 | + async def fake_receive(): |
| 21 | + yield fake_response_1 |
| 22 | + yield fake_response_2 |
| 23 | + raise ConnectionClosedOK(rcvd=None, sent=None) |
| 24 | + |
| 25 | + |
| 26 | + llm_connection = MagicMock() |
| 27 | + llm_connection.receive = fake_receive |
| 28 | + |
| 29 | + invocation_context = MagicMock() |
| 30 | + invocation_context.agent.name = "TestAgent" |
| 31 | + invocation_context.live_request_queue = MagicMock() |
| 32 | + invocation_context.transcription_cache = [] |
| 33 | + invocation_context.invocation_id = "test_invocation_id_123" |
| 34 | + |
| 35 | + events = [] |
| 36 | + async for event in flow._receive_from_model( |
| 37 | + llm_connection, event_id="test_event", invocation_context=invocation_context, llm_request=MagicMock() |
| 38 | + ): |
| 39 | + events.append(event) |
| 40 | + |
| 41 | + # Add your assertions here, e.g.: |
| 42 | + assert len(events) == 2 |
| 43 | + assert events[0].author == "TestAgent" |
| 44 | + assert events[1].author == "user" |
0 commit comments