Skip to content

Commit 20279d9

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Save output in state via output_key only when the event is authored by current agent
PiperOrigin-RevId: 776640671
1 parent 09e487d commit 20279d9

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,12 +431,22 @@ def _llm_flow(self) -> BaseLlmFlow:
431431

432432
def __maybe_save_output_to_state(self, event: Event):
433433
"""Saves the model output to state if needed."""
434+
# skip if the event was authored by some other agent (e.g. current agent
435+
# transferred to another agent)
436+
if event.author != self.name:
437+
logger.debug(
438+
'Skipping output save for agent %s: event authored by %s',
439+
self.name,
440+
event.author,
441+
)
442+
return
434443
if (
435444
self.output_key
436445
and event.is_final_response()
437446
and event.content
438447
and event.content.parts
439448
):
449+
440450
result = ''.join(
441451
[part.text if part.text else '' for part in event.content.parts]
442452
)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for LlmAgent output saving functionality."""
16+
17+
from unittest.mock import Mock
18+
from unittest.mock import patch
19+
20+
from google.adk.agents.llm_agent import LlmAgent
21+
from google.adk.events.event import Event
22+
from google.adk.events.event_actions import EventActions
23+
from google.genai import types
24+
from pydantic import BaseModel
25+
import pytest
26+
27+
28+
class MockOutputSchema(BaseModel):
29+
message: str
30+
confidence: float
31+
32+
33+
def create_test_event(
34+
author: str = "test_agent",
35+
content_text: str = "Hello world",
36+
is_final: bool = True,
37+
invocation_id: str = "test_invocation",
38+
) -> Event:
39+
"""Helper to create test events."""
40+
# Create mock content
41+
parts = [types.Part.from_text(text=content_text)] if content_text else []
42+
content = types.Content(role="model", parts=parts) if parts else None
43+
44+
# Create event
45+
event = Event(
46+
invocation_id=invocation_id,
47+
author=author,
48+
content=content,
49+
actions=EventActions(),
50+
)
51+
52+
# Mock is_final_response if needed
53+
if not is_final:
54+
event.partial = True
55+
56+
return event
57+
58+
59+
class TestLlmAgentOutputSave:
60+
"""Test suite for LlmAgent output saving functionality."""
61+
62+
def test_maybe_save_output_to_state_skips_different_author(self, caplog):
63+
"""Test that output is not saved when event author differs from agent name."""
64+
agent = LlmAgent(name="agent_a", output_key="result")
65+
event = create_test_event(author="agent_b", content_text="Response from B")
66+
67+
with caplog.at_level("DEBUG"):
68+
agent._LlmAgent__maybe_save_output_to_state(event)
69+
70+
# Should not add anything to state_delta
71+
assert len(event.actions.state_delta) == 0
72+
73+
# Should log the skip
74+
assert (
75+
"Skipping output save for agent agent_a: event authored by agent_b"
76+
in caplog.text
77+
)
78+
79+
def test_maybe_save_output_to_state_saves_same_author(self):
80+
"""Test that output is saved when event author matches agent name."""
81+
agent = LlmAgent(name="test_agent", output_key="result")
82+
event = create_test_event(author="test_agent", content_text="Test response")
83+
84+
agent._LlmAgent__maybe_save_output_to_state(event)
85+
86+
# Should save to state_delta
87+
assert event.actions.state_delta["result"] == "Test response"
88+
89+
def test_maybe_save_output_to_state_no_output_key(self):
90+
"""Test that nothing is saved when output_key is not set."""
91+
agent = LlmAgent(name="test_agent") # No output_key
92+
event = create_test_event(author="test_agent", content_text="Test response")
93+
94+
agent._LlmAgent__maybe_save_output_to_state(event)
95+
96+
# Should not save anything
97+
assert len(event.actions.state_delta) == 0
98+
99+
def test_maybe_save_output_to_state_not_final_response(self):
100+
"""Test that output is not saved for non-final responses."""
101+
agent = LlmAgent(name="test_agent", output_key="result")
102+
event = create_test_event(
103+
author="test_agent", content_text="Partial response", is_final=False
104+
)
105+
106+
agent._LlmAgent__maybe_save_output_to_state(event)
107+
108+
# Should not save partial responses
109+
assert len(event.actions.state_delta) == 0
110+
111+
def test_maybe_save_output_to_state_no_content(self):
112+
"""Test that nothing is saved when event has no content."""
113+
agent = LlmAgent(name="test_agent", output_key="result")
114+
event = create_test_event(author="test_agent", content_text="")
115+
116+
agent._LlmAgent__maybe_save_output_to_state(event)
117+
118+
# Should not save empty content
119+
assert len(event.actions.state_delta) == 0
120+
121+
def test_maybe_save_output_to_state_with_output_schema(self):
122+
"""Test that output is processed with schema when output_schema is set."""
123+
agent = LlmAgent(
124+
name="test_agent", output_key="result", output_schema=MockOutputSchema
125+
)
126+
127+
# Create event with JSON content
128+
json_content = '{"message": "Hello", "confidence": 0.95}'
129+
event = create_test_event(author="test_agent", content_text=json_content)
130+
131+
agent._LlmAgent__maybe_save_output_to_state(event)
132+
133+
# Should save parsed and validated output
134+
expected_output = {"message": "Hello", "confidence": 0.95}
135+
assert event.actions.state_delta["result"] == expected_output
136+
137+
def test_maybe_save_output_to_state_multiple_parts(self):
138+
"""Test that multiple text parts are concatenated."""
139+
agent = LlmAgent(name="test_agent", output_key="result")
140+
141+
# Create event with multiple text parts
142+
parts = [
143+
types.Part.from_text(text="Hello "),
144+
types.Part.from_text(text="world"),
145+
types.Part.from_text(text="!"),
146+
]
147+
content = types.Content(role="model", parts=parts)
148+
149+
event = Event(
150+
invocation_id="test_invocation",
151+
author="test_agent",
152+
content=content,
153+
actions=EventActions(),
154+
)
155+
156+
agent._LlmAgent__maybe_save_output_to_state(event)
157+
158+
# Should concatenate all text parts
159+
assert event.actions.state_delta["result"] == "Hello world!"
160+
161+
def test_maybe_save_output_to_state_agent_transfer_scenario(self, caplog):
162+
"""Test realistic agent transfer scenario."""
163+
# Scenario: Agent A transfers to Agent B, Agent B produces output
164+
# Agent A should not save Agent B's output
165+
166+
agent_a = LlmAgent(name="support_agent", output_key="support_result")
167+
agent_b_event = create_test_event(
168+
author="billing_agent", content_text="Your bill is $100"
169+
)
170+
171+
with caplog.at_level("DEBUG"):
172+
agent_a._LlmAgent__maybe_save_output_to_state(agent_b_event)
173+
174+
# Agent A should not save Agent B's output
175+
assert len(agent_b_event.actions.state_delta) == 0
176+
assert (
177+
"Skipping output save for agent support_agent: event authored by"
178+
" billing_agent"
179+
in caplog.text
180+
)
181+
182+
def test_maybe_save_output_to_state_case_sensitive_names(self, caplog):
183+
"""Test that agent name comparison is case-sensitive."""
184+
agent = LlmAgent(name="TestAgent", output_key="result")
185+
event = create_test_event(author="testagent", content_text="Test response")
186+
187+
with caplog.at_level("DEBUG"):
188+
agent._LlmAgent__maybe_save_output_to_state(event)
189+
190+
# Should not save due to case mismatch
191+
assert len(event.actions.state_delta) == 0
192+
assert (
193+
"Skipping output save for agent TestAgent: event authored by testagent"
194+
in caplog.text
195+
)
196+
197+
@patch("google.adk.agents.llm_agent.logger")
198+
def test_maybe_save_output_to_state_logging(self, mock_logger):
199+
"""Test that debug logging works correctly."""
200+
agent = LlmAgent(name="agent1", output_key="result")
201+
event = create_test_event(author="agent2", content_text="Test response")
202+
203+
agent._LlmAgent__maybe_save_output_to_state(event)
204+
205+
# Should call logger.debug with correct parameters
206+
mock_logger.debug.assert_called_once_with(
207+
"Skipping output save for agent %s: event authored by %s",
208+
"agent1",
209+
"agent2",
210+
)

0 commit comments

Comments
 (0)