Skip to content

Commit f54b9b6

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Add unit tests for contents.py
PiperOrigin-RevId: 775713101
1 parent 6729edd commit f54b9b6

File tree

1 file changed

+361
-0
lines changed

1 file changed

+361
-0
lines changed
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.adk.agents import Agent
16+
from google.adk.events.event import Event
17+
from google.adk.flows.llm_flows import contents
18+
from google.adk.flows.llm_flows.contents import _convert_foreign_event
19+
from google.adk.flows.llm_flows.contents import _get_contents
20+
from google.adk.flows.llm_flows.contents import _merge_function_response_events
21+
from google.adk.flows.llm_flows.contents import _rearrange_events_for_async_function_responses_in_history
22+
from google.adk.flows.llm_flows.contents import _rearrange_events_for_latest_function_response
23+
from google.adk.models import LlmRequest
24+
from google.genai import types
25+
import pytest
26+
27+
from ... import testing_utils
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_content_processor_no_contents():
32+
"""Test ContentLlmRequestProcessor when include_contents is 'none'."""
33+
agent = Agent(model="gemini-1.5-flash", name="agent", include_contents="none")
34+
llm_request = LlmRequest(model="gemini-1.5-flash")
35+
invocation_context = await testing_utils.create_invocation_context(
36+
agent=agent
37+
)
38+
39+
# Collect events from async generator
40+
events = []
41+
async for event in contents.request_processor.run_async(
42+
invocation_context, llm_request
43+
):
44+
events.append(event)
45+
46+
# Should not yield any events
47+
assert len(events) == 0
48+
# Contents should not be set when include_contents is 'none'
49+
assert llm_request.contents == []
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_content_processor_with_contents():
54+
"""Test ContentLlmRequestProcessor when include_contents is not 'none'."""
55+
agent = Agent(model="gemini-1.5-flash", name="agent")
56+
llm_request = LlmRequest(model="gemini-1.5-flash")
57+
invocation_context = await testing_utils.create_invocation_context(
58+
agent=agent
59+
)
60+
61+
# Add some test events to the session
62+
test_event = Event(
63+
invocation_id="test_inv",
64+
author="user",
65+
content=types.Content(
66+
role="user", parts=[types.Part.from_text(text="Hello")]
67+
),
68+
)
69+
invocation_context.session.events = [test_event]
70+
71+
# Collect events from async generator
72+
events = []
73+
async for event in contents.request_processor.run_async(
74+
invocation_context, llm_request
75+
):
76+
events.append(event)
77+
78+
# Should not yield any events (processor doesn't emit events, just modifies request)
79+
assert len(events) == 0
80+
# Contents should be set
81+
assert llm_request.contents is not None
82+
assert len(llm_request.contents) == 1
83+
assert llm_request.contents[0].role == "user"
84+
assert llm_request.contents[0].parts[0].text == "Hello"
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_content_processor_non_llm_agent():
89+
"""Test ContentLlmRequestProcessor with non-LLM agent."""
90+
from google.adk.agents.base_agent import BaseAgent
91+
92+
# Create a base agent (not LLM agent)
93+
agent = BaseAgent(name="base_agent")
94+
llm_request = LlmRequest(model="gemini-1.5-flash")
95+
invocation_context = await testing_utils.create_invocation_context(
96+
agent=agent
97+
)
98+
99+
# Collect events from async generator
100+
events = []
101+
async for event in contents.request_processor.run_async(
102+
invocation_context, llm_request
103+
):
104+
events.append(event)
105+
106+
# Should not yield any events and not modify request
107+
assert len(events) == 0
108+
assert llm_request.contents == []
109+
110+
111+
def test_get_contents_empty_events():
112+
"""Test _get_contents with empty events list."""
113+
contents_result = _get_contents(None, [], "test_agent")
114+
assert contents_result == []
115+
116+
117+
def test_get_contents_with_events():
118+
"""Test _get_contents with valid events."""
119+
test_event = Event(
120+
invocation_id="test_inv",
121+
author="user",
122+
content=types.Content(
123+
role="user", parts=[types.Part.from_text(text="Hello")]
124+
),
125+
)
126+
127+
contents_result = _get_contents(None, [test_event], "test_agent")
128+
assert len(contents_result) == 1
129+
assert contents_result[0].role == "user"
130+
assert contents_result[0].parts[0].text == "Hello"
131+
132+
133+
def test_get_contents_filters_empty_events():
134+
"""Test _get_contents filters out events with empty content."""
135+
# Event with empty text
136+
empty_event = Event(
137+
invocation_id="test_inv",
138+
author="user",
139+
content=types.Content(role="user", parts=[types.Part.from_text(text="")]),
140+
)
141+
142+
# Event without content
143+
no_content_event = Event(
144+
invocation_id="test_inv",
145+
author="user",
146+
)
147+
148+
# Valid event
149+
valid_event = Event(
150+
invocation_id="test_inv",
151+
author="user",
152+
content=types.Content(
153+
role="user", parts=[types.Part.from_text(text="Hello")]
154+
),
155+
)
156+
157+
contents_result = _get_contents(
158+
None, [empty_event, no_content_event, valid_event], "test_agent"
159+
)
160+
assert len(contents_result) == 1
161+
assert contents_result[0].role == "user"
162+
assert contents_result[0].parts[0].text == "Hello"
163+
164+
165+
def test_convert_foreign_event():
166+
"""Test _convert_foreign_event function."""
167+
agent_event = Event(
168+
invocation_id="test_inv",
169+
author="agent1",
170+
content=types.Content(
171+
role="model", parts=[types.Part.from_text(text="Agent response")]
172+
),
173+
)
174+
175+
converted_event = _convert_foreign_event(agent_event)
176+
177+
assert converted_event.author == "user"
178+
assert converted_event.content.role == "user"
179+
assert len(converted_event.content.parts) == 2
180+
assert converted_event.content.parts[0].text == "For context:"
181+
assert (
182+
"[agent1] said: Agent response" in converted_event.content.parts[1].text
183+
)
184+
185+
186+
def test_convert_event_with_function_call():
187+
"""Test _convert_foreign_event with function call."""
188+
function_call = types.FunctionCall(
189+
id="func_123", name="test_function", args={"param": "value"}
190+
)
191+
192+
agent_event = Event(
193+
invocation_id="test_inv",
194+
author="agent1",
195+
content=types.Content(
196+
role="model", parts=[types.Part(function_call=function_call)]
197+
),
198+
)
199+
200+
converted_event = _convert_foreign_event(agent_event)
201+
202+
assert converted_event.author == "user"
203+
assert converted_event.content.role == "user"
204+
assert len(converted_event.content.parts) == 2
205+
assert converted_event.content.parts[0].text == "For context:"
206+
assert (
207+
"[agent1] called tool `test_function`"
208+
in converted_event.content.parts[1].text
209+
)
210+
assert "{'param': 'value'}" in converted_event.content.parts[1].text
211+
212+
213+
def test_convert_event_with_function_response():
214+
"""Test _convert_foreign_event with function response."""
215+
function_response = types.FunctionResponse(
216+
id="func_123", name="test_function", response={"result": "success"}
217+
)
218+
219+
agent_event = Event(
220+
invocation_id="test_inv",
221+
author="agent1",
222+
content=types.Content(
223+
role="user", parts=[types.Part(function_response=function_response)]
224+
),
225+
)
226+
227+
converted_event = _convert_foreign_event(agent_event)
228+
229+
assert converted_event.author == "user"
230+
assert converted_event.content.role == "user"
231+
assert len(converted_event.content.parts) == 2
232+
assert converted_event.content.parts[0].text == "For context:"
233+
assert (
234+
"[agent1] `test_function` tool returned result:"
235+
in converted_event.content.parts[1].text
236+
)
237+
assert "{'result': 'success'}" in converted_event.content.parts[1].text
238+
239+
240+
def test_merge_function_response_events():
241+
"""Test _merge_function_response_events function."""
242+
# Create initial function response event
243+
function_response1 = types.FunctionResponse(
244+
id="func_123", name="test_function", response={"status": "pending"}
245+
)
246+
247+
initial_event = Event(
248+
invocation_id="test_inv",
249+
author="user",
250+
content=types.Content(
251+
role="user", parts=[types.Part(function_response=function_response1)]
252+
),
253+
)
254+
255+
# Create final function response event
256+
function_response2 = types.FunctionResponse(
257+
id="func_123", name="test_function", response={"result": "success"}
258+
)
259+
260+
final_event = Event(
261+
invocation_id="test_inv2",
262+
author="user",
263+
content=types.Content(
264+
role="user", parts=[types.Part(function_response=function_response2)]
265+
),
266+
)
267+
268+
merged_event = _merge_function_response_events([initial_event, final_event])
269+
270+
assert (
271+
merged_event.invocation_id == "test_inv"
272+
) # Should keep initial event ID
273+
assert len(merged_event.content.parts) == 1
274+
# The first part should be replaced with the final response
275+
assert merged_event.content.parts[0].function_response.response == {
276+
"result": "success"
277+
}
278+
279+
280+
def test_rearrange_events_for_async_function_responses():
281+
"""Test _rearrange_events_for_async_function_responses_in_history function."""
282+
# Create function call event
283+
function_call = types.FunctionCall(
284+
id="func_123", name="test_function", args={"param": "value"}
285+
)
286+
287+
call_event = Event(
288+
invocation_id="test_inv1",
289+
author="agent",
290+
content=types.Content(
291+
role="model", parts=[types.Part(function_call=function_call)]
292+
),
293+
)
294+
295+
# Create function response event
296+
function_response = types.FunctionResponse(
297+
id="func_123", name="test_function", response={"result": "success"}
298+
)
299+
300+
response_event = Event(
301+
invocation_id="test_inv2",
302+
author="user",
303+
content=types.Content(
304+
role="user", parts=[types.Part(function_response=function_response)]
305+
),
306+
)
307+
308+
# Test rearrangement
309+
events = [call_event, response_event]
310+
rearranged = _rearrange_events_for_async_function_responses_in_history(events)
311+
312+
# Should have both events in correct order
313+
assert len(rearranged) == 2
314+
assert rearranged[0] == call_event
315+
assert rearranged[1] == response_event
316+
317+
318+
def test_rearrange_events_for_latest_function_response():
319+
"""Test _rearrange_events_for_latest_function_response function."""
320+
# Create function call event
321+
function_call = types.FunctionCall(
322+
id="func_123", name="test_function", args={"param": "value"}
323+
)
324+
325+
call_event = Event(
326+
invocation_id="test_inv1",
327+
author="agent",
328+
content=types.Content(
329+
role="model", parts=[types.Part(function_call=function_call)]
330+
),
331+
)
332+
333+
# Create intermediate event
334+
intermediate_event = Event(
335+
invocation_id="test_inv2",
336+
author="agent",
337+
content=types.Content(
338+
role="model", parts=[types.Part.from_text(text="Processing...")]
339+
),
340+
)
341+
342+
# Create function response event
343+
function_response = types.FunctionResponse(
344+
id="func_123", name="test_function", response={"result": "success"}
345+
)
346+
347+
response_event = Event(
348+
invocation_id="test_inv3",
349+
author="user",
350+
content=types.Content(
351+
role="user", parts=[types.Part(function_response=function_response)]
352+
),
353+
)
354+
355+
# Test with matching function call and response
356+
events = [call_event, intermediate_event, response_event]
357+
rearranged = _rearrange_events_for_latest_function_response(events)
358+
359+
# Should remove intermediate events and merge responses
360+
assert len(rearranged) == 2
361+
assert rearranged[0] == call_event

0 commit comments

Comments
 (0)