Skip to content

Commit cd43fdb

Browse files
committed
Fix: test_base_llm_flow.py - patch LLM flow streaming unit test
1 parent 9bd539e commit cd43fdb

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ async def _receive_from_model(
217217
) -> AsyncGenerator[Event, None]:
218218
"""Receive data from model and process events using BaseLlmConnection."""
219219

220-
def get_author_for_event(llm_response):
220+
def get_author_for_event(llm_response: LlmResponse) -> str:
221221
"""Get the author of the event.
222222
223223
When the model returns transcription, the author is "user". Otherwise, the
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
from unittest.mock import MagicMock
3+
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow, LlmResponse, Event, ConnectionClosedOK
4+
5+
@pytest.mark.asyncio
6+
async def test_receive_from_model_yields_events():
7+
flow = BaseLlmFlow()
8+
9+
fake_response_1 = LlmResponse(
10+
content=MagicMock(role='assistant'),
11+
error_code=None,
12+
interrupted=False
13+
)
14+
fake_response_2 = LlmResponse(
15+
content=MagicMock(role='user'),
16+
error_code=None,
17+
interrupted=False
18+
)
19+
20+
async def fake_receive():
21+
yield fake_response_1
22+
yield fake_response_2
23+
raise ConnectionClosedOK(rcvd=None, sent=None)
24+
25+
26+
llm_connection = MagicMock()
27+
llm_connection.receive = fake_receive
28+
29+
invocation_context = MagicMock()
30+
invocation_context.agent.name = "TestAgent"
31+
invocation_context.live_request_queue = MagicMock()
32+
invocation_context.transcription_cache = []
33+
invocation_context.invocation_id = "test_invocation_id_123"
34+
35+
events = []
36+
async for event in flow._receive_from_model(
37+
llm_connection, event_id="test_event", invocation_context=invocation_context, llm_request=MagicMock()
38+
):
39+
events.append(event)
40+
41+
# Add your assertions here, e.g.:
42+
assert len(events) == 2
43+
assert events[0].author == "TestAgent"
44+
assert events[1].author == "user"

0 commit comments

Comments
 (0)