Skip to content

feat: Add sample plugin for logging #1982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from abc import ABC
import asyncio
import datetime
from enum import Enum
import inspect
import logging
from typing import AsyncGenerator
from typing import Callable
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
Expand All @@ -36,6 +38,7 @@
from ...agents.run_config import StreamingMode
from ...agents.transcription_entry import TranscriptionEntry
from ...events.event import Event
from ...models.base_llm import ModelErrorStrategy
from ...models.base_llm_connection import BaseLlmConnection
from ...models.llm_request import LlmRequest
from ...models.llm_response import LlmResponse
Expand Down Expand Up @@ -521,7 +524,13 @@ async def _call_llm_async(
with tracer.start_as_current_span('call_llm'):
if invocation_context.run_config.support_cfc:
invocation_context.live_request_queue = LiveRequestQueue()
async for llm_response in self.run_live(invocation_context):
responses_generator = lambda: self.run_live(invocation_context)
async for llm_response in self._run_and_handle_error(
responses_generator,
invocation_context,
llm_request,
model_response_event,
):
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
Expand All @@ -540,10 +549,16 @@ async def _call_llm_async(
# the counter beyond the max set value, then the execution is stopped
# right here, and exception is thrown.
invocation_context.increment_llm_call_count()
async for llm_response in llm.generate_content_async(
responses_generator = lambda: llm.generate_content_async(
llm_request,
stream=invocation_context.run_config.streaming_mode
== StreamingMode.SSE,
)
async for llm_response in self._run_and_handle_error(
responses_generator,
invocation_context,
llm_request,
model_response_event,
):
trace_call_llm(
invocation_context,
Expand Down Expand Up @@ -660,6 +675,54 @@ def _finalize_model_response_event(

return model_response_event

async def _run_and_handle_error(
self,
response_generator: Callable[..., AsyncGenerator[LlmResponse, None]],
invocation_context: InvocationContext,
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
"""Runs the response generator and processes the error with plugins.
Args:
response_generator: The response generator to run.
invocation_context: The invocation context.
llm_request: The LLM request.
model_response_event: The model response event.
Yields:
A generator of LlmResponse.
"""
while True:
try:
responses_generator_instance = response_generator()
async for response in responses_generator_instance:
yield response
break
except Exception as model_error:
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
outcome = (
await invocation_context.plugin_manager.run_on_model_error_callback(
callback_context=callback_context,
llm_request=llm_request,
error=model_error,
)
)
# Retry the LLM call if the plugin outcome is RETRY.
if outcome == ModelErrorStrategy.RETRY:
continue

# If the plugin outcome is PASS, we can break the loop.
if outcome == ModelErrorStrategy.PASS:
break
if outcome is not None:
yield outcome
break
else:
raise model_error

def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
from ...agents.llm_agent import LlmAgent

Expand Down
18 changes: 15 additions & 3 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,21 @@ async def handle_function_calls_async(

# Step 3: Otherwise, proceed calling the tool normally.
if function_response is None:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
try:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
except Exception as tool_error:
error_response = await invocation_context.plugin_manager.run_on_tool_error_callback(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
if error_response is not None:
function_response = error_response
else:
raise tool_error

# Step 4: Check if plugin after_tool_callback overrides the function
# response.
Expand Down
5 changes: 5 additions & 0 deletions src/google/adk/models/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from .llm_response import LlmResponse


class ModelErrorStrategy:
RETRY = 'RETRY'
PASS = 'PASS'


class BaseLlm(BaseModel):
"""The BaseLLM class.
Expand Down
56 changes: 55 additions & 1 deletion src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from ..agents.base_agent import BaseAgent
from ..agents.callback_context import CallbackContext
from ..events.event import Event
from ..models.base_llm import ModelErrorStrategy
from ..models.llm_request import LlmRequest
from ..models.llm_response import LlmResponse
from ..tools.base_tool import BaseTool
from ..utils.feature_decorator import working_in_progress

if TYPE_CHECKING:
from ..agents.invocation_context import InvocationContext
Expand Down Expand Up @@ -265,6 +265,34 @@ async def after_model_callback(
"""
pass

async def on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse | ModelErrorStrategy]:
"""Callback executed when a model call encounters an error.
This callback provides an opportunity to handle model errors gracefully,
potentially providing alternative responses or recovery mechanisms.
Args:
callback_context: The context for the current agent call.
llm_request: The request that was sent to the model when the error
occurred.
error: The exception that was raised during model execution.
Returns:
An optional LlmResponse. If an LlmResponse is returned, it will be used
instead of propagating the error.
Returning `ModelErrorStrategy.RETRY` will retry the LLM call.
Returning `ModelErrorStrategy.PASS` will allow the LLM call to
proceed normally and ignore the error.
Returning `None` allows the original error to be raised.
"""
pass

async def before_tool_callback(
self,
*,
Expand Down Expand Up @@ -315,3 +343,29 @@ async def after_tool_callback(
result.
"""
pass

async def on_tool_error_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
"""Callback executed when a tool call encounters an error.
This callback provides an opportunity to handle tool errors gracefully,
potentially providing alternative responses or recovery mechanisms.
Args:
tool: The tool instance that encountered an error.
tool_args: The arguments that were passed to the tool.
tool_context: The context specific to the tool execution.
error: The exception that was raised during tool execution.
Returns:
An optional dictionary. If a dictionary is returned, it will be used as
the tool response instead of propagating the error. Returning `None`
allows the original error to be raised.
"""
pass
Loading