Skip to content

feat: Add implementation of BaseEvalService that runs evals locally #1756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/google/adk/evaluation/agent_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC
from abc import abstractmethod

from typing_extensions import override

from ..agents import BaseAgent


class AgentCreator(ABC):
"""Creates an Agent for the purposes of Eval."""

@abstractmethod
def get_agent(
self,
) -> BaseAgent:
"""Returns an instance of an Agent to be used for Eval purposes."""


class IdentityAgentCreator(AgentCreator):
"""An implementation of the AgentCreator interface that always returns a copy of the root agent."""

def __init__(self, root_agent: BaseAgent):
self._root_agent = root_agent

@override
def get_agent(
self,
) -> BaseAgent:
"""Returns a deep copy of the root agent."""
# TODO: Use Agent.clone() when the PR is merged.
return self._root_agent.model_copy(deep=True)
37 changes: 35 additions & 2 deletions src/google/adk/evaluation/base_eval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from abc import ABC
from abc import abstractmethod
from enum import Enum
from typing import AsyncGenerator
from typing import Optional

Expand Down Expand Up @@ -56,6 +57,19 @@ class InferenceConfig(BaseModel):
charges.""",
)

max_inference_parallelism: int = Field(
default=4,
description="""Number of parallel inferences to run during an Eval. Few
factors to consider while changing this value:

1) Your available quota with the model. Models tend to enforce per-minute or
per-second SLAs. Using a larger value could result in the eval quickly consuming
the quota.

2) The tools used by the Agent could also have their SLA. Using a larger value
could also overwhelm those tools.""",
)


class InferenceRequest(BaseModel):
"""Represent a request to perform inferences for the eval cases in an eval set."""
Expand Down Expand Up @@ -88,6 +102,14 @@ class InferenceRequest(BaseModel):
)


class InferenceStatus(Enum):
"""Status of the inference."""

UNKNOWN = 0
SUCCESS = 1
FAILURE = 2


class InferenceResult(BaseModel):
"""Contains inference results for a single eval case."""

Expand All @@ -106,14 +128,25 @@ class InferenceResult(BaseModel):
description="""Id of the eval case for which inferences were generated.""",
)

inferences: list[Invocation] = Field(
description="""Inferences obtained from the Agent for the eval case."""
inferences: Optional[list[Invocation]] = Field(
default=None,
description="""Inferences obtained from the Agent for the eval case.""",
)

session_id: Optional[str] = Field(
description="""Id of the inference session."""
)

status: InferenceStatus = Field(
default=InferenceStatus.UNKNOWN,
description="""Status of the inference.""",
)

error_message: Optional[str] = Field(
default=None,
description="""Error message if the inference failed.""",
)


class EvaluateRequest(BaseModel):
model_config = ConfigDict(
Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/evaluation/evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def _process_query(
async def _generate_inferences_from_root_agent(
invocations: list[Invocation],
root_agent: Agent,
reset_func: Any,
reset_func: Optional[Any] = None,
initial_session: Optional[SessionInput] = None,
session_id: Optional[str] = None,
session_service: Optional[BaseSessionService] = None,
Expand Down
185 changes: 185 additions & 0 deletions src/google/adk/evaluation/local_eval_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
import logging
from typing import AsyncGenerator
from typing import Callable
from typing import Optional
import uuid

from typing_extensions import override

from ..agents import Agent
from ..artifacts.base_artifact_service import BaseArtifactService
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..errors.not_found_error import NotFoundError
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..utils.feature_decorator import working_in_progress
from .agent_creator import AgentCreator
from .base_eval_service import BaseEvalService
from .base_eval_service import EvaluateRequest
from .base_eval_service import InferenceRequest
from .base_eval_service import InferenceResult
from .base_eval_service import InferenceStatus
from .eval_result import EvalCaseResult
from .eval_set import EvalCase
from .eval_set_results_manager import EvalSetResultsManager
from .eval_sets_manager import EvalSetsManager
from .evaluation_generator import EvaluationGenerator
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
from .metric_evaluator_registry import MetricEvaluatorRegistry


logger = logging.getLogger('google_adk.' + __name__)

EVAL_SESSION_ID_PREFIX = '___eval___session___'


def _get_session_id() -> str:
return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'


@working_in_progress("Incomplete feature, don't use yet")
class LocalEvalService(BaseEvalService):
"""An implementation of BaseEvalService, that runs the evals locally."""

def __init__(
self,
agent_creator: AgentCreator,
eval_sets_manager: EvalSetsManager,
metric_evaluator_registry: MetricEvaluatorRegistry = DEFAULT_METRIC_EVALUATOR_REGISTRY,
session_service: BaseSessionService = InMemorySessionService(),
artifact_service: BaseArtifactService = InMemoryArtifactService(),
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
session_id_supplier: Callable[[], str] = _get_session_id,
):
self._agent_creator = agent_creator
self._eval_sets_manager = eval_sets_manager
self._metric_evaluator_registry = metric_evaluator_registry
self._session_service = session_service
self._artifact_service = artifact_service
self._eval_set_results_manager = eval_set_results_manager
self._session_id_supplier = session_id_supplier

@override
async def perform_inference(
self,
inference_request: InferenceRequest,
) -> AsyncGenerator[InferenceResult, None]:
"""Returns InferenceResult obtained from the Agent as and when they are available.

