From 3d4017874ce11657ce2338a1e13809407ddbb3a0 Mon Sep 17 00:00:00 2001 From: Ankur Sharma Date: Wed, 9 Jul 2025 13:32:29 -0700 Subject: [PATCH] feat: Use LocalEvalService to run all evals in cli and web We update both adk web run eval endpoint and adk eval cli to use the LocalService. The old method is marked as deprecated and will be removed in later PRs. PiperOrigin-RevId: 781188524 --- src/google/adk/cli/cli_eval.py | 77 ++++++- src/google/adk/cli/cli_tools_click.py | 186 +++++++++------ src/google/adk/cli/fast_api.py | 84 +++---- src/google/adk/evaluation/eval_metrics.py | 2 + .../adk/evaluation/local_eval_service.py | 1 - .../evaluation/metric_evaluator_registry.py | 18 +- tests/unittests/cli/test_fast_api.py | 45 +--- .../cli/utils/test_cli_tools_click.py | 214 ++++++++---------- 8 files changed, 338 insertions(+), 289 deletions(-) diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index d122c2150..58a2e7084 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -24,8 +24,16 @@ from typing import Optional import uuid +from typing_extensions import deprecated + from ..agents import Agent from ..artifacts.base_artifact_service import BaseArtifactService +from ..evaluation.base_eval_service import BaseEvalService +from ..evaluation.base_eval_service import EvaluateConfig +from ..evaluation.base_eval_service import EvaluateRequest +from ..evaluation.base_eval_service import InferenceConfig +from ..evaluation.base_eval_service import InferenceRequest +from ..evaluation.base_eval_service import InferenceResult from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..evaluation.eval_case import EvalCase from ..evaluation.eval_metrics import EvalMetric @@ -107,26 +115,79 @@ def try_get_reset_func(agent_module_file_path: str) -> Any: def parse_and_get_evals_to_run( - eval_set_file_path: tuple[str], + evals_to_run_info: list[str], ) -> dict[str, list[str]]: - """Returns a dictionary of eval sets to evals that should be run.""" + """Returns a dictionary of eval set info to evals that should be run. + + Args: + evals_to_run_info: While the structure is quite simple, a list of string, + each string actually is formatted with the following convention: + :[comma separated eval case ids] + """ eval_set_to_evals = {} - for input_eval_set in eval_set_file_path: + for input_eval_set in evals_to_run_info: evals = [] if ":" not in input_eval_set: - eval_set_file = input_eval_set + # We don't have any eval cases specified. This would be the case where the + # the user wants to run all eval cases in the eval set. + eval_set = input_eval_set else: - eval_set_file = input_eval_set.split(":")[0] + # There are eval cases that we need to parse. The user wants to run + # specific eval cases from the eval set. + eval_set = input_eval_set.split(":")[0] evals = input_eval_set.split(":")[1].split(",") - if eval_set_file not in eval_set_to_evals: - eval_set_to_evals[eval_set_file] = [] + if eval_set not in eval_set_to_evals: + eval_set_to_evals[eval_set] = [] - eval_set_to_evals[eval_set_file].extend(evals) + eval_set_to_evals[eval_set].extend(evals) return eval_set_to_evals +async def _collect_inferences( + inference_requests: list[InferenceRequest], + eval_service: BaseEvalService, +) -> list[InferenceResult]: + """Simple utility methods to collect inferences from an eval service. + + The method is intentionally kept private to prevent general usage. + """ + inference_results = [] + for inference_request in inference_requests: + async for inference_result in eval_service.perform_inference( + inference_request=inference_request + ): + inference_results.append(inference_result) + return inference_results + + +async def _collect_eval_results( + inference_results: list[InferenceResult], + eval_service: BaseEvalService, + eval_metrics: list[EvalMetric], +) -> list[EvalCaseResult]: + """Simple utility methods to collect eval results from an eval service. + + The method is intentionally kept private to prevent general usage. + """ + eval_results = [] + evaluate_request = EvaluateRequest( + inference_results=inference_results, + evaluate_config=EvaluateConfig(eval_metrics=eval_metrics), + ) + async for eval_result in eval_service.evaluate( + evaluate_request=evaluate_request + ): + eval_results.append(eval_result) + + return eval_results + + +@deprecated( + "This method is deprecated and will be removed in fututre release. Please" + " use LocalEvalService to define your custom evals." +) async def run_evals( eval_cases_by_eval_set_id: dict[str, list[EvalCase]], root_agent: Agent, diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 95e617781..1d40abdb5 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -277,6 +277,11 @@ def cli_run( ), ) @click.argument("eval_set_file_path", nargs=-1) +@click.option( + "--eval_set_id", + multiple=True, + help="Optional. The eval set id to use for evals.", +) @click.option("--config_file_path", help="Optional. The path to config file.") @click.option( "--print_detailed_results", @@ -297,6 +302,7 @@ def cli_run( def cli_eval( agent_module_file_path: str, eval_set_file_path: list[str], + eval_set_id: list[str], config_file_path: str, print_detailed_results: bool, eval_storage_uri: Optional[str] = None, @@ -314,12 +320,40 @@ def cli_eval( separated list of eval names and then add that as a suffix to the eval set file name, demarcated by a `:`. - For example, + For example, we have `sample_eval_set_file.json` file that has following the + eval cases: + sample_eval_set_file.json: + |....... eval_1 + |....... eval_2 + |....... eval_3 + |....... eval_4 + |....... eval_5 sample_eval_set_file.json:eval_1,eval_2,eval_3 This will only run eval_1, eval_2 and eval_3 from sample_eval_set_file.json. + EVAL_SET_ID: You can specify one or more eval set ids. + + For each eval set, all evals will be run by default. + + If you want to run only specific evals from a eval set, first create a comma + separated list of eval names and then add that as a suffix to the eval set + file name, demarcated by a `:`. + + For example, we have `sample_eval_set_id` that has following the eval cases: + sample_eval_set_id: + |....... eval_1 + |....... eval_2 + |....... eval_3 + |....... eval_4 + |....... eval_5 + + If we did: + sample_eval_set_id:eval_1,eval_2,eval_3 + + This will only run eval_1, eval_2 and eval_3 from sample_eval_set_id. + CONFIG_FILE_PATH: The path to config file. PRINT_DETAILED_RESULTS: Prints detailed results on the console. @@ -327,18 +361,28 @@ def cli_eval( envs.load_dotenv_for_agent(agent_module_file_path, ".") try: + from ..evaluation.base_eval_service import InferenceConfig + from ..evaluation.base_eval_service import InferenceRequest + from ..evaluation.eval_metrics import EvalMetric + from ..evaluation.eval_result import EvalCaseResult + from ..evaluation.evaluator import EvalStatus + from ..evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager + from ..evaluation.local_eval_service import LocalEvalService from ..evaluation.local_eval_sets_manager import load_eval_set_from_file - from .cli_eval import EvalCaseResult - from .cli_eval import EvalMetric - from .cli_eval import EvalStatus + from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager + from .cli_eval import _collect_eval_results + from .cli_eval import _collect_inferences from .cli_eval import get_evaluation_criteria_or_default from .cli_eval import get_root_agent from .cli_eval import parse_and_get_evals_to_run - from .cli_eval import run_evals - from .cli_eval import try_get_reset_func except ModuleNotFoundError: raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) + if eval_set_file_path and eval_set_id: + raise click.ClickException( + "Only one of eval_set_file_path or eval_set_id can be specified." + ) + evaluation_criteria = get_evaluation_criteria_or_default(config_file_path) eval_metrics = [] for metric_name, threshold in evaluation_criteria.items(): @@ -349,80 +393,87 @@ def cli_eval( print(f"Using evaluation criteria: {evaluation_criteria}") root_agent = get_root_agent(agent_module_file_path) - reset_func = try_get_reset_func(agent_module_file_path) - - gcs_eval_sets_manager = None + app_name = os.path.basename(agent_module_file_path) + agents_dir = os.path.dirname(agent_module_file_path) + eval_sets_manager = None eval_set_results_manager = None + if eval_storage_uri: gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( eval_storage_uri ) - gcs_eval_sets_manager = gcs_eval_managers.eval_sets_manager + eval_sets_manager = gcs_eval_managers.eval_sets_manager eval_set_results_manager = gcs_eval_managers.eval_set_results_manager else: - eval_set_results_manager = LocalEvalSetResultsManager( - agents_dir=os.path.dirname(agent_module_file_path) - ) - eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path) - eval_set_id_to_eval_cases = {} - - # Read the eval_set files and get the cases. - for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items(): - if gcs_eval_sets_manager: - eval_set = gcs_eval_sets_manager._load_eval_set_from_blob( - eval_set_file_path + eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) + + inference_requests = [] + if eval_set_file_path: + eval_sets_manager = InMemoryEvalSetsManager() + eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path) + + # Read the eval_set files and get the cases. + for ( + eval_set_file_path, + eval_case_ids, + ) in eval_set_file_path_to_evals.items(): + eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) + + eval_sets_manager.create_eval_set( + app_name=app_name, eval_set_id=eval_set.eval_set_id ) - if not eval_set: - raise click.ClickException( - f"Eval set {eval_set_file_path} not found in GCS." + for eval_case in eval_set.eval_cases: + eval_sets_manager.add_eval_case( + app_name=app_name, + eval_set_id=eval_set.eval_set_id, + eval_case=eval_case, ) - else: - eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) - eval_cases = eval_set.eval_cases - - if eval_case_ids: - # There are eval_ids that we should select. - eval_cases = [ - e for e in eval_set.eval_cases if e.eval_id in eval_case_ids - ] - - eval_set_id_to_eval_cases[eval_set.eval_set_id] = eval_cases - - async def _collect_eval_results() -> list[EvalCaseResult]: - session_service = InMemorySessionService() - eval_case_results = [] - async for eval_case_result in run_evals( - eval_set_id_to_eval_cases, - root_agent, - reset_func, - eval_metrics, - session_service=session_service, - ): - eval_case_result.session_details = await session_service.get_session( - app_name=os.path.basename(agent_module_file_path), - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, + inference_requests.append( + InferenceRequest( + app_name=app_name, + eval_set_id=eval_set.eval_set_id, + eval_case_ids=eval_case_ids, + inference_config=InferenceConfig(), + ) + ) + elif eval_set_id: + eval_sets_manager = ( + eval_sets_manager + if eval_storage_uri + else LocalEvalSetsManager(agents_dir=agents_dir) + ) + eval_set_id_to_eval_cases = parse_and_get_evals_to_run(eval_set_id) + for eval_set_id_key, eval_case_ids in eval_set_id_to_eval_cases.items(): + inference_requests.append( + InferenceRequest( + app_name=app_name, + eval_set_id=eval_set_id_key, + eval_case_ids=eval_case_ids, + inference_config=InferenceConfig(), + ) ) - eval_case_results.append(eval_case_result) - return eval_case_results try: - eval_results = asyncio.run(_collect_eval_results()) - except ModuleNotFoundError: - raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) + eval_service = LocalEvalService( + root_agent=root_agent, + eval_sets_manager=eval_sets_manager, + eval_set_results_manager=eval_set_results_manager, + ) - # Write eval set results. - eval_set_id_to_eval_results = collections.defaultdict(list) - for eval_case_result in eval_results: - eval_set_id = eval_case_result.eval_set_id - eval_set_id_to_eval_results[eval_set_id].append(eval_case_result) - - for eval_set_id, eval_case_results in eval_set_id_to_eval_results.items(): - eval_set_results_manager.save_eval_set_result( - app_name=os.path.basename(agent_module_file_path), - eval_set_id=eval_set_id, - eval_case_results=eval_case_results, + inference_results = asyncio.run( + _collect_inferences( + inference_requests=inference_requests, eval_service=eval_service + ) ) + eval_results = asyncio.run( + _collect_eval_results( + inference_results=inference_results, + eval_service=eval_service, + eval_metrics=eval_metrics, + ) + ) + except ModuleNotFoundError: + raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) print("*********************************************************************") eval_run_summary = {} @@ -1021,7 +1072,8 @@ def cli_deploy_agent_engine( Example: adk deploy agent_engine --project=[project] --region=[region] - --staging_bucket=[staging_bucket] --display_name=[app_name] path/to/my_agent + --staging_bucket=[staging_bucket] --display_name=[app_name] + path/to/my_agent """ try: cli_deploy.to_agent_engine( diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 8d6ca2e46..78a2c325f 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -61,6 +61,9 @@ from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..errors.not_found_error import NotFoundError +from ..evaluation.base_eval_service import InferenceConfig +from ..evaluation.base_eval_service import InferenceRequest +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import SessionInput from ..evaluation.eval_metrics import EvalMetric @@ -197,6 +200,7 @@ class RunEvalResult(common.BaseModel): final_eval_status: EvalStatus eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( deprecated=True, + default=[], description=( "This field is deprecated, use overall_eval_metric_results instead." ), @@ -643,7 +647,9 @@ async def run_eval( app_name: str, eval_set_id: str, req: RunEvalRequest ) -> list[RunEvalResult]: """Runs an eval given the details in the eval request.""" - from .cli_eval import run_evals + from ..evaluation.local_eval_service import LocalEvalService + from .cli_eval import _collect_eval_results + from .cli_eval import _collect_inferences # Create a mapping from eval set file to all the evals that needed to be # run. @@ -654,52 +660,52 @@ async def run_eval( status_code=400, detail=f"Eval set `{eval_set_id}` not found." ) - if req.eval_ids: - eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids] - eval_set_to_evals = {eval_set_id: eval_cases} - else: - logger.info("Eval ids to run list is empty. We will run all eval cases.") - eval_set_to_evals = {eval_set_id: eval_set.eval_cases} - root_agent = agent_loader.load_agent(app_name) - run_eval_results = [] + eval_case_results = [] try: - async for eval_case_result in run_evals( - eval_set_to_evals, - root_agent, - getattr(root_agent, "reset_data", None), - req.eval_metrics, + eval_service = LocalEvalService( + root_agent=root_agent, + eval_sets_manager=eval_sets_manager, + eval_set_results_manager=eval_set_results_manager, session_service=session_service, artifact_service=artifact_service, - ): - run_eval_results.append( - RunEvalResult( - app_name=app_name, - eval_set_file=eval_case_result.eval_set_file, - eval_set_id=eval_set_id, - eval_id=eval_case_result.eval_id, - final_eval_status=eval_case_result.final_eval_status, - eval_metric_results=eval_case_result.eval_metric_results, - overall_eval_metric_results=eval_case_result.overall_eval_metric_results, - eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, - ) - ) - eval_case_result.session_details = await session_service.get_session( - app_name=app_name, - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, - ) - eval_case_results.append(eval_case_result) + ) + inference_request = InferenceRequest( + app_name=app_name, + eval_set_id=eval_set.eval_set_id, + eval_case_ids=req.eval_ids, + inference_config=InferenceConfig(), + ) + inference_results = await _collect_inferences( + inference_requests=[inference_request], eval_service=eval_service + ) + + eval_case_results = await _collect_eval_results( + inference_results=inference_results, + eval_service=eval_service, + eval_metrics=req.eval_metrics, + ) except ModuleNotFoundError as e: logger.exception("%s", e) - raise HTTPException(status_code=400, detail=str(e)) from e + raise HTTPException( + status_code=400, detail=MISSING_EVAL_DEPENDENCIES_MESSAGE + ) from e - eval_set_results_manager.save_eval_set_result( - app_name, eval_set_id, eval_case_results - ) + run_eval_results = [] + for eval_case_result in eval_case_results: + run_eval_results.append( + RunEvalResult( + eval_set_file=eval_case_result.eval_set_file, + eval_set_id=eval_set_id, + eval_id=eval_case_result.eval_id, + final_eval_status=eval_case_result.final_eval_status, + overall_eval_metric_results=eval_case_result.overall_eval_metric_results, + eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, + user_id=eval_case_result.user_id, + session_id=eval_case_result.session_id, + ) + ) return run_eval_results diff --git a/src/google/adk/evaluation/eval_metrics.py b/src/google/adk/evaluation/eval_metrics.py index 8cf235427..bc7691773 100644 --- a/src/google/adk/evaluation/eval_metrics.py +++ b/src/google/adk/evaluation/eval_metrics.py @@ -36,6 +36,8 @@ class PrebuiltMetrics(Enum): RESPONSE_MATCH_SCORE = "response_match_score" + SAFETY_V1 = "safety_v1" + MetricName: TypeAlias = Union[str, PrebuiltMetrics] diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 240b56c38..745249a91 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -60,7 +60,6 @@ def _get_session_id() -> str: return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}' -@working_in_progress("Incomplete feature, don't use yet") class LocalEvalService(BaseEvalService): """An implementation of BaseEvalService, that runs the evals locally.""" diff --git a/src/google/adk/evaluation/metric_evaluator_registry.py b/src/google/adk/evaluation/metric_evaluator_registry.py index 99a700896..7351fa71d 100644 --- a/src/google/adk/evaluation/metric_evaluator_registry.py +++ b/src/google/adk/evaluation/metric_evaluator_registry.py @@ -22,6 +22,7 @@ from .eval_metrics import PrebuiltMetrics from .evaluator import Evaluator from .response_evaluator import ResponseEvaluator +from .safety_evaluator import SafetyEvaluatorV1 from .trajectory_evaluator import TrajectoryEvaluator logger = logging.getLogger("google_adk." + __name__) @@ -71,16 +72,21 @@ def _get_default_metric_evaluator_registry() -> MetricEvaluatorRegistry: metric_evaluator_registry = MetricEvaluatorRegistry() metric_evaluator_registry.register_evaluator( - metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE, - evaluator=type(TrajectoryEvaluator), + metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value, + evaluator=TrajectoryEvaluator, ) metric_evaluator_registry.register_evaluator( - metric_name=PrebuiltMetrics.RESPONSE_EVALUATION_SCORE, - evaluator=type(ResponseEvaluator), + metric_name=PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value, + evaluator=ResponseEvaluator, ) metric_evaluator_registry.register_evaluator( - metric_name=PrebuiltMetrics.RESPONSE_MATCH_SCORE, - evaluator=type(ResponseEvaluator), + metric_name=PrebuiltMetrics.RESPONSE_MATCH_SCORE.value, + evaluator=ResponseEvaluator, + ) + + metric_evaluator_registry.register_evaluator( + metric_name=PrebuiltMetrics.SAFETY_V1.value, + evaluator=SafetyEvaluatorV1, ) return metric_evaluator_registry diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index d4f9382e3..8475b7e06 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -32,6 +32,7 @@ from google.adk.evaluation.eval_case import Invocation from google.adk.evaluation.eval_result import EvalSetResult from google.adk.evaluation.eval_set import EvalSet +from google.adk.evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager from google.adk.events import Event from google.adk.runners import Runner from google.adk.sessions.base_session_service import ListSessionsResponse @@ -330,49 +331,7 @@ def mock_memory_service(): @pytest.fixture def mock_eval_sets_manager(): """Create a mock eval sets manager.""" - - # Storage for eval sets. - eval_sets = {} - - class MockEvalSetsManager: - """Mock eval sets manager.""" - - def create_eval_set(self, app_name, eval_set_id): - """Create an eval set.""" - if app_name not in eval_sets: - eval_sets[app_name] = {} - - if eval_set_id in eval_sets[app_name]: - raise ValueError(f"Eval set {eval_set_id} already exists.") - - eval_sets[app_name][eval_set_id] = EvalSet( - eval_set_id=eval_set_id, eval_cases=[] - ) - return eval_set_id - - def get_eval_set(self, app_name, eval_set_id): - """Get an eval set.""" - if app_name not in eval_sets: - raise ValueError(f"App {app_name} not found.") - if eval_set_id not in eval_sets[app_name]: - raise ValueError(f"Eval set {eval_set_id} not found in app {app_name}.") - return eval_sets[app_name][eval_set_id] - - def list_eval_sets(self, app_name): - """List eval sets.""" - if app_name not in eval_sets: - raise ValueError(f"App {app_name} not found.") - return list(eval_sets[app_name].keys()) - - def add_eval_case(self, app_name, eval_set_id, eval_case): - """Add an eval case to an eval set.""" - if app_name not in eval_sets: - raise ValueError(f"App {app_name} not found.") - if eval_set_id not in eval_sets[app_name]: - raise ValueError(f"Eval set {eval_set_id} not found in app {app_name}.") - eval_sets[app_name][eval_set_id].eval_cases.append(eval_case) - - return MockEvalSetsManager() + return InMemoryEvalSetsManager() @pytest.fixture diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index da45442a4..c0d26c2ce 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -23,18 +23,46 @@ from typing import Any from typing import Dict from typing import List -from typing import Optional from typing import Tuple +from unittest import mock import click from click.testing import CliRunner +from google.adk.agents.base_agent import BaseAgent from google.adk.cli import cli_tools_click -from google.adk.evaluation import local_eval_set_results_manager -from google.adk.sessions import Session +from google.adk.evaluation.eval_case import EvalCase +from google.adk.evaluation.eval_set import EvalSet +from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager +from google.adk.evaluation.local_eval_sets_manager import LocalEvalSetsManager from pydantic import BaseModel import pytest +class DummyAgent(BaseAgent): + + def __init__(self, name): + super().__init__(name=name) + self.sub_agents = [] + + +root_agent = DummyAgent(name="dummy_agent") + + +@pytest.fixture +def mock_load_eval_set_from_file(): + with mock.patch( + "google.adk.evaluation.local_eval_sets_manager.load_eval_set_from_file" + ) as mock_func: + yield mock_func + + +@pytest.fixture +def mock_get_root_agent(): + with mock.patch("google.adk.cli.cli_eval.get_root_agent") as mock_func: + mock_func.return_value = root_agent + yield mock_func + + # Helpers class _Recorder(BaseModel): """Callable that records every invocation.""" @@ -237,137 +265,73 @@ def test_cli_api_server_invokes_uvicorn( assert _patch_uvicorn.calls, "uvicorn.Server.run must be called" -def test_cli_eval_success_path( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch -) -> None: - """Test the success path of `adk eval` by fully executing it with a stub module, up to summary generation.""" - import asyncio - import sys - import types - - # stub cli_eval module - stub = types.ModuleType("google.adk.cli.cli_eval") - eval_sets_manager_stub = types.ModuleType( - "google.adk.evaluation.local_eval_sets_manager" - ) - - class _EvalMetric: - - def __init__(self, metric_name: str, threshold: float) -> None: - ... - - class _EvalCaseResult(BaseModel): - eval_set_id: str - eval_id: str - final_eval_status: Any - user_id: str - session_id: str - session_details: Optional[Session] = None - eval_metric_results: list = {} - overall_eval_metric_results: list = {} - eval_metric_result_per_invocation: list = {} - - class EvalCase(BaseModel): - eval_id: str +def test_cli_eval_with_eval_set_file_path( + mock_load_eval_set_from_file, + mock_get_root_agent, + tmp_path, +): + agent_path = tmp_path / "my_agent" + agent_path.mkdir() + (agent_path / "__init__.py").touch() - class EvalSet(BaseModel): - eval_set_id: str - eval_cases: list[EvalCase] + eval_set_file = tmp_path / "my_evals.json" + eval_set_file.write_text("{}") - def mock_save_eval_set_result(cls, *args, **kwargs): - return None - - monkeypatch.setattr( - local_eval_set_results_manager.LocalEvalSetResultsManager, - "save_eval_set_result", - mock_save_eval_set_result, + mock_load_eval_set_from_file.return_value = EvalSet( + eval_set_id="my_evals", + eval_cases=[EvalCase(eval_id="case1", conversation=[])], ) - # minimal enum-like namespace - _EvalStatus = types.SimpleNamespace(PASSED="PASSED", FAILED="FAILED") - - # helper funcs - stub.EvalMetric = _EvalMetric - stub.EvalCaseResult = _EvalCaseResult - stub.EvalStatus = _EvalStatus - stub.MISSING_EVAL_DEPENDENCIES_MESSAGE = "stub msg" - - stub.get_evaluation_criteria_or_default = lambda _p: {"foo": 1.0} - stub.get_root_agent = lambda _p: object() - stub.try_get_reset_func = lambda _p: None - stub.parse_and_get_evals_to_run = lambda _paths: {"set1.json": ["e1", "e2"]} - eval_sets_manager_stub.load_eval_set_from_file = lambda x, y: EvalSet( - eval_set_id="test_eval_set_id", - eval_cases=[EvalCase(eval_id="e1"), EvalCase(eval_id="e2")], + result = CliRunner().invoke( + cli_tools_click.cli_eval, + [str(agent_path), str(eval_set_file)], ) - # Create an async generator function for run_evals - async def mock_run_evals(*_a, **_k): - yield _EvalCaseResult( - eval_set_id="set1.json", - eval_id="e1", - final_eval_status=_EvalStatus.PASSED, - user_id="user", - session_id="session1", - overall_eval_metric_results=[{ - "metricName": "some_metric", - "threshold": 0.0, - "score": 1.0, - "evalStatus": _EvalStatus.PASSED, - }], - ) - yield _EvalCaseResult( - eval_set_id="set1.json", - eval_id="e2", - final_eval_status=_EvalStatus.FAILED, - user_id="user", - session_id="session2", - overall_eval_metric_results=[{ - "metricName": "some_metric", - "threshold": 0.0, - "score": 0.0, - "evalStatus": _EvalStatus.FAILED, - }], - ) - - stub.run_evals = mock_run_evals - - # Replace asyncio.run with a function that properly handles coroutines - def mock_asyncio_run(coro): - # Create a new event loop - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete(coro) - finally: - loop.close() - - monkeypatch.setattr(cli_tools_click.asyncio, "run", mock_asyncio_run) - - # inject stub - monkeypatch.setitem(sys.modules, "google.adk.cli.cli_eval", stub) - monkeypatch.setitem( - sys.modules, - "google.adk.evaluation.local_eval_sets_manager", - eval_sets_manager_stub, + assert result.exit_code == 0 + # Assert that we wrote eval set results + eval_set_results_manager = LocalEvalSetResultsManager( + agents_dir=str(tmp_path) ) - - # create dummy agent directory - agent_dir = tmp_path / "agent5" - agent_dir.mkdir() - (agent_dir / "__init__.py").touch() - - # inject monkeypatch - monkeypatch.setattr( - cli_tools_click.envs, "load_dotenv_for_agent", lambda *a, **k: None + eval_set_results = eval_set_results_manager.list_eval_set_results( + app_name="my_agent" + ) + assert len(eval_set_results) == 1 + + +def test_cli_eval_with_eval_set_id( + mock_get_root_agent, + tmp_path, +): + app_name = "test_app" + eval_set_id = "test_eval_set_id" + agent_path = tmp_path / app_name + agent_path.mkdir() + (agent_path / "__init__.py").touch() + + eval_sets_manager = LocalEvalSetsManager(agents_dir=str(tmp_path)) + eval_sets_manager.create_eval_set(app_name=app_name, eval_set_id=eval_set_id) + eval_sets_manager.add_eval_case( + app_name=app_name, + eval_set_id=eval_set_id, + eval_case=EvalCase(eval_id="case1", conversation=[]), + ) + eval_sets_manager.add_eval_case( + app_name=app_name, + eval_set_id=eval_set_id, + eval_case=EvalCase(eval_id="case2", conversation=[]), ) - runner = CliRunner() - result = runner.invoke( - cli_tools_click.main, - ["eval", str(agent_dir), str(tmp_path / "dummy_eval.json")], + result = CliRunner().invoke( + cli_tools_click.cli_eval, + [str(agent_path), "--eval_set_id", "test_eval_set_id:case1,case2"], ) assert result.exit_code == 0 - assert "Eval Run Summary" in result.output - assert "Tests passed: 1" in result.output - assert "Tests failed: 1" in result.output + # Assert that we wrote eval set results + eval_set_results_manager = LocalEvalSetResultsManager( + agents_dir=str(tmp_path) + ) + eval_set_results = eval_set_results_manager.list_eval_set_results( + app_name=app_name + ) + assert len(eval_set_results) == 2