Skip to content

Commit 2c64a30

Browse files
avidelatmcopybara-github
authored andcommitted
fix: make LiteLLM streaming truly asynchronous
Merge google#1451 ## Description Fixes google#1306 by using `async for` with `await self.llm_client.acompletion()` instead of synchronous `for` loop. ## Changes - Updated test mocks to properly handle async streaming by creating an async generator - Ensured proper parameter handling to avoid duplicate stream parameter ## Testing Plan - All unit tests now pass with the async streaming implementation - Verified with `pytest tests/unittests/models/test_litellm.py` that all streaming tests pass - Manually tested with a sample agent using LiteLLM to confirm streaming works properly # Test Evidence: https://youtu.be/hSp3otI79DM Let me know if you need anything else from me for this PR COPYBARA_INTEGRATE_REVIEW=google#1451 from avidelatm:fix/litellm-async-streaming d35b9dc PiperOrigin-RevId: 774835130
1 parent 53de35a commit 2c64a30

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ async def generate_content_async(
679679
aggregated_llm_response_with_tool_call = None
680680
usage_metadata = None
681681
fallback_index = 0
682-
for part in self.llm_client.completion(**completion_args):
682+
async for part in await self.llm_client.acompletion(**completion_args):
683683
for chunk, finish_reason in _model_response_to_chunk(part):
684684
if isinstance(chunk, FunctionChunk):
685685
index = chunk.index or fallback_index

tests/unittests/models/test_litellm.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,26 @@ def __init__(self, acompletion_mock, completion_mock):
416416
self.completion_mock = completion_mock
417417

418418
async def acompletion(self, model, messages, tools, **kwargs):
419-
return await self.acompletion_mock(
420-
model=model, messages=messages, tools=tools, **kwargs
421-
)
419+
if kwargs.get("stream", False):
420+
kwargs_copy = dict(kwargs)
421+
kwargs_copy.pop("stream", None)
422+
423+
async def stream_generator():
424+
stream_data = self.completion_mock(
425+
model=model,
426+
messages=messages,
427+
tools=tools,
428+
stream=True,
429+
**kwargs_copy,
430+
)
431+
for item in stream_data:
432+
yield item
433+
434+
return stream_generator()
435+
else:
436+
return await self.acompletion_mock(
437+
model=model, messages=messages, tools=tools, **kwargs
438+
)
422439

423440
def completion(self, model, messages, tools, stream, **kwargs):
424441
return self.completion_mock(

0 commit comments

Comments
 (0)