Skip to content

Commit 90360f4

Browse files
pwwpchecopybara-github
authored andcommitted
feat: add new callbacks to handle tool and model errors
This CL add new callbacks in plugin system: - `on_tool_error_callback` - `on_model_error_callback` This allow the user to create plugins that can handle errors. PiperOrigin-RevId: 783052800
1 parent 31fa5d9 commit 90360f4

File tree

9 files changed

+339
-5
lines changed

9 files changed

+339
-5
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,13 @@ async def _call_llm_async(
521521
with tracer.start_as_current_span('call_llm'):
522522
if invocation_context.run_config.support_cfc:
523523
invocation_context.live_request_queue = LiveRequestQueue()
524-
async for llm_response in self.run_live(invocation_context):
524+
responses_generator = self.run_live(invocation_context)
525+
async for llm_response in self._run_and_handle_error(
526+
responses_generator,
527+
invocation_context,
528+
llm_request,
529+
model_response_event,
530+
):
525531
# Runs after_model_callback if it exists.
526532
if altered_llm_response := await self._handle_after_model_callback(
527533
invocation_context, llm_response, model_response_event
@@ -540,10 +546,16 @@ async def _call_llm_async(
540546
# the counter beyond the max set value, then the execution is stopped
541547
# right here, and exception is thrown.
542548
invocation_context.increment_llm_call_count()
543-
async for llm_response in llm.generate_content_async(
549+
responses_generator = llm.generate_content_async(
544550
llm_request,
545551
stream=invocation_context.run_config.streaming_mode
546552
== StreamingMode.SSE,
553+
)
554+
async for llm_response in self._run_and_handle_error(
555+
responses_generator,
556+
invocation_context,
557+
llm_request,
558+
model_response_event,
547559
):
548560
trace_call_llm(
549561
invocation_context,
@@ -660,6 +672,43 @@ def _finalize_model_response_event(
660672

661673
return model_response_event
662674

675+
async def _run_and_handle_error(
676+
self,
677+
response_generator: AsyncGenerator[LlmResponse, None],
678+
invocation_context: InvocationContext,
679+
llm_request: LlmRequest,
680+
model_response_event: Event,
681+
) -> AsyncGenerator[LlmResponse, None]:
682+
"""Runs the response generator and processes the error with plugins.
683+
684+
Args:
685+
response_generator: The response generator to run.
686+
invocation_context: The invocation context.
687+
llm_request: The LLM request.
688+
model_response_event: The model response event.
689+
690+
Yields:
691+
A generator of LlmResponse.
692+
"""
693+
try:
694+
async for response in response_generator:
695+
yield response
696+
except Exception as model_error:
697+
callback_context = CallbackContext(
698+
invocation_context, event_actions=model_response_event.actions
699+
)
700+
error_response = (
701+
await invocation_context.plugin_manager.run_on_model_error_callback(
702+
callback_context=callback_context,
703+
llm_request=llm_request,
704+
error=model_error,
705+
)
706+
)
707+
if error_response is not None:
708+
yield error_response
709+
else:
710+
raise model_error
711+
663712
def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
664713
from ...agents.llm_agent import LlmAgent
665714

src/google/adk/flows/llm_flows/functions.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,21 @@ async def handle_function_calls_async(
176176

177177
# Step 3: Otherwise, proceed calling the tool normally.
178178
if function_response is None:
179-
function_response = await __call_tool_async(
180-
tool, args=function_args, tool_context=tool_context
181-
)
179+
try:
180+
function_response = await __call_tool_async(
181+
tool, args=function_args, tool_context=tool_context
182+
)
183+
except Exception as tool_error:
184+
error_response = await invocation_context.plugin_manager.run_on_tool_error_callback(
185+
tool=tool,
186+
tool_args=function_args,
187+
tool_context=tool_context,
188+
error=tool_error,
189+
)
190+
if error_response is not None:
191+
function_response = error_response
192+
else:
193+
raise tool_error
182194

183195
# Step 4: Check if plugin after_tool_callback overrides the function
184196
# response.

src/google/adk/plugins/base_plugin.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,31 @@ async def after_model_callback(
265265
"""
266266
pass
267267

268+
async def on_model_error_callback(
269+
self,
270+
*,
271+
callback_context: CallbackContext,
272+
llm_request: LlmRequest,
273+
error: Exception,
274+
) -> Optional[LlmResponse]:
275+
"""Callback executed when a model call encounters an error.
276+
277+
This callback provides an opportunity to handle model errors gracefully,
278+
potentially providing alternative responses or recovery mechanisms.
279+
280+
Args:
281+
callback_context: The context for the current agent call.
282+
llm_request: The request that was sent to the model when the error
283+
occurred.
284+
error: The exception that was raised during model execution.
285+
286+
Returns:
287+
An optional LlmResponse. If an LlmResponse is returned, it will be used
288+
instead of propagating the error. Returning `None` allows the original
289+
error to be raised.
290+
"""
291+
pass
292+
268293
async def before_tool_callback(
269294
self,
270295
*,
@@ -315,3 +340,29 @@ async def after_tool_callback(
315340
result.
316341
"""
317342
pass
343+
344+
async def on_tool_error_callback(
345+
self,
346+
*,
347+
tool: BaseTool,
348+
tool_args: dict[str, Any],
349+
tool_context: ToolContext,
350+
error: Exception,
351+
) -> Optional[dict]:
352+
"""Callback executed when a tool call encounters an error.
353+
354+
This callback provides an opportunity to handle tool errors gracefully,
355+
potentially providing alternative responses or recovery mechanisms.
356+
357+
Args:
358+
tool: The tool instance that encountered an error.
359+
tool_args: The arguments that were passed to the tool.
360+
tool_context: The context specific to the tool execution.
361+
error: The exception that was raised during tool execution.
362+
363+
Returns:
364+
An optional dictionary. If a dictionary is returned, it will be used as
365+
the tool response instead of propagating the error. Returning `None`
366+
allows the original error to be raised.
367+
"""
368+
pass

src/google/adk/plugins/plugin_manager.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
"after_tool_callback",
4949
"before_model_callback",
5050
"after_model_callback",
51+
"on_tool_error_callback",
52+
"on_model_error_callback",
5153
]
5254

5355
logger = logging.getLogger("google_adk." + __name__)
@@ -195,6 +197,21 @@ async def run_after_tool_callback(
195197
result=result,
196198
)
197199

200+
async def run_on_model_error_callback(
201+
self,
202+
*,
203+
callback_context: CallbackContext,
204+
llm_request: LlmRequest,
205+
error: Exception,
206+
) -> Optional[LlmResponse]:
207+
"""Runs the `on_model_error_callback` for all plugins."""
208+
return await self._run_callbacks(
209+
"on_model_error_callback",
210+
callback_context=callback_context,
211+
llm_request=llm_request,
212+
error=error,
213+
)
214+
198215
async def run_before_model_callback(
199216
self, *, callback_context: CallbackContext, llm_request: LlmRequest
200217
) -> Optional[LlmResponse]:
@@ -215,6 +232,23 @@ async def run_after_model_callback(
215232
llm_response=llm_response,
216233
)
217234

235+
async def run_on_tool_error_callback(
236+
self,
237+
*,
238+
tool: BaseTool,
239+
tool_args: dict[str, Any],
240+
tool_context: ToolContext,
241+
error: Exception,
242+
) -> Optional[dict]:
243+
"""Runs the `on_tool_error_callback` for all plugins."""
244+
return await self._run_callbacks(
245+
"on_tool_error_callback",
246+
tool=tool,
247+
tool_args=tool_args,
248+
tool_context=tool_context,
249+
error=error,
250+
)
251+
218252
async def _run_callbacks(
219253
self, callback_name: PluginCallbackName, **kwargs: Any
220254
) -> Optional[Any]:

tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,33 @@
2020
from google.adk.models import LlmResponse
2121
from google.adk.plugins.base_plugin import BasePlugin
2222
from google.genai import types
23+
from google.genai.errors import ClientError
2324
import pytest
2425

2526
from ... import testing_utils
2627

28+
mock_error = ClientError(
29+
code=429,
30+
response_json={
31+
'error': {
32+
'code': 429,
33+
'message': 'Quota exceeded.',
34+
'status': 'RESOURCE_EXHAUSTED',
35+
}
36+
},
37+
)
38+
2739

2840
class MockPlugin(BasePlugin):
2941
before_model_text = 'before_model_text from MockPlugin'
3042
after_model_text = 'after_model_text from MockPlugin'
43+
on_model_error_text = 'on_model_error_text from MockPlugin'
3144

3245
def __init__(self, name='mock_plugin'):
3346
self.name = name
3447
self.enable_before_model_callback = False
3548
self.enable_after_model_callback = False
49+
self.enable_on_model_error_callback = False
3650
self.before_model_response = LlmResponse(
3751
content=testing_utils.ModelContent(
3852
[types.Part.from_text(text=self.before_model_text)]
@@ -43,6 +57,11 @@ def __init__(self, name='mock_plugin'):
4357
[types.Part.from_text(text=self.after_model_text)]
4458
)
4559
)
60+
self.on_model_error_response = LlmResponse(
61+
content=testing_utils.ModelContent(
62+
[types.Part.from_text(text=self.on_model_error_text)]
63+
)
64+
)
4665

4766
async def before_model_callback(
4867
self, *, callback_context: CallbackContext, llm_request: LlmRequest
@@ -58,6 +77,17 @@ async def after_model_callback(
5877
return None
5978
return self.after_model_response
6079

80+
async def on_model_error_callback(
81+
self,
82+
*,
83+
callback_context: CallbackContext,
84+
llm_request: LlmRequest,
85+
error: Exception,
86+
) -> Optional[LlmResponse]:
87+
if not self.enable_on_model_error_callback:
88+
return None
89+
return self.on_model_error_response
90+
6191

6292
CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content'
6393

@@ -124,5 +154,36 @@ def test_before_model_callback_fallback_model(mock_plugin):
124154
]
125155

126156

157+
def test_on_model_error_callback_with_plugin(mock_plugin):
158+
"""Tests that the model error is handled by the plugin."""
159+
mock_model = testing_utils.MockModel.create(error=mock_error, responses=[])
160+
mock_plugin.enable_on_model_error_callback = True
161+
agent = Agent(
162+
name='root_agent',
163+
model=mock_model,
164+
)
165+
166+
runner = testing_utils.InMemoryRunner(agent, plugins=[mock_plugin])
167+
168+
assert testing_utils.simplify_events(runner.run('test')) == [
169+
('root_agent', mock_plugin.on_model_error_text),
170+
]
171+
172+
173+
def test_on_model_error_callback_fallback_to_runner(mock_plugin):
174+
"""Tests that the model error is not handled and falls back to raise from runner."""
175+
mock_model = testing_utils.MockModel.create(error=mock_error, responses=[])
176+
mock_plugin.enable_on_model_error_callback = False
177+
agent = Agent(
178+
name='root_agent',
179+
model=mock_model,
180+
)
181+
182+
try:
183+
testing_utils.InMemoryRunner(agent, plugins=[mock_plugin])
184+
except Exception as e:
185+
assert e == mock_error
186+
187+
127188
if __name__ == '__main__':
128189
pytest.main([__file__])

0 commit comments

Comments
 (0)