Skip to content

feat: add new callbacks to handle tool and model errors #1981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,13 @@ async def _call_llm_async(
with tracer.start_as_current_span('call_llm'):
if invocation_context.run_config.support_cfc:
invocation_context.live_request_queue = LiveRequestQueue()
async for llm_response in self.run_live(invocation_context):
responses_generator = self.run_live(invocation_context)
async for llm_response in self._run_and_handle_error(
responses_generator,
invocation_context,
llm_request,
model_response_event,
):
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
Expand All @@ -540,10 +546,16 @@ async def _call_llm_async(
# the counter beyond the max set value, then the execution is stopped
# right here, and exception is thrown.
invocation_context.increment_llm_call_count()
async for llm_response in llm.generate_content_async(
responses_generator = llm.generate_content_async(
llm_request,
stream=invocation_context.run_config.streaming_mode
== StreamingMode.SSE,
)
async for llm_response in self._run_and_handle_error(
responses_generator,
invocation_context,
llm_request,
model_response_event,
):
trace_call_llm(
invocation_context,
Expand Down Expand Up @@ -660,6 +672,43 @@ def _finalize_model_response_event(

return model_response_event

async def _run_and_handle_error(
self,
response_generator: AsyncGenerator[LlmResponse, None],
invocation_context: InvocationContext,
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
"""Runs the response generator and processes the error with plugins.

Args:
response_generator: The response generator to run.
invocation_context: The invocation context.
llm_request: The LLM request.
model_response_event: The model response event.

Yields:
A generator of LlmResponse.
"""
try:
async for response in response_generator:
yield response
except Exception as model_error:
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
error_response = (
await invocation_context.plugin_manager.run_on_model_error_callback(
callback_context=callback_context,
llm_request=llm_request,
error=model_error,
)
)
if error_response is not None:
yield error_response
else:
raise model_error

def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
from ...agents.llm_agent import LlmAgent

Expand Down
18 changes: 15 additions & 3 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,21 @@ async def handle_function_calls_async(

# Step 3: Otherwise, proceed calling the tool normally.
if function_response is None:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
try:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
except Exception as tool_error:
error_response = await invocation_context.plugin_manager.run_on_tool_error_callback(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
if error_response is not None:
function_response = error_response
else:
raise tool_error

# Step 4: Check if plugin after_tool_callback overrides the function
# response.
Expand Down
51 changes: 51 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,31 @@ async def after_model_callback(
"""
pass

async def on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
"""Callback executed when a model call encounters an error.

This callback provides an opportunity to handle model errors gracefully,
potentially providing alternative responses or recovery mechanisms.

Args:
callback_context: The context for the current agent call.
llm_request: The request that was sent to the model when the error
occurred.
error: The exception that was raised during model execution.

Returns:
An optional LlmResponse. If an LlmResponse is returned, it will be used
instead of propagating the error. Returning `None` allows the original
error to be raised.
"""
pass

async def before_tool_callback(
self,
*,
Expand Down Expand Up @@ -315,3 +340,29 @@ async def after_tool_callback(
result.
"""
pass

async def on_tool_error_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
"""Callback executed when a tool call encounters an error.

This callback provides an opportunity to handle tool errors gracefully,
potentially providing alternative responses or recovery mechanisms.

Args:
tool: The tool instance that encountered an error.
tool_args: The arguments that were passed to the tool.
tool_context: The context specific to the tool execution.
error: The exception that was raised during tool execution.

Returns:
An optional dictionary. If a dictionary is returned, it will be used as
the tool response instead of propagating the error. Returning `None`
allows the original error to be raised.
"""
pass
34 changes: 34 additions & 0 deletions src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"after_tool_callback",
"before_model_callback",
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
]

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -195,6 +197,21 @@ async def run_after_tool_callback(
result=result,
)

async def run_on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
"""Runs the `on_model_error_callback` for all plugins."""
return await self._run_callbacks(
"on_model_error_callback",
callback_context=callback_context,
llm_request=llm_request,
error=error,
)

async def run_before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
Expand All @@ -215,6 +232,23 @@ async def run_after_model_callback(
llm_response=llm_response,
)

async def run_on_tool_error_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
"""Runs the `on_tool_error_callback` for all plugins."""
return await self._run_callbacks(
"on_tool_error_callback",
tool=tool,
tool_args=tool_args,
tool_context=tool_context,
error=error,
)

async def _run_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> Optional[Any]:
Expand Down
61 changes: 61 additions & 0 deletions tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,33 @@
from google.adk.models import LlmResponse
from google.adk.plugins.base_plugin import BasePlugin
from google.genai import types
from google.genai.errors import ClientError
import pytest

from ... import testing_utils

mock_error = ClientError(
code=429,
response_json={
'error': {
'code': 429,
'message': 'Quota exceeded.',
'status': 'RESOURCE_EXHAUSTED',
}
},
)


class MockPlugin(BasePlugin):
before_model_text = 'before_model_text from MockPlugin'
after_model_text = 'after_model_text from MockPlugin'
on_model_error_text = 'on_model_error_text from MockPlugin'

def __init__(self, name='mock_plugin'):
self.name = name
self.enable_before_model_callback = False
self.enable_after_model_callback = False
self.enable_on_model_error_callback = False
self.before_model_response = LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text=self.before_model_text)]
Expand All @@ -43,6 +57,11 @@ def __init__(self, name='mock_plugin'):
[types.Part.from_text(text=self.after_model_text)]
)
)
self.on_model_error_response = LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text=self.on_model_error_text)]
)
)

async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
Expand All @@ -58,6 +77,17 @@ async def after_model_callback(
return None
return self.after_model_response

async def on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
if not self.enable_on_model_error_callback:
return None
return self.on_model_error_response


CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content'

Expand Down Expand Up @@ -124,5 +154,36 @@ def test_before_model_callback_fallback_model(mock_plugin):
]


def test_on_model_error_callback_with_plugin(mock_plugin):
"""Tests that the model error is handled by the plugin."""
mock_model = testing_utils.MockModel.create(error=mock_error, responses=[])
mock_plugin.enable_on_model_error_callback = True
agent = Agent(
name='root_agent',
model=mock_model,
)

runner = testing_utils.InMemoryRunner(agent, plugins=[mock_plugin])

assert testing_utils.simplify_events(runner.run('test')) == [
('root_agent', mock_plugin.on_model_error_text),
]


def test_on_model_error_callback_fallback_to_runner(mock_plugin):
"""Tests that the model error is not handled and falls back to raise from runner."""
mock_model = testing_utils.MockModel.create(error=mock_error, responses=[])
mock_plugin.enable_on_model_error_callback = False
agent = Agent(
name='root_agent',
model=mock_model,
)

try:
testing_utils.InMemoryRunner(agent, plugins=[mock_plugin])
except Exception as e:
assert e == mock_error


if __name__ == '__main__':
pytest.main([__file__])
Loading