From 90360f455ae30e91ba98f3fdc056f3260ef008ef Mon Sep 17 00:00:00 2001 From: Che Liu Date: Mon, 14 Jul 2025 14:54:09 -0700 Subject: [PATCH] feat: add new callbacks to handle tool and model errors This CL add new callbacks in plugin system: - `on_tool_error_callback` - `on_model_error_callback` This allow the user to create plugins that can handle errors. PiperOrigin-RevId: 783052800 --- .../adk/flows/llm_flows/base_llm_flow.py | 53 +++++++++++++++- src/google/adk/flows/llm_flows/functions.py | 18 +++++- src/google/adk/plugins/base_plugin.py | 51 ++++++++++++++++ src/google/adk/plugins/plugin_manager.py | 34 +++++++++++ .../llm_flows/test_plugin_model_callbacks.py | 61 +++++++++++++++++++ .../llm_flows/test_plugin_tool_callbacks.py | 61 +++++++++++++++++++ tests/unittests/plugins/test_base_plugin.py | 41 +++++++++++++ .../unittests/plugins/test_plugin_manager.py | 19 ++++++ tests/unittests/testing_utils.py | 6 ++ 9 files changed, 339 insertions(+), 5 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b87a083ac..f10a4ed34 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -521,7 +521,13 @@ async def _call_llm_async( with tracer.start_as_current_span('call_llm'): if invocation_context.run_config.support_cfc: invocation_context.live_request_queue = LiveRequestQueue() - async for llm_response in self.run_live(invocation_context): + responses_generator = self.run_live(invocation_context) + async for llm_response in self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, + ): # Runs after_model_callback if it exists. if altered_llm_response := await self._handle_after_model_callback( invocation_context, llm_response, model_response_event @@ -540,10 +546,16 @@ async def _call_llm_async( # the counter beyond the max set value, then the execution is stopped # right here, and exception is thrown. invocation_context.increment_llm_call_count() - async for llm_response in llm.generate_content_async( + responses_generator = llm.generate_content_async( llm_request, stream=invocation_context.run_config.streaming_mode == StreamingMode.SSE, + ) + async for llm_response in self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, ): trace_call_llm( invocation_context, @@ -660,6 +672,43 @@ def _finalize_model_response_event( return model_response_event + async def _run_and_handle_error( + self, + response_generator: AsyncGenerator[LlmResponse, None], + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, + ) -> AsyncGenerator[LlmResponse, None]: + """Runs the response generator and processes the error with plugins. + + Args: + response_generator: The response generator to run. + invocation_context: The invocation context. + llm_request: The LLM request. + model_response_event: The model response event. + + Yields: + A generator of LlmResponse. + """ + try: + async for response in response_generator: + yield response + except Exception as model_error: + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) + error_response = ( + await invocation_context.plugin_manager.run_on_model_error_callback( + callback_context=callback_context, + llm_request=llm_request, + error=model_error, + ) + ) + if error_response is not None: + yield error_response + else: + raise model_error + def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm: from ...agents.llm_agent import LlmAgent diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 379e11ef7..aaa08d91a 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -176,9 +176,21 @@ async def handle_function_calls_async( # Step 3: Otherwise, proceed calling the tool normally. if function_response is None: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context - ) + try: + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + except Exception as tool_error: + error_response = await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + function_response = error_response + else: + raise tool_error # Step 4: Check if plugin after_tool_callback overrides the function # response. diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 729e3519a..08e281dbb 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -265,6 +265,31 @@ async def after_model_callback( """ pass + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + """Callback executed when a model call encounters an error. + + This callback provides an opportunity to handle model errors gracefully, + potentially providing alternative responses or recovery mechanisms. + + Args: + callback_context: The context for the current agent call. + llm_request: The request that was sent to the model when the error + occurred. + error: The exception that was raised during model execution. + + Returns: + An optional LlmResponse. If an LlmResponse is returned, it will be used + instead of propagating the error. Returning `None` allows the original + error to be raised. + """ + pass + async def before_tool_callback( self, *, @@ -315,3 +340,29 @@ async def after_tool_callback( result. """ pass + + async def on_tool_error_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict]: + """Callback executed when a tool call encounters an error. + + This callback provides an opportunity to handle tool errors gracefully, + potentially providing alternative responses or recovery mechanisms. + + Args: + tool: The tool instance that encountered an error. + tool_args: The arguments that were passed to the tool. + tool_context: The context specific to the tool execution. + error: The exception that was raised during tool execution. + + Returns: + An optional dictionary. If a dictionary is returned, it will be used as + the tool response instead of propagating the error. Returning `None` + allows the original error to be raised. + """ + pass diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index 3680c3515..217dbb8be 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -48,6 +48,8 @@ "after_tool_callback", "before_model_callback", "after_model_callback", + "on_tool_error_callback", + "on_model_error_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -195,6 +197,21 @@ async def run_after_tool_callback( result=result, ) + async def run_on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + """Runs the `on_model_error_callback` for all plugins.""" + return await self._run_callbacks( + "on_model_error_callback", + callback_context=callback_context, + llm_request=llm_request, + error=error, + ) + async def run_before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: @@ -215,6 +232,23 @@ async def run_after_model_callback( llm_response=llm_response, ) + async def run_on_tool_error_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict]: + """Runs the `on_tool_error_callback` for all plugins.""" + return await self._run_callbacks( + "on_tool_error_callback", + tool=tool, + tool_args=tool_args, + tool_context=tool_context, + error=error, + ) + async def _run_callbacks( self, callback_name: PluginCallbackName, **kwargs: Any ) -> Optional[Any]: diff --git a/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py b/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py index b9b2dec35..62a4c234d 100644 --- a/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py @@ -20,19 +20,33 @@ from google.adk.models import LlmResponse from google.adk.plugins.base_plugin import BasePlugin from google.genai import types +from google.genai.errors import ClientError import pytest from ... import testing_utils +mock_error = ClientError( + code=429, + response_json={ + 'error': { + 'code': 429, + 'message': 'Quota exceeded.', + 'status': 'RESOURCE_EXHAUSTED', + } + }, +) + class MockPlugin(BasePlugin): before_model_text = 'before_model_text from MockPlugin' after_model_text = 'after_model_text from MockPlugin' + on_model_error_text = 'on_model_error_text from MockPlugin' def __init__(self, name='mock_plugin'): self.name = name self.enable_before_model_callback = False self.enable_after_model_callback = False + self.enable_on_model_error_callback = False self.before_model_response = LlmResponse( content=testing_utils.ModelContent( [types.Part.from_text(text=self.before_model_text)] @@ -43,6 +57,11 @@ def __init__(self, name='mock_plugin'): [types.Part.from_text(text=self.after_model_text)] ) ) + self.on_model_error_response = LlmResponse( + content=testing_utils.ModelContent( + [types.Part.from_text(text=self.on_model_error_text)] + ) + ) async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest @@ -58,6 +77,17 @@ async def after_model_callback( return None return self.after_model_response + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + if not self.enable_on_model_error_callback: + return None + return self.on_model_error_response + CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content' @@ -124,5 +154,36 @@ def test_before_model_callback_fallback_model(mock_plugin): ] +def test_on_model_error_callback_with_plugin(mock_plugin): + """Tests that the model error is handled by the plugin.""" + mock_model = testing_utils.MockModel.create(error=mock_error, responses=[]) + mock_plugin.enable_on_model_error_callback = True + agent = Agent( + name='root_agent', + model=mock_model, + ) + + runner = testing_utils.InMemoryRunner(agent, plugins=[mock_plugin]) + + assert testing_utils.simplify_events(runner.run('test')) == [ + ('root_agent', mock_plugin.on_model_error_text), + ] + + +def test_on_model_error_callback_fallback_to_runner(mock_plugin): + """Tests that the model error is not handled and falls back to raise from runner.""" + mock_model = testing_utils.MockModel.create(error=mock_error, responses=[]) + mock_plugin.enable_on_model_error_callback = False + agent = Agent( + name='root_agent', + model=mock_model, + ) + + try: + testing_utils.InMemoryRunner(agent, plugins=[mock_plugin]) + except Exception as e: + assert e == mock_error + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py index a79e562a5..fac4169b3 100644 --- a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py @@ -24,19 +24,35 @@ from google.adk.tools.function_tool import FunctionTool from google.adk.tools.tool_context import ToolContext from google.genai import types +from google.genai.errors import ClientError import pytest from ... import testing_utils +mock_error = ClientError( + code=429, + response_json={ + "error": { + "code": 429, + "message": "Quota exceeded.", + "status": "RESOURCE_EXHAUSTED", + } + }, +) + class MockPlugin(BasePlugin): before_tool_response = {"MockPlugin": "before_tool_response from MockPlugin"} after_tool_response = {"MockPlugin": "after_tool_response from MockPlugin"} + on_tool_error_response = { + "MockPlugin": "on_tool_error_response from MockPlugin" + } def __init__(self, name="mock_plugin"): self.name = name self.enable_before_tool_callback = False self.enable_after_tool_callback = False + self.enable_on_tool_error_callback = False async def before_tool_callback( self, @@ -61,6 +77,18 @@ async def after_tool_callback( return None return self.after_tool_response + async def on_tool_error_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict]: + if not self.enable_on_tool_error_callback: + return None + return self.on_tool_error_response + @pytest.fixture def mock_tool(): @@ -70,6 +98,14 @@ def simple_fn(**kwargs) -> Dict[str, Any]: return FunctionTool(simple_fn) +@pytest.fixture +def mock_error_tool(): + def raise_error_fn(**kwargs) -> Dict[str, Any]: + raise mock_error + + return FunctionTool(raise_error_fn) + + @pytest.fixture def mock_plugin(): return MockPlugin() @@ -124,5 +160,30 @@ async def test_async_after_tool_callback(mock_tool, mock_plugin): assert part.function_response.response == mock_plugin.after_tool_response +@pytest.mark.asyncio +async def test_async_on_tool_error_use_plugin_response( + mock_error_tool, mock_plugin +): + mock_plugin.enable_on_tool_error_callback = True + + result_event = await invoke_tool_with_plugin(mock_error_tool, mock_plugin) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_plugin.on_tool_error_response + + +@pytest.mark.asyncio +async def test_async_on_tool_error_fallback_to_runner( + mock_error_tool, mock_plugin +): + mock_plugin.enable_on_tool_error_callback = False + + try: + await invoke_tool_with_plugin(mock_error_tool, mock_plugin) + except Exception as e: + assert e == mock_error + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/unittests/plugins/test_base_plugin.py b/tests/unittests/plugins/test_base_plugin.py index 04b1c3e94..3a2de9430 100644 --- a/tests/unittests/plugins/test_base_plugin.py +++ b/tests/unittests/plugins/test_base_plugin.py @@ -67,12 +67,18 @@ async def before_tool_callback(self, **kwargs) -> str: async def after_tool_callback(self, **kwargs) -> str: return "overridden_after_tool" + async def on_tool_error_callback(self, **kwargs) -> str: + return "overridden_on_tool_error" + async def before_model_callback(self, **kwargs) -> str: return "overridden_before_model" async def after_model_callback(self, **kwargs) -> str: return "overridden_after_model" + async def on_model_error_callback(self, **kwargs) -> str: + return "overridden_on_model_error" + def test_base_plugin_initialization(): """Tests that a plugin is initialized with the correct name.""" @@ -137,6 +143,15 @@ async def test_base_plugin_default_callbacks_return_none(): ) is None ) + assert ( + await plugin.on_tool_error_callback( + tool=mock_context, + tool_args={}, + tool_context=mock_context, + error=Exception(), + ) + is None + ) assert ( await plugin.before_model_callback( callback_context=mock_context, llm_request=mock_context @@ -149,6 +164,14 @@ async def test_base_plugin_default_callbacks_return_none(): ) is None ) + assert ( + await plugin.on_model_error_callback( + callback_context=mock_context, + llm_request=mock_context, + error=Exception(), + ) + is None + ) @pytest.mark.asyncio @@ -170,6 +193,7 @@ async def test_base_plugin_all_callbacks_can_be_overridden(): mock_llm_request = Mock(spec=LlmRequest) mock_llm_response = Mock(spec=LlmResponse) mock_event = Mock(spec=Event) + mock_error = Mock(spec=Exception) # Call each method and assert it returns the unique string from the override. # This proves that the subclass's method was executed. @@ -237,3 +261,20 @@ async def test_base_plugin_all_callbacks_can_be_overridden(): ) == "overridden_after_tool" ) + assert ( + await plugin.on_tool_error_callback( + tool=mock_tool, + tool_args={}, + tool_context=mock_tool_context, + error=mock_error, + ) + == "overridden_on_tool_error" + ) + assert ( + await plugin.on_model_error_callback( + callback_context=mock_callback_context, + llm_request=mock_llm_request, + error=mock_error, + ) + == "overridden_on_model_error" + ) diff --git a/tests/unittests/plugins/test_plugin_manager.py b/tests/unittests/plugins/test_plugin_manager.py index 76d32a618..e3edfa83e 100644 --- a/tests/unittests/plugins/test_plugin_manager.py +++ b/tests/unittests/plugins/test_plugin_manager.py @@ -77,12 +77,18 @@ async def before_tool_callback(self, **kwargs): async def after_tool_callback(self, **kwargs): return await self._handle_callback("after_tool_callback") + async def on_tool_error_callback(self, **kwargs): + return await self._handle_callback("on_tool_error_callback") + async def before_model_callback(self, **kwargs): return await self._handle_callback("before_model_callback") async def after_model_callback(self, **kwargs): return await self._handle_callback("after_model_callback") + async def on_model_error_callback(self, **kwargs): + return await self._handle_callback("on_model_error_callback") + @pytest.fixture def service() -> PluginManager: @@ -227,12 +233,23 @@ async def test_all_callbacks_are_supported( await service.run_after_tool_callback( tool=mock_context, tool_args={}, tool_context=mock_context, result={} ) + await service.run_on_tool_error_callback( + tool=mock_context, + tool_args={}, + tool_context=mock_context, + error=mock_context, + ) await service.run_before_model_callback( callback_context=mock_context, llm_request=mock_context ) await service.run_after_model_callback( callback_context=mock_context, llm_response=mock_context ) + await service.run_on_model_error_callback( + callback_context=mock_context, + llm_request=mock_context, + error=mock_context, + ) # Verify all callbacks were logged expected_callbacks = [ @@ -244,7 +261,9 @@ async def test_all_callbacks_are_supported( "after_agent_callback", "before_tool_callback", "after_tool_callback", + "on_tool_error_callback", "before_model_callback", "after_model_callback", + "on_model_error_callback", ] assert set(plugin1.call_log) == set(expected_callbacks) diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index 810a6c448..9ddf92cd1 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -247,6 +247,7 @@ class MockModel(BaseLlm): requests: list[LlmRequest] = [] responses: list[LlmResponse] + error: Union[Exception, None] = None response_index: int = -1 @classmethod @@ -255,7 +256,10 @@ def create( responses: Union[ list[types.Part], list[LlmResponse], list[str], list[list[types.Part]] ], + error: Union[Exception, None] = None, ): + if error and not responses: + return cls(responses=[], error=error) if not responses: return cls(responses=[]) elif isinstance(responses[0], LlmResponse): @@ -285,6 +289,8 @@ def supported_models() -> list[str]: def generate_content( self, llm_request: LlmRequest, stream: bool = False ) -> Generator[LlmResponse, None, None]: + if self.error: + raise self.error # Increasement of the index has to happen before the yield. self.response_index += 1 self.requests.append(llm_request)