Args:
inference_request: The request for generating inferences.
"""
# Get the eval set from the storage.
eval_set = self._eval_sets_manager.get_eval_set(
app_name=inference_request.app_name,
eval_set_id=inference_request.eval_set_id,
)

if not eval_set:
raise NotFoundError(
f'Eval set with id {inference_request.eval_set_id} not found for app'
f' {inference_request.app_name}'
)

# Select eval cases for which we need to run inferencing. If the inference
# request specified eval cases, then we use only those.
eval_cases = eval_set.eval_cases
if inference_request.eval_case_ids:
eval_cases = [
eval_case
for eval_case in eval_cases
if eval_case.eval_id in inference_request.eval_case_ids
]

root_agent = self._agent_creator.get_agent()

semaphore = asyncio.Semaphore(
value=inference_request.inference_config.max_inference_parallelism
)

async def run_inference(eval_case):
async with semaphore:
return await self._perform_inference_sigle_eval_item(
app_name=inference_request.app_name,
eval_set_id=inference_request.eval_set_id,
eval_case=eval_case,
root_agent=root_agent,
)

inference_results = [run_inference(eval_case) for eval_case in eval_cases]
for inference_result in asyncio.as_completed(inference_results):
yield await inference_result

@override
async def evaluate(
self,
evaluate_request: EvaluateRequest,
) -> AsyncGenerator[EvalCaseResult, None]:
"""Returns EvalCaseResult for each item as and when they are available.

Args:
evaluate_request: The request to perform metric evaluations on the
inferences.
"""
raise NotImplementedError()

async def _perform_inference_sigle_eval_item(
self,
app_name: str,
eval_set_id: str,
eval_case: EvalCase,
root_agent: Agent,
) -> InferenceResult:
initial_session = eval_case.session_input
session_id = self._session_id_supplier()
inference_result = InferenceResult(
app_name=app_name,
eval_set_id=eval_set_id,
eval_case_id=eval_case.eval_id,
session_id=session_id,
)

try:
inferences = (
await EvaluationGenerator._generate_inferences_from_root_agent(
invocations=eval_case.conversation,
root_agent=root_agent,
initial_session=initial_session,
session_id=session_id,
session_service=self._session_service,
artifact_service=self._artifact_service,
)
)

inference_result.inferences = inferences
inference_result.status = InferenceStatus.SUCCESS

return inference_result
except Exception as e:
# We intentionally catch the Exception as we don't failures to affect
# other inferences.
logger.error(
'Inference failed for eval case `%s` with error %s',
eval_case.eval_id,
e,
)
inference_result.status = InferenceStatus.FAILURE
inference_result.error_message = str(e)
return inference_result
78 changes: 78 additions & 0 deletions tests/unittests/evaluation/test_identity_agent_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.adk.agents import Agent
from google.adk.evaluation.agent_creator import IdentityAgentCreator
from google.adk.tools.tool_context import ToolContext
from google.genai import types


def _method_1(arg1: int, tool_context: ToolContext) -> (int, ToolContext):
return (arg1, tool_context)


async def _method_2(arg1: list[int]) -> list[int]:
return arg1


_TEST_SUB_AGENT = Agent(
model="gemini-2.0-flash",
name="test_sub_agent",
description="test sub-agent description",
instruction="test sub-agent instructions",
tools=[
_method_1,
_method_2,
],
)

_TEST_AGENT_1 = Agent(
model="gemini-2.0-flash",
name="test_agent_1",
description="test agent description",
instruction="test agent instructions",
tools=[
_method_1,
_method_2,
],
sub_agents=[_TEST_SUB_AGENT],
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)


def test_identity_agent_creator():
creator = IdentityAgentCreator(root_agent=_TEST_AGENT_1)

agent1 = creator.get_agent()
agent2 = creator.get_agent()

assert isinstance(agent1, Agent)
assert isinstance(agent2, Agent)

assert agent1 is not _TEST_AGENT_1 # Ensure it's a copy
assert agent2 is not _TEST_AGENT_1 # Ensure it's a copy
assert agent1 is not agent2 # Ensure different copies are returned

assert agent1.sub_agents[0] is not _TEST_SUB_AGENT # Ensure it's a copy
assert agent2.sub_agents[0] is not _TEST_SUB_AGENT # Ensure it's a copy
assert (
agent1.sub_agents[0] is not agent2.sub_agents[0]
) # Ensure different copies are returned
Loading
Loading