Skip to content

Commit 781aeb9

Browse files
pwwpchecopybara-github
authored andcommitted
feat: add new callbacks to handle tool and model errors
This CL add new callbacks in plugin system: - `on_tool_error_callback` - `on_model_error_callback` This allow the user to create plugins that can handle errors. PiperOrigin-RevId: 783052800
1 parent b977d12 commit 781aeb9

File tree

9 files changed

+362
-38
lines changed

9 files changed

+362
-38
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -513,45 +513,61 @@ async def _call_llm_async(
513513
# Calls the LLM.
514514
llm = self.__get_llm(invocation_context)
515515
with tracer.start_as_current_span('call_llm'):
516-
if invocation_context.run_config.support_cfc:
517-
invocation_context.live_request_queue = LiveRequestQueue()
518-
async for llm_response in self.run_live(invocation_context):
519-
# Runs after_model_callback if it exists.
520-
if altered_llm_response := await self._handle_after_model_callback(
521-
invocation_context, llm_response, model_response_event
522-
):
523-
llm_response = altered_llm_response
524-
# only yield partial response in SSE streaming mode
525-
if (
526-
invocation_context.run_config.streaming_mode == StreamingMode.SSE
527-
or not llm_response.partial
528-
):
529-
yield llm_response
530-
if llm_response.turn_complete:
531-
invocation_context.live_request_queue.close()
532-
else:
533-
# Check if we can make this llm call or not. If the current call pushes
534-
# the counter beyond the max set value, then the execution is stopped
535-
# right here, and exception is thrown.
536-
invocation_context.increment_llm_call_count()
537-
async for llm_response in llm.generate_content_async(
538-
llm_request,
539-
stream=invocation_context.run_config.streaming_mode
540-
== StreamingMode.SSE,
516+
try:
517+
if (
518+
invocation_context.run_config
519+
and invocation_context.run_config.support_cfc
541520
):
542-
trace_call_llm(
543-
invocation_context,
544-
model_response_event.id,
521+
invocation_context.live_request_queue = LiveRequestQueue()
522+
async for llm_response in self.run_live(invocation_context):
523+
# Runs after_model_callback if it exists.
524+
if altered_llm_response := await self._handle_after_model_callback(
525+
invocation_context, llm_response, model_response_event
526+
):
527+
llm_response = altered_llm_response
528+
# only yield partial response in SSE streaming mode
529+
if (
530+
invocation_context.run_config.streaming_mode
531+
== StreamingMode.SSE
532+
or not llm_response.partial
533+
):
534+
yield llm_response
535+
if llm_response.turn_complete:
536+
invocation_context.live_request_queue.close()
537+
else:
538+
# Check if we can make this llm call or not. If the current call pushes
539+
# the counter beyond the max set value, then the execution is stopped
540+
# right here, and exception is thrown.
541+
invocation_context.increment_llm_call_count()
542+
async for llm_response in llm.generate_content_async(
545543
llm_request,
546-
llm_response,
547-
)
548-
# Runs after_model_callback if it exists.
549-
if altered_llm_response := await self._handle_after_model_callback(
550-
invocation_context, llm_response, model_response_event
544+
stream=invocation_context.run_config.streaming_mode
545+
== StreamingMode.SSE,
551546
):
552-
llm_response = altered_llm_response
547+
trace_call_llm(
548+
invocation_context,
549+
model_response_event.id,
550+
llm_request,
551+
llm_response,
552+
)
553+
# Runs after_model_callback if it exists.
554+
if altered_llm_response := await self._handle_after_model_callback(
555+
invocation_context, llm_response, model_response_event
556+
):
557+
llm_response = altered_llm_response
553558

554-
yield llm_response
559+
yield llm_response
560+
except Exception as model_error:
561+
if (
562+
invocation_context.run_config
563+
and invocation_context.run_config.support_cfc
564+
and invocation_context.live_request_queue
565+
):
566+
invocation_context.live_request_queue.close()
567+
error_response = await self._handle_model_error(
568+
invocation_context, llm_request, model_response_event, model_error
569+
)
570+
yield error_response
555571

556572
async def _handle_before_model_callback(
557573
self,
@@ -592,6 +608,29 @@ async def _handle_before_model_callback(
592608
if callback_response:
593609
return callback_response
594610

611+
async def _handle_model_error(
612+
self,
613+
invocation_context: InvocationContext,
614+
llm_request: LlmRequest,
615+
model_response_event: Event,
616+
model_error: Exception,
617+
) -> LlmResponse:
618+
"""Handle model errors through plugin system."""
619+
callback_context = CallbackContext(
620+
invocation_context, event_actions=model_response_event.actions
621+
)
622+
error_response = (
623+
await invocation_context.plugin_manager.run_on_model_error_callback(
624+
callback_context=callback_context,
625+
llm_request=llm_request,
626+
error=model_error,
627+
)
628+
)
629+
if error_response is not None:
630+
return error_response
631+
else:
632+
raise model_error
633+
595634
async def _handle_after_model_callback(
596635
self,
597636
invocation_context: InvocationContext,

src/google/adk/flows/llm_flows/functions.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,21 @@ async def handle_function_calls_async(
176176

177177
# Step 3: Otherwise, proceed calling the tool normally.
178178
if function_response is None:
179-
function_response = await __call_tool_async(
180-
tool, args=function_args, tool_context=tool_context
181-
)
179+
try:
180+
function_response = await __call_tool_async(
181+
tool, args=function_args, tool_context=tool_context
182+
)
183+
except Exception as tool_error:
184+
error_response = await invocation_context.plugin_manager.run_on_tool_error_callback(
185+
tool=tool,
186+
tool_args=function_args,
187+
tool_context=tool_context,
188+
error=tool_error,
189+
)
190+
if error_response is not None:
191+
function_response = error_response
192+
else:
193+
raise tool_error
182194

183195
# Step 4: Check if plugin after_tool_callback overrides the function
184196
# response.

src/google/adk/plugins/base_plugin.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,31 @@ async def after_model_callback(
265265
"""
266266
pass
267267

268+
async def on_model_error_callback(
269+
self,
270+
*,
271+
callback_context: CallbackContext,
272+
llm_request: LlmRequest,
273+
error: Exception,
274+
) -> Optional[LlmResponse]:
275+
"""Callback executed when a model call encounters an error.
276+
277+
This callback provides an opportunity to handle model errors gracefully,
278+
potentially providing alternative responses or recovery mechanisms.
279+
280+
Args:
281+
callback_context: The context for the current agent call.
282+
llm_request: The request that was sent to the model when the error
283+
occurred.
284+
error: The exception that was raised during model execution.
285+
286+
Returns:
287+
An optional LlmResponse. If an LlmResponse is returned, it will be used
288+
instead of propagating the error. Returning `None` allows the original
289+
error to be raised.
290+
"""
291+
pass
292+
268293
async def before_tool_callback(
269294
self,
270295
*,
@@ -315,3 +340,29 @@ async def after_tool_callback(
315340
result.
316341
"""
317342
pass
343+
344+
async def on_tool_error_callback(
345+
self,
346+
*,
347+
tool: BaseTool,
348+
tool_args: dict[str, Any],
349+
tool_context: ToolContext,
350+
error: Exception,
351+
) -> Optional[dict]:
352+
"""Callback executed when a tool call encounters an error.
353+
354+
This callback provides an opportunity to handle tool errors gracefully,
355+
potentially providing alternative responses or recovery mechanisms.
356+
357+
Args:
358+
tool: The tool instance that encountered an error.
359+
tool_args: The arguments that were passed to the tool.
360+
tool_context: The context specific to the tool execution.
361+
error: The exception that was raised during tool execution.
362+
363+
Returns:
364+
An optional dictionary. If a dictionary is returned, it will be used as
365+
the tool response instead of propagating the error. Returning `None`
366+
allows the original error to be raised.
367+
"""
368+
pass

src/google/adk/plugins/plugin_manager.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
"after_tool_callback",
4949
"before_model_callback",
5050
"after_model_callback",
51+
"on_tool_error_callback",
52+
"on_model_error_callback",
5153
]
5254

5355
logger = logging.getLogger("google_adk." + __name__)
@@ -195,6 +197,21 @@ async def run_after_tool_callback(
195197
result=result,
196198
)
197199

200+
async def run_on_model_error_callback(
201+
self,
202+
*,
203+
callback_context: CallbackContext,
204+
llm_request: LlmRequest,
205+
error: Exception,
206+
) -> Optional[LlmResponse]:
207+
"""Runs the `on_model_error_callback` for all plugins."""
208+
return await self._run_callbacks(
209+
"on_model_error_callback",
210+
callback_context=callback_context,
211+
llm_request=llm_request,
212+
error=error,
213+
)
214+
198215
async def run_before_model_callback(
199216
self, *, callback_context: CallbackContext, llm_request: LlmRequest
200217
) -> Optional[LlmResponse]:
@@ -215,6 +232,23 @@ async def run_after_model_callback(
215232
llm_response=llm_response,
216233
)
217234

235+
async def run_on_tool_error_callback(
236+
self,
237+
*,
238+
tool: BaseTool,
239+
tool_args: dict[str, Any],
240+
tool_context: ToolContext,
241+
error: Exception,
242+
) -> Optional[dict]:
243+
"""Runs the `on_tool_error_callback` for all plugins."""
244+
return await self._run_callbacks(
245+
"on_tool_error_callback",
246+
tool=tool,
247+
tool_args=tool_args,
248+
tool_context=tool_context,
249+
error=error,
250+
)
251+
218252
async def _run_callbacks(
219253
self, callback_name: PluginCallbackName, **kwargs: Any
220254
) -> Optional[Any]:

tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,33 @@
2020
from google.adk.models import LlmResponse
2121
from google.adk.plugins.base_plugin import BasePlugin
2222
from google.genai import types
23+
from google.genai.errors import ClientError
2324
import pytest
2425

2526
from ... import testing_utils
2627

28+
mock_error = ClientError(
29+
code=429,
30+
response_json={
31+
'error': {
32+
'code': 429,
33+
'message': 'Quota exceeded.',
34+
'status': 'RESOURCE_EXHAUSTED',
35+
}
36+
},
37+
)
38+
2739

2840
class MockPlugin(BasePlugin):
2941
before_model_text = 'before_model_text from MockPlugin'
3042
after_model_text = 'after_model_text from MockPlugin'
43+
on_model_error_text = 'on_model_error_text from MockPlugin'
3144

3245
def __init__(self, name='mock_plugin'):
3346
self.name = name
3447
self.enable_before_model_callback = False
3548
self.enable_after_model_callback = False
49+
self.enable_on_model_error_callback = False
3650
self.before_model_response = LlmResponse(
3751
content=testing_utils.ModelContent(
3852
[types.Part.from_text(text=self.before_model_text)]
@@ -43,6 +57,11 @@ def __init__(self, name='mock_plugin'):
4357
[types.Part.from_text(text=self.after_model_text)]
4458
)
4559
)
60+
self.on_model_error_response = LlmResponse(
61+
content=testing_utils.ModelContent(
62+
[types.Part.from_text(text=self.on_model_error_text)]
63+
)
64+
)
4665

4766
async def before_model_callback(
4867
self, *, callback_context: CallbackContext, llm_request: LlmRequest
@@ -58,6 +77,17 @@ async def after_model_callback(
5877
return None
5978
return self.after_model_response
6079

80+
async def on_model_error_callback(
81+
self,
82+
*,
83+
callback_context: CallbackContext,
84+
llm_request: LlmRequest,
85+
error: Exception,
86+
) -> Optional[LlmResponse]:
87+
if not self.enable_on_model_error_callback:
88+
return None
89+
return self.on_model_error_response
90+
6191

6292
CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content'
6393

@@ -124,5 +154,36 @@ def test_before_model_callback_fallback_model(mock_plugin):
124154
]
125155

126156

157+
def test_on_model_error_callback_with_plugin(mock_plugin):
158+
"""Tests that the model error is handled by the plugin."""
159+
mock_model = testing_utils.MockModel.create(error=mock_error, responses=[])
160+
mock_plugin.enable_on_model_error_callback = True
161+
agent = Agent(
162+
name='root_agent',
163+
model=mock_model,
164+
)
165+
166+
runner = testing_utils.InMemoryRunner(agent, plugins=[mock_plugin])
167+
168+
assert testing_utils.simplify_events(runner.run('test')) == [
169+
('root_agent', mock_plugin.on_model_error_text),
170+
]
171+
172+
173+
def test_on_model_error_callback_fallback_to_runner(mock_plugin):
174+
"""Tests that the model error is not handled and falls back to raise from runner."""
175+
mock_model = testing_utils.MockModel.create(error=mock_error, responses=[])
176+
mock_plugin.enable_on_model_error_callback = False
177+
agent = Agent(
178+
name='root_agent',
179+
model=mock_model,
180+
)
181+
182+
try:
183+
testing_utils.InMemoryRunner(agent, plugins=[mock_plugin])
184+
except Exception as e:
185+
assert e == mock_error
186+
187+
127188
if __name__ == '__main__':
128189
pytest.main([__file__])

0 commit comments

Comments
 (0)