diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b87a083ac..feaa597f7 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -17,9 +17,11 @@ from abc import ABC import asyncio import datetime +from enum import Enum import inspect import logging from typing import AsyncGenerator +from typing import Callable from typing import cast from typing import Optional from typing import TYPE_CHECKING @@ -36,6 +38,7 @@ from ...agents.run_config import StreamingMode from ...agents.transcription_entry import TranscriptionEntry from ...events.event import Event +from ...models.base_llm import ModelErrorStrategy from ...models.base_llm_connection import BaseLlmConnection from ...models.llm_request import LlmRequest from ...models.llm_response import LlmResponse @@ -521,7 +524,13 @@ async def _call_llm_async( with tracer.start_as_current_span('call_llm'): if invocation_context.run_config.support_cfc: invocation_context.live_request_queue = LiveRequestQueue() - async for llm_response in self.run_live(invocation_context): + responses_generator = lambda: self.run_live(invocation_context) + async for llm_response in self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, + ): # Runs after_model_callback if it exists. if altered_llm_response := await self._handle_after_model_callback( invocation_context, llm_response, model_response_event @@ -540,10 +549,16 @@ async def _call_llm_async( # the counter beyond the max set value, then the execution is stopped # right here, and exception is thrown. invocation_context.increment_llm_call_count() - async for llm_response in llm.generate_content_async( + responses_generator = lambda: llm.generate_content_async( llm_request, stream=invocation_context.run_config.streaming_mode == StreamingMode.SSE, + ) + async for llm_response in self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, ): trace_call_llm( invocation_context, @@ -660,6 +675,54 @@ def _finalize_model_response_event( return model_response_event + async def _run_and_handle_error( + self, + response_generator: Callable[..., AsyncGenerator[LlmResponse, None]], + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, + ) -> AsyncGenerator[LlmResponse, None]: + """Runs the response generator and processes the error with plugins. + + Args: + response_generator: The response generator to run. + invocation_context: The invocation context. + llm_request: The LLM request. + model_response_event: The model response event. + + Yields: + A generator of LlmResponse. + """ + while True: + try: + responses_generator_instance = response_generator() + async for response in responses_generator_instance: + yield response + break + except Exception as model_error: + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) + outcome = ( + await invocation_context.plugin_manager.run_on_model_error_callback( + callback_context=callback_context, + llm_request=llm_request, + error=model_error, + ) + ) + # Retry the LLM call if the plugin outcome is RETRY. + if outcome == ModelErrorStrategy.RETRY: + continue + + # If the plugin outcome is PASS, we can break the loop. + if outcome == ModelErrorStrategy.PASS: + break + if outcome is not None: + yield outcome + break + else: + raise model_error + def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm: from ...agents.llm_agent import LlmAgent diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 379e11ef7..aaa08d91a 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -176,9 +176,21 @@ async def handle_function_calls_async( # Step 3: Otherwise, proceed calling the tool normally. if function_response is None: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context - ) + try: + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + except Exception as tool_error: + error_response = await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + function_response = error_response + else: + raise tool_error # Step 4: Check if plugin after_tool_callback overrides the function # response. diff --git a/src/google/adk/models/base_llm.py b/src/google/adk/models/base_llm.py index 159ae221a..a01555dff 100644 --- a/src/google/adk/models/base_llm.py +++ b/src/google/adk/models/base_llm.py @@ -28,6 +28,11 @@ from .llm_response import LlmResponse +class ModelErrorStrategy: + RETRY = 'RETRY' + PASS = 'PASS' + + class BaseLlm(BaseModel): """The BaseLLM class. diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 729e3519a..3ba0cdb17 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -25,10 +25,10 @@ from ..agents.base_agent import BaseAgent from ..agents.callback_context import CallbackContext from ..events.event import Event +from ..models.base_llm import ModelErrorStrategy from ..models.llm_request import LlmRequest from ..models.llm_response import LlmResponse from ..tools.base_tool import BaseTool -from ..utils.feature_decorator import working_in_progress if TYPE_CHECKING: from ..agents.invocation_context import InvocationContext @@ -265,6 +265,34 @@ async def after_model_callback( """ pass + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse | ModelErrorStrategy]: + """Callback executed when a model call encounters an error. + + This callback provides an opportunity to handle model errors gracefully, + potentially providing alternative responses or recovery mechanisms. + + Args: + callback_context: The context for the current agent call. + llm_request: The request that was sent to the model when the error + occurred. + error: The exception that was raised during model execution. + + Returns: + An optional LlmResponse. If an LlmResponse is returned, it will be used + instead of propagating the error. + Returning `ModelErrorStrategy.RETRY` will retry the LLM call. + Returning `ModelErrorStrategy.PASS` will allow the LLM call to + proceed normally and ignore the error. + Returning `None` allows the original error to be raised. + """ + pass + async def before_tool_callback( self, *, @@ -315,3 +343,29 @@ async def after_tool_callback( result. """ pass + + async def on_tool_error_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict]: + """Callback executed when a tool call encounters an error. + + This callback provides an opportunity to handle tool errors gracefully, + potentially providing alternative responses or recovery mechanisms. + + Args: + tool: The tool instance that encountered an error. + tool_args: The arguments that were passed to the tool. + tool_context: The context specific to the tool execution. + error: The exception that was raised during tool execution. + + Returns: + An optional dictionary. If a dictionary is returned, it will be used as + the tool response instead of propagating the error. Returning `None` + allows the original error to be raised. + """ + pass diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index 3680c3515..4c2fb8121 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -30,6 +30,7 @@ from ..agents.callback_context import CallbackContext from ..agents.invocation_context import InvocationContext from ..events.event import Event + from ..models.base_llm import ModelErrorStrategy from ..models.llm_request import LlmRequest from ..models.llm_response import LlmResponse from ..tools.base_tool import BaseTool @@ -48,6 +49,8 @@ "after_tool_callback", "before_model_callback", "after_model_callback", + "on_tool_error_callback", + "on_model_error_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -195,6 +198,21 @@ async def run_after_tool_callback( result=result, ) + async def run_on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse | ModelErrorStrategy]: + """Runs the `on_model_error_callback` for all plugins.""" + return await self._run_callbacks( + "on_model_error_callback", + callback_context=callback_context, + llm_request=llm_request, + error=error, + ) + async def run_before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: @@ -215,6 +233,23 @@ async def run_after_model_callback( llm_response=llm_response, ) + async def run_on_tool_error_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict]: + """Runs the `on_tool_error_callback` for all plugins.""" + return await self._run_callbacks( + "on_tool_error_callback", + tool=tool, + tool_args=tool_args, + tool_context=tool_context, + error=error, + ) + async def _run_callbacks( self, callback_name: PluginCallbackName, **kwargs: Any ) -> Optional[Any]: diff --git a/src/google/adk/plugins/reflect_retry_plugin.py b/src/google/adk/plugins/reflect_retry_plugin.py new file mode 100644 index 000000000..31943737a --- /dev/null +++ b/src/google/adk/plugins/reflect_retry_plugin.py @@ -0,0 +1,219 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any +from typing import Literal +from typing import Optional + +from google.genai import types +from pydantic import BaseModel + +from ..agents.callback_context import CallbackContext +from ..models.base_llm import ModelErrorStrategy +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from ..tools.base_tool import BaseTool +from ..tools.tool_context import ToolContext +from .base_plugin import BasePlugin + + +class ReflectAndRetryPluginResponse(BaseModel): + """Response from ReflectAndRetryPlugin.""" + + response_type: Literal[str] = "ERROR_HANDLED_BY_REFLEX_AND_RETRY_PLUGIN" + error_type: str = "" + error_details: str = "" + retry_count: int = 0 + reflection_guidance: str = "" + + +class ReflectAndRetryPlugin(BasePlugin): + """A plugin that provides error recovery through reflection and retry logic. + + When tool calls or model calls fail, this plugin generates instructional + responses that encourage the model to reflect on the error and try a + different approach, rather than simply propagating the error. + + This plugin is particularly useful for handling transient errors, API + limitations, or cases where the model might need to adjust its strategy + based on encountered obstacles. + + Example: + >>> reflect_retry_plugin = ReflectAndRetryPlugin() + >>> runner = Runner( + ... agents=[my_agent], + ... plugins=[reflect_retry_plugin], + ... ) + """ + + def __init__(self, name: str = "reflect_retry_plugin", max_retries: int = 3): + """Initialize the reflect and retry plugin. + + Args: + name: The name of the plugin instance. + max_retries: Maximum number of retries to attempt before giving up. + """ + super().__init__(name) + self.max_retries = max_retries + self._retry_counts: dict[str, int] = {} + + async def on_tool_error_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict]: + """Handle tool execution errors with reflection and retry logic.""" + retry_key = self._get_retry_key( + tool_context.invocation_id, f"tool:{tool.name}" + ) + + if not self._should_retry(retry_key): + return self._get_tool_retry_exceed_msg(tool, error) + + retry_count = self._increment_retry_count(retry_key) + + # Create a reflective response instead of propagating the error + return self._create_tool_reflection_response( + tool, tool_args, error, retry_count + ) + + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse | ModelErrorStrategy]: + """Handle model execution errors with reflection and retry logic.""" + retry_key = self._get_retry_key(callback_context.invocation_id, "model") + + if not self._should_retry(retry_key): + return self._get_model_retry_exceed_msg(error) + + self._increment_retry_count(retry_key) + + return ModelErrorStrategy.RETRY + + def _get_retry_key(self, context_id: str, operation: str) -> str: + """Generate a unique key for tracking retries.""" + return f"{context_id}:{operation}" + + def _should_retry(self, retry_key: str) -> bool: + """Check if we should attempt a retry for this operation.""" + current_count = self._retry_counts.get(retry_key, 0) + return current_count < self.max_retries + + def _increment_retry_count(self, retry_key: str) -> int: + """Increment and return the retry count for an operation.""" + self._retry_counts[retry_key] = self._retry_counts.get(retry_key, 0) + 1 + return self._retry_counts[retry_key] + + def _format_error_details(self, error: Exception) -> str: + """Format error details for inclusion in reflection message.""" + error_type = type(error).__name__ + error_message = str(error) + return f"{error_type}: {error_message}" + + def _create_tool_reflection_response( + self, + tool: BaseTool, + tool_args: dict[str, Any], + error: Exception, + retry_count: int, + ) -> dict[str, Any]: + """Create a reflection response for tool errors.""" + args_summary = json.dumps(tool_args, indent=2, default=str) + error_details = self._format_error_details(error) + + reflection_message = f""" +The tool call to '{tool.name}' failed with the following error: + +Error: {error_details} + +Tool Arguments Used: +{args_summary} + +**Reflection Instructions:** +When realizing the current approach won't work, think about the potential issues and explicitly try a different approach. Consider: + +1. **Parameter Issues**: Are the arguments correctly formatted or within expected ranges? +2. **Alternative Methods**: Is there a different tool or approach that might work better? +3. **Error Context**: What does this specific error tell you about what went wrong? +4. **Incremental Steps**: Can you break down the task into smaller, more manageable steps? + +This is retry attempt {retry_count} of {self.max_retries}. Please analyze the error and adjust your strategy accordingly. + +Instead of repeating the same approach, explicitly state what you learned from this error and how you plan to modify your approach. +""" + + return ReflectAndRetryPluginResponse( + error_type=type(error).__name__, + error_details=str(error), + retry_count=retry_count, + reflection_guidance=reflection_message.strip(), + ).model_dump(mode="json") + + def _get_tool_retry_exceed_msg( + self, + tool: BaseTool, + error: Exception, + ) -> dict[str, Any]: + """Create a reflection response for tool errors.""" + reflection_message = f""" +The tool call to '{tool.name}' has failed {self.max_retries} times and has exceeded the maximum retry limit. + +Last Error: {self._format_error_details(error)} + +**Instructions:** +Do not attempt to use this tool ('{tool.name}') again for this task. +You must try a different approach, using a different tool or strategy to accomplish the goal. +""" + return ReflectAndRetryPluginResponse( + error_type=type(error).__name__, + error_details=str(error), + retry_count=self.max_retries, + reflection_guidance=reflection_message.strip(), + ).model_dump(mode="json") + + def _get_model_retry_exceed_msg( + self, + error: Exception, + ) -> LlmResponse: + """Create a reflection response for model errors.""" + error_details = self._format_error_details(error) + reflection_content = f""" +The model request has failed {self.max_retries} times and has exceeded the maximum retry limit. + +Last Error: {error_details} +""" + content = types.Content( + role="assistant", parts=[types.Part(text=reflection_content.strip())] + ) + return LlmResponse( + content=content, + custom_metadata=({ + "reflect_and_retry_plugin": ReflectAndRetryPluginResponse( + error_type=type(error).__name__, + error_details=str(error), + retry_count=self.max_retries, + reflection_guidance=reflection_content.strip(), + ).model_dump(mode="json") + }), + ) diff --git a/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py b/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py index b9b2dec35..d009de940 100644 --- a/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py @@ -18,21 +18,36 @@ from google.adk.agents.callback_context import CallbackContext from google.adk.models import LlmRequest from google.adk.models import LlmResponse +from google.adk.models.base_llm import ModelErrorStrategy from google.adk.plugins.base_plugin import BasePlugin from google.genai import types +from google.genai.errors import ClientError import pytest from ... import testing_utils +mock_error = ClientError( + code=429, + response_json={ + 'error': { + 'code': 429, + 'message': 'Quota exceeded.', + 'status': 'RESOURCE_EXHAUSTED', + } + }, +) + class MockPlugin(BasePlugin): before_model_text = 'before_model_text from MockPlugin' after_model_text = 'after_model_text from MockPlugin' + on_model_error_text = 'on_model_error_text from MockPlugin' def __init__(self, name='mock_plugin'): self.name = name self.enable_before_model_callback = False self.enable_after_model_callback = False + self.enable_on_model_error_callback = False self.before_model_response = LlmResponse( content=testing_utils.ModelContent( [types.Part.from_text(text=self.before_model_text)] @@ -43,6 +58,11 @@ def __init__(self, name='mock_plugin'): [types.Part.from_text(text=self.after_model_text)] ) ) + self.on_model_error_response = LlmResponse( + content=testing_utils.ModelContent( + [types.Part.from_text(text=self.on_model_error_text)] + ) + ) async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest @@ -58,6 +78,38 @@ async def after_model_callback( return None return self.after_model_response + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + if not self.enable_on_model_error_callback: + return None + return self.on_model_error_response + + +class MockPluginWithRetries(BasePlugin): + + def __init__(self, name='mock_plugin'): + self.name = name + self.max_retries = 3 + self.call_count = 0 + + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + if self.call_count < self.max_retries: + self.call_count += 1 + return ModelErrorStrategy.RETRY + else: + return ModelErrorStrategy.PASS + CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content' @@ -75,6 +127,11 @@ def mock_plugin(): return MockPlugin() +@pytest.fixture +def mock_plugin_with_retries(): + return MockPluginWithRetries() + + def test_before_model_callback_with_plugin(mock_plugin): """Tests that the model response is overridden by before_model_callback from the plugin.""" responses = ['model_response'] @@ -124,5 +181,56 @@ def test_before_model_callback_fallback_model(mock_plugin): ] +def test_on_model_error_callback_with_plugin(mock_plugin): + """Tests that the model error is handled by the plugin.""" + mock_model = testing_utils.MockModel.create(error=mock_error, responses=[]) + mock_plugin.enable_on_model_error_callback = True + agent = Agent( + name='root_agent', + model=mock_model, + ) + + runner = testing_utils.InMemoryRunner(agent, plugins=[mock_plugin]) + + assert testing_utils.simplify_events(runner.run('test')) == [ + ('root_agent', mock_plugin.on_model_error_text), + ] + + +def test_on_model_error_callback_with_plugin_retries(mock_plugin_with_retries): + """Tests that when the model error and the plugin returns RETRY, the model invocation is retried.""" + mock_model = testing_utils.MockModel.create(error=mock_error, responses=[]) + agent = Agent( + name='root_agent', + model=mock_model, + ) + + runner = testing_utils.InMemoryRunner( + agent, plugins=[mock_plugin_with_retries] + ) + runner.run('test') + + assert ( + mock_plugin_with_retries.call_count + == mock_plugin_with_retries.max_retries + ) + assert mock_model.response_index == mock_plugin_with_retries.max_retries + + +def test_on_model_error_callback_fallback_to_runner(mock_plugin): + """Tests that the model error is not handled and falls back to raise from runner.""" + mock_model = testing_utils.MockModel.create(error=mock_error, responses=[]) + mock_plugin.enable_on_model_error_callback = False + agent = Agent( + name='root_agent', + model=mock_model, + ) + + try: + testing_utils.InMemoryRunner(agent, plugins=[mock_plugin]) + except Exception as e: + assert e == mock_error + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py index a79e562a5..fac4169b3 100644 --- a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py @@ -24,19 +24,35 @@ from google.adk.tools.function_tool import FunctionTool from google.adk.tools.tool_context import ToolContext from google.genai import types +from google.genai.errors import ClientError import pytest from ... import testing_utils +mock_error = ClientError( + code=429, + response_json={ + "error": { + "code": 429, + "message": "Quota exceeded.", + "status": "RESOURCE_EXHAUSTED", + } + }, +) + class MockPlugin(BasePlugin): before_tool_response = {"MockPlugin": "before_tool_response from MockPlugin"} after_tool_response = {"MockPlugin": "after_tool_response from MockPlugin"} + on_tool_error_response = { + "MockPlugin": "on_tool_error_response from MockPlugin" + } def __init__(self, name="mock_plugin"): self.name = name self.enable_before_tool_callback = False self.enable_after_tool_callback = False + self.enable_on_tool_error_callback = False async def before_tool_callback( self, @@ -61,6 +77,18 @@ async def after_tool_callback( return None return self.after_tool_response + async def on_tool_error_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict]: + if not self.enable_on_tool_error_callback: + return None + return self.on_tool_error_response + @pytest.fixture def mock_tool(): @@ -70,6 +98,14 @@ def simple_fn(**kwargs) -> Dict[str, Any]: return FunctionTool(simple_fn) +@pytest.fixture +def mock_error_tool(): + def raise_error_fn(**kwargs) -> Dict[str, Any]: + raise mock_error + + return FunctionTool(raise_error_fn) + + @pytest.fixture def mock_plugin(): return MockPlugin() @@ -124,5 +160,30 @@ async def test_async_after_tool_callback(mock_tool, mock_plugin): assert part.function_response.response == mock_plugin.after_tool_response +@pytest.mark.asyncio +async def test_async_on_tool_error_use_plugin_response( + mock_error_tool, mock_plugin +): + mock_plugin.enable_on_tool_error_callback = True + + result_event = await invoke_tool_with_plugin(mock_error_tool, mock_plugin) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_plugin.on_tool_error_response + + +@pytest.mark.asyncio +async def test_async_on_tool_error_fallback_to_runner( + mock_error_tool, mock_plugin +): + mock_plugin.enable_on_tool_error_callback = False + + try: + await invoke_tool_with_plugin(mock_error_tool, mock_plugin) + except Exception as e: + assert e == mock_error + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/unittests/plugins/test_base_plugin.py b/tests/unittests/plugins/test_base_plugin.py index 04b1c3e94..3a2de9430 100644 --- a/tests/unittests/plugins/test_base_plugin.py +++ b/tests/unittests/plugins/test_base_plugin.py @@ -67,12 +67,18 @@ async def before_tool_callback(self, **kwargs) -> str: async def after_tool_callback(self, **kwargs) -> str: return "overridden_after_tool" + async def on_tool_error_callback(self, **kwargs) -> str: + return "overridden_on_tool_error" + async def before_model_callback(self, **kwargs) -> str: return "overridden_before_model" async def after_model_callback(self, **kwargs) -> str: return "overridden_after_model" + async def on_model_error_callback(self, **kwargs) -> str: + return "overridden_on_model_error" + def test_base_plugin_initialization(): """Tests that a plugin is initialized with the correct name.""" @@ -137,6 +143,15 @@ async def test_base_plugin_default_callbacks_return_none(): ) is None ) + assert ( + await plugin.on_tool_error_callback( + tool=mock_context, + tool_args={}, + tool_context=mock_context, + error=Exception(), + ) + is None + ) assert ( await plugin.before_model_callback( callback_context=mock_context, llm_request=mock_context @@ -149,6 +164,14 @@ async def test_base_plugin_default_callbacks_return_none(): ) is None ) + assert ( + await plugin.on_model_error_callback( + callback_context=mock_context, + llm_request=mock_context, + error=Exception(), + ) + is None + ) @pytest.mark.asyncio @@ -170,6 +193,7 @@ async def test_base_plugin_all_callbacks_can_be_overridden(): mock_llm_request = Mock(spec=LlmRequest) mock_llm_response = Mock(spec=LlmResponse) mock_event = Mock(spec=Event) + mock_error = Mock(spec=Exception) # Call each method and assert it returns the unique string from the override. # This proves that the subclass's method was executed. @@ -237,3 +261,20 @@ async def test_base_plugin_all_callbacks_can_be_overridden(): ) == "overridden_after_tool" ) + assert ( + await plugin.on_tool_error_callback( + tool=mock_tool, + tool_args={}, + tool_context=mock_tool_context, + error=mock_error, + ) + == "overridden_on_tool_error" + ) + assert ( + await plugin.on_model_error_callback( + callback_context=mock_callback_context, + llm_request=mock_llm_request, + error=mock_error, + ) + == "overridden_on_model_error" + ) diff --git a/tests/unittests/plugins/test_plugin_manager.py b/tests/unittests/plugins/test_plugin_manager.py index 76d32a618..e3edfa83e 100644 --- a/tests/unittests/plugins/test_plugin_manager.py +++ b/tests/unittests/plugins/test_plugin_manager.py @@ -77,12 +77,18 @@ async def before_tool_callback(self, **kwargs): async def after_tool_callback(self, **kwargs): return await self._handle_callback("after_tool_callback") + async def on_tool_error_callback(self, **kwargs): + return await self._handle_callback("on_tool_error_callback") + async def before_model_callback(self, **kwargs): return await self._handle_callback("before_model_callback") async def after_model_callback(self, **kwargs): return await self._handle_callback("after_model_callback") + async def on_model_error_callback(self, **kwargs): + return await self._handle_callback("on_model_error_callback") + @pytest.fixture def service() -> PluginManager: @@ -227,12 +233,23 @@ async def test_all_callbacks_are_supported( await service.run_after_tool_callback( tool=mock_context, tool_args={}, tool_context=mock_context, result={} ) + await service.run_on_tool_error_callback( + tool=mock_context, + tool_args={}, + tool_context=mock_context, + error=mock_context, + ) await service.run_before_model_callback( callback_context=mock_context, llm_request=mock_context ) await service.run_after_model_callback( callback_context=mock_context, llm_response=mock_context ) + await service.run_on_model_error_callback( + callback_context=mock_context, + llm_request=mock_context, + error=mock_context, + ) # Verify all callbacks were logged expected_callbacks = [ @@ -244,7 +261,9 @@ async def test_all_callbacks_are_supported( "after_agent_callback", "before_tool_callback", "after_tool_callback", + "on_tool_error_callback", "before_model_callback", "after_model_callback", + "on_model_error_callback", ] assert set(plugin1.call_log) == set(expected_callbacks) diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index 810a6c448..1fdc24eff 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -247,6 +247,7 @@ class MockModel(BaseLlm): requests: list[LlmRequest] = [] responses: list[LlmResponse] + error: Union[Exception, None] = None response_index: int = -1 @classmethod @@ -255,7 +256,10 @@ def create( responses: Union[ list[types.Part], list[LlmResponse], list[str], list[list[types.Part]] ], + error: Union[Exception, None] = None, ): + if error and not responses: + return cls(responses=[], error=error) if not responses: return cls(responses=[]) elif isinstance(responses[0], LlmResponse): @@ -287,6 +291,8 @@ def generate_content( ) -> Generator[LlmResponse, None, None]: # Increasement of the index has to happen before the yield. self.response_index += 1 + if self.error: + raise self.error self.requests.append(llm_request) # yield LlmResponse(content=self.responses[self.response_index]) yield self.responses[self.response_index] @@ -297,6 +303,8 @@ async def generate_content_async( ) -> AsyncGenerator[LlmResponse, None]: # Increasement of the index has to happen before the yield. self.response_index += 1 + if self.error: + raise self.error self.requests.append(llm_request) yield self.responses[self.response_index]