diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index d122c2150..58a2e7084 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -24,8 +24,16 @@ from typing import Optional import uuid +from typing_extensions import deprecated + from ..agents import Agent from ..artifacts.base_artifact_service import BaseArtifactService +from ..evaluation.base_eval_service import BaseEvalService +from ..evaluation.base_eval_service import EvaluateConfig +from ..evaluation.base_eval_service import EvaluateRequest +from ..evaluation.base_eval_service import InferenceConfig +from ..evaluation.base_eval_service import InferenceRequest +from ..evaluation.base_eval_service import InferenceResult from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..evaluation.eval_case import EvalCase from ..evaluation.eval_metrics import EvalMetric @@ -107,26 +115,79 @@ def try_get_reset_func(agent_module_file_path: str) -> Any: def parse_and_get_evals_to_run( - eval_set_file_path: tuple[str], + evals_to_run_info: list[str], ) -> dict[str, list[str]]: - """Returns a dictionary of eval sets to evals that should be run.""" + """Returns a dictionary of eval set info to evals that should be run. + + Args: + evals_to_run_info: While the structure is quite simple, a list of string, + each string actually is formatted with the following convention: + :[comma separated eval case ids] + """ eval_set_to_evals = {} - for input_eval_set in eval_set_file_path: + for input_eval_set in evals_to_run_info: evals = [] if ":" not in input_eval_set: - eval_set_file = input_eval_set + # We don't have any eval cases specified. This would be the case where the + # the user wants to run all eval cases in the eval set. + eval_set = input_eval_set else: - eval_set_file = input_eval_set.split(":")[0] + # There are eval cases that we need to parse. The user wants to run + # specific eval cases from the eval set. + eval_set = input_eval_set.split(":")[0] evals = input_eval_set.split(":")[1].split(",") - if eval_set_file not in eval_set_to_evals: - eval_set_to_evals[eval_set_file] = [] + if eval_set not in eval_set_to_evals: + eval_set_to_evals[eval_set] = [] - eval_set_to_evals[eval_set_file].extend(evals) + eval_set_to_evals[eval_set].extend(evals) return eval_set_to_evals +async def _collect_inferences( + inference_requests: list[InferenceRequest], + eval_service: BaseEvalService, +) -> list[InferenceResult]: + """Simple utility methods to collect inferences from an eval service. + + The method is intentionally kept private to prevent general usage. + """ + inference_results = [] + for inference_request in inference_requests: + async for inference_result in eval_service.perform_inference( + inference_request=inference_request + ): + inference_results.append(inference_result) + return inference_results + + +async def _collect_eval_results( + inference_results: list[InferenceResult], + eval_service: BaseEvalService, + eval_metrics: list[EvalMetric], +) -> list[EvalCaseResult]: + """Simple utility methods to collect eval results from an eval service. + + The method is intentionally kept private to prevent general usage. + """ + eval_results = [] + evaluate_request = EvaluateRequest( + inference_results=inference_results, + evaluate_config=EvaluateConfig(eval_metrics=eval_metrics), + ) + async for eval_result in eval_service.evaluate( + evaluate_request=evaluate_request + ): + eval_results.append(eval_result) + + return eval_results + + +@deprecated( + "This method is deprecated and will be removed in fututre release. Please" + " use LocalEvalService to define your custom evals." +) async def run_evals( eval_cases_by_eval_set_id: dict[str, list[EvalCase]], root_agent: Agent, diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 95e617781..1d40abdb5 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -277,6 +277,11 @@ def cli_run( ), ) @click.argument("eval_set_file_path", nargs=-1) +@click.option( + "--eval_set_id", + multiple=True, + help="Optional. The eval set id to use for evals.", +) @click.option("--config_file_path", help="Optional. The path to config file.") @click.option( "--print_detailed_results", @@ -297,6 +302,7 @@ def cli_run( def cli_eval( agent_module_file_path: str, eval_set_file_path: list[str], + eval_set_id: list[str], config_file_path: str, print_detailed_results: bool, eval_storage_uri: Optional[str] = None, @@ -314,12 +320,40 @@ def cli_eval( separated list of eval names and then add that as a suffix to the eval set file name, demarcated by a `:`. - For example, + For example, we have `sample_eval_set_file.json` file that has following the + eval cases: + sample_eval_set_file.json: + |....... eval_1 + |....... eval_2 + |....... eval_3 + |....... eval_4 + |....... eval_5 sample_eval_set_file.json:eval_1,eval_2,eval_3 This will only run eval_1, eval_2 and eval_3 from sample_eval_set_file.json. + EVAL_SET_ID: You can specify one or more eval set ids. + + For each eval set, all evals will be run by default. + + If you want to run only specific evals from a eval set, first create a comma + separated list of eval names and then add that as a suffix to the eval set + file name, demarcated by a `:`. + + For example, we have `sample_eval_set_id` that has following the eval cases: + sample_eval_set_id: + |....... eval_1 + |....... eval_2 + |....... eval_3 + |....... eval_4 + |....... eval_5 + + If we did: + sample_eval_set_id:eval_1,eval_2,eval_3 + + This will only run eval_1, eval_2 and eval_3 from sample_eval_set_id. + CONFIG_FILE_PATH: The path to config file. PRINT_DETAILED_RESULTS: Prints detailed results on the console. @@ -327,18 +361,28 @@ def cli_eval( envs.load_dotenv_for_agent(agent_module_file_path, ".") try: + from ..evaluation.base_eval_service import InferenceConfig + from ..evaluation.base_eval_service import InferenceRequest + from ..evaluation.eval_metrics import EvalMetric + from ..evaluation.eval_result import EvalCaseResult + from ..evaluation.evaluator import EvalStatus + from ..evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager + from ..evaluation.local_eval_service import LocalEvalService from ..evaluation.local_eval_sets_manager import load_eval_set_from_file - from .cli_eval import EvalCaseResult - from .cli_eval import EvalMetric - from .cli_eval import EvalStatus + from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager + from .cli_eval import _collect_eval_results + from .cli_eval import _collect_inferences from .cli_eval import get_evaluation_criteria_or_default from .cli_eval import get_root_agent from .cli_eval import parse_and_get_evals_to_run - from .cli_eval import run_evals - from .cli_eval import try_get_reset_func except ModuleNotFoundError: raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) + if eval_set_file_path and eval_set_id: + raise click.ClickException( + "Only one of eval_set_file_path or eval_set_id can be specified." + ) + evaluation_criteria = get_evaluation_criteria_or_default(config_file_path) eval_metrics = [] for metric_name, threshold in evaluation_criteria.items(): @@ -349,80 +393,87 @@ def cli_eval( print(f"Using evaluation criteria: {evaluation_criteria}") root_agent = get_root_agent(agent_module_file_path) - reset_func = try_get_reset_func(agent_module_file_path) - - gcs_eval_sets_manager = None + app_name = os.path.basename(agent_module_file_path) + agents_dir = os.path.dirname(agent_module_file_path) + eval_sets_manager = None eval_set_results_manager = None + if eval_storage_uri: gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( eval_storage_uri ) - gcs_eval_sets_manager = gcs_eval_managers.eval_sets_manager + eval_sets_manager = gcs_eval_managers.eval_sets_manager eval_set_results_manager = gcs_eval_managers.eval_set_results_manager else: - eval_set_results_manager = LocalEvalSetResultsManager( - agents_dir=os.path.dirname(agent_module_file_path) - ) - eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path) - eval_set_id_to_eval_cases = {} - - # Read the eval_set files and get the cases. - for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items(): - if gcs_eval_sets_manager: - eval_set = gcs_eval_sets_manager._load_eval_set_from_blob( - eval_set_file_path + eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) + + inference_requests = [] + if eval_set_file_path: + eval_sets_manager = InMemoryEvalSetsManager() + eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path) + + # Read the eval_set files and get the cases. + for ( + eval_set_file_path, + eval_case_ids, + ) in eval_set_file_path_to_evals.items(): + eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) + + eval_sets_manager.create_eval_set( + app_name=app_name, eval_set_id=eval_set.eval_set_id ) - if not eval_set: - raise click.ClickException( - f"Eval set {eval_set_file_path} not found in GCS." + for eval_case in eval_set.eval_cases: + eval_sets_manager.add_eval_case( + app_name=app_name, + eval_set_id=eval_set.eval_set_id, + eval_case=eval_case, ) - else: - eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) - eval_cases = eval_set.eval_cases - - if eval_case_ids: - # There are eval_ids that we should select. - eval_cases = [ - e for e in eval_set.eval_cases if e.eval_id in eval_case_ids - ] - - eval_set_id_to_eval_cases[eval_set.eval_set_id] = eval_cases - - async def _collect_eval_results() -> list[EvalCaseResult]: - session_service = InMemorySessionService() - eval_case_results = [] - async for eval_case_result in run_evals( - eval_set_id_to_eval_cases, - root_agent, - reset_func, - eval_metrics, - session_service=session_service, - ): - eval_case_result.session_details = await session_service.get_session( - app_name=os.path.basename(agent_module_file_path), - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, + inference_requests.append( + InferenceRequest( + app_name=app_name, + eval_set_id=eval_set.eval_set_id, + eval_case_ids=eval_case_ids, + inference_config=InferenceConfig(), + ) + ) + elif eval_set_id: + eval_sets_manager = ( + eval_sets_manager + if eval_storage_uri + else LocalEvalSetsManager(agents_dir=agents_dir) + ) + eval_set_id_to_eval_cases = parse_and_get_evals_to_run(eval_set_id) + for eval_set_id_key, eval_case_ids in eval_set_id_to_eval_cases.items(): + inference_requests.append( + InferenceRequest( + app_name=app_name, + eval_set_id=eval_set_id_key, + eval_case_ids=eval_case_ids, + inference_config=InferenceConfig(), + ) ) - eval_case_results.append(eval_case_result) - return eval_case_results try: - eval_results = asyncio.run(_collect_eval_results()) - except ModuleNotFoundError: - raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) + eval_service = LocalEvalService( + root_agent=root_agent, + eval_sets_manager=eval_sets_manager, + eval_set_results_manager=eval_set_results_manager, + ) - # Write eval set results. - eval_set_id_to_eval_results = collections.defaultdict(list) - for eval_case_result in eval_results: - eval_set_id = eval_case_result.eval_set_id - eval_set_id_to_eval_results[eval_set_id].append(eval_case_result) - - for eval_set_id, eval_case_results in eval_set_id_to_eval_results.items(): - eval_set_results_manager.save_eval_set_result( - app_name=os.path.basename(agent_module_file_path), - eval_set_id=eval_set_id, - eval_case_results=eval_case_results, + inference_results = asyncio.run( + _collect_inferences( + inference_requests=inference_requests, eval_service=eval_service + ) ) + eval_results = asyncio.run( + _collect_eval_results( + inference_results=inference_results, + eval_service=eval_service, + eval_metrics=eval_metrics, + ) + ) + except ModuleNotFoundError: + raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) print("*********************************************************************") eval_run_summary = {} @@ -1021,7 +1072,8 @@ def cli_deploy_agent_engine( Example: adk deploy agent_engine --project=[project] --region=[region] - --staging_bucket=[staging_bucket] --display_name=[app_name] path/to/my_agent + --staging_bucket=[staging_bucket] --display_name=[app_name] + path/to/my_agent """ try: cli_deploy.to_agent_engine( diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 8d6ca2e46..78a2c325f 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -61,6 +61,9 @@ from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..errors.not_found_error import NotFoundError +from ..evaluation.base_eval_service import InferenceConfig +from ..evaluation.base_eval_service import InferenceRequest +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import SessionInput from ..evaluation.eval_metrics import EvalMetric @@ -197,6 +200,7 @@ class RunEvalResult(common.BaseModel): final_eval_status: EvalStatus eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( deprecated=True, + default=[], description=( "This field is deprecated, use overall_eval_metric_results instead." ), @@ -643,7 +647,9 @@ async def run_eval( app_name: str, eval_set_id: str, req: RunEvalRequest ) -> list[RunEvalResult]: """Runs an eval given the details in the eval request.""" - from .cli_eval import run_evals + from ..evaluation.local_eval_service import LocalEvalService + from .cli_eval import _collect_eval_results + from .cli_eval import _collect_inferences # Create a mapping from eval set file to all the evals that needed to be # run. @@ -654,52 +660,52 @@ async def run_eval( status_code=400, detail=f"Eval set `{eval_set_id}` not found." ) - if req.eval_ids: - eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids] - eval_set_to_evals = {eval_set_id: eval_cases} - else: - logger.info("Eval ids to run list is empty. We will run all eval cases.") - eval_set_to_evals = {eval_set_id: eval_set.eval_cases} - root_agent = agent_loader.load_agent(app_name) - run_eval_results = [] + eval_case_results = [] try: - async for eval_case_result in run_evals( - eval_set_to_evals, - root_agent, - getattr(root_agent, "reset_data", None), - req.eval_metrics, + eval_service = LocalEvalService( + root_agent=root_agent, + eval_sets_manager=eval_sets_manager, + eval_set_results_manager=eval_set_results_manager, session_service=session_service, artifact_service=artifact_service, - ): - run_eval_results.append( - RunEvalResult( - app_name=app_name, - eval_set_file=eval_case_result.eval_set_file, - eval_set_id=eval_set_id, - eval_id=eval_case_result.eval_id, - final_eval_status=eval_case_result.final_eval_status, - eval_metric_results=eval_case_result.eval_metric_results, - overall_eval_metric_results=eval_case_result.overall_eval_metric_results, - eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, - ) - ) - eval_case_result.session_details = await session_service.get_session( - app_name=app_name, - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, - ) - eval_case_results.append(eval_case_result) + ) + inference_request = InferenceRequest( + app_name=app_name, + eval_set_id=eval_set.eval_set_id, + eval_case_ids=req.eval_ids, + inference_config=InferenceConfig(), + ) + inference_results = await _collect_inferences( + inference_requests=[inference_request], eval_service=eval_service + ) + + eval_case_results = await _collect_eval_results( + inference_results=inference_results, + eval_service=eval_service, + eval_metrics=req.eval_metrics, + ) except ModuleNotFoundError as e: logger.exception("%s", e) - raise HTTPException(status_code=400, detail=str(e)) from e + raise HTTPException( + status_code=400, detail=MISSING_EVAL_DEPENDENCIES_MESSAGE + ) from e - eval_set_results_manager.save_eval_set_result( - app_name, eval_set_id, eval_case_results - ) + run_eval_results = [] + for eval_case_result in eval_case_results: + run_eval_results.append( + RunEvalResult( + eval_set_file=eval_case_result.eval_set_file, + eval_set_id=eval_set_id, + eval_id=eval_case_result.eval_id, + final_eval_status=eval_case_result.final_eval_status, + overall_eval_metric_results=eval_case_result.overall_eval_metric_results, + eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, + user_id=eval_case_result.user_id, + session_id=eval_case_result.session_id, + ) + ) return run_eval_results diff --git a/src/google/adk/evaluation/eval_metrics.py b/src/google/adk/evaluation/eval_metrics.py index 8cf235427..bc7691773 100644 --- a/src/google/adk/evaluation/eval_metrics.py +++ b/src/google/adk/evaluation/eval_metrics.py @@ -36,6 +36,8 @@ class PrebuiltMetrics(Enum): RESPONSE_MATCH_SCORE = "response_match_score" + SAFETY_V1 = "safety_v1" + MetricName: TypeAlias = Union[str, PrebuiltMetrics] diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 240b56c38..745249a91 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -60,7 +60,6 @@ def _get_session_id() -> str: return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}' -@working_in_progress("Incomplete feature, don't use yet") class LocalEvalService(BaseEvalService): """An implementation of BaseEvalService, that runs the evals locally.""" diff --git a/src/google/adk/evaluation/metric_evaluator_registry.py b/src/google/adk/evaluation/metric_evaluator_registry.py index 99a700896..7351fa71d 100644 --- a/src/google/adk/evaluation/metric_evaluator_registry.py +++ b/src/google/adk/evaluation/metric_evaluator_registry.py @@ -22,6 +22,7 @@ from .eval_metrics import PrebuiltMetrics from .evaluator import Evaluator from .response_evaluator import ResponseEvaluator +from .safety_evaluator import SafetyEvaluatorV1 from .trajectory_evaluator import TrajectoryEvaluator logger = logging.getLogger("google_adk." + __name__) @@ -71,16 +72,21 @@ def _get_default_metric_evaluator_registry() -> MetricEvaluatorRegistry: metric_evaluator_registry = MetricEvaluatorRegistry() metric_evaluator_registry.register_evaluator( - metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE, - evaluator=type(TrajectoryEvaluator), + metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value, + evaluator=TrajectoryEvaluator, ) metric_evaluator_registry.register_evaluator( - metric_name=PrebuiltMetrics.RESPONSE_EVALUATION_SCORE, - evaluator=type(ResponseEvaluator), + metric_name=PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value, + evaluator=ResponseEvaluator, ) metric_evaluator_registry.register_evaluator( - metric_name=PrebuiltMetrics.RESPONSE_MATCH_SCORE, - evaluator=type(ResponseEvaluator), + metric_name=PrebuiltMetrics.RESPONSE_MATCH_SCORE.value, + evaluator=ResponseEvaluator, + ) + + metric_evaluator_registry.register_evaluator( + metric_name=PrebuiltMetrics.SAFETY_V1.value, + evaluator=SafetyEvaluatorV1, ) return metric_evaluator_registry diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index d4f9382e3..8475b7e06 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -32,6 +32,7 @@ from google.adk.evaluation.eval_case import Invocation from google.adk.evaluation.eval_result import EvalSetResult from google.adk.evaluation.eval_set import EvalSet +from google.adk.evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager from google.adk.events import Event from google.adk.runners import Runner from google.adk.sessions.base_session_service import ListSessionsResponse @@ -330,49 +331,7 @@ def mock_memory_service(): @pytest.fixture def mock_eval_sets_manager(): """Create a mock eval sets manager.""" - - # Storage for eval sets. - eval_sets = {} - - class MockEvalSetsManager: - """Mock eval sets manager.""" - - def create_eval_set(self, app_name, eval_set_id): - """Create an eval set.""" - if app_name not in eval_sets: - eval_sets[app_name] = {} - - if eval_set_id in eval_sets[app_name]: - raise ValueError(f"Eval set {eval_set_id} already exists.") - - eval_sets[app_name][eval_set_id] = EvalSet( - eval_set_id=eval_set_id, eval_cases=[] - ) - return eval_set_id - - def get_eval_set(self, app_name, eval_set_id): - """Get an eval set.""" - if app_name not in eval_sets: - raise ValueError(f"App {app_name} not found.") - if eval_set_id not in eval_sets[app_name]: - raise ValueError(f"Eval set {eval_set_id} not found in app {app_name}.") - return eval_sets[app_name][eval_set_id] - - def list_eval_sets(self, app_name): - """List eval sets.""" - if app_name not in eval_sets: - raise ValueError(f"App {app_name} not found.") - return list(eval_sets[app_name].keys()) - - def add_eval_case(self, app_name, eval_set_id, eval_case): - """Add an eval case to an eval set.""" - if app_name not in eval_sets: - raise ValueError(f"App {app_name} not found.") - if eval_set_id not in eval_sets[app_name]: - raise ValueError(f"Eval set {eval_set_id} not found in app {app_name}.") - eval_sets[app_name][eval_set_id].eval_cases.append(eval_case) - - return MockEvalSetsManager() + return InMemoryEvalSetsManager() @pytest.fixture diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index da45442a4..c0d26c2ce 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -23,18 +23,46 @@ from typing import Any from typing import Dict from typing import List -from typing import Optional from typing import Tuple +from unittest import mock import click from click.testing import CliRunner +from google.adk.agents.base_agent import BaseAgent from google.adk.cli import cli_tools_click -from google.adk.evaluation import local_eval_set_results_manager -from google.adk.sessions import Session +from google.adk.evaluation.eval_case import EvalCase +from google.adk.evaluation.eval_set import EvalSet +from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager +from google.adk.evaluation.local_eval_sets_manager import LocalEvalSetsManager from pydantic import BaseModel import pytest +class DummyAgent(BaseAgent): + + def __init__(self, name): + super().__init__(name=name) + self.sub_agents = [] + + +root_agent = DummyAgent(name="dummy_agent") + + +@pytest.fixture +def mock_load_eval_set_from_file(): + with mock.patch( + "google.adk.evaluation.local_eval_sets_manager.load_eval_set_from_file" + ) as mock_func: + yield mock_func + + +@pytest.fixture +def mock_get_root_agent(): + with mock.patch("google.adk.cli.cli_eval.get_root_agent") as mock_func: + mock_func.return_value = root_agent + yield mock_func + + # Helpers class _Recorder(BaseModel): """Callable that records every invocation.""" @@ -237,137 +265,73 @@ def test_cli_api_server_invokes_uvicorn( assert _patch_uvicorn.calls, "uvicorn.Server.run must be called" -def test_cli_eval_success_path( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch -) -> None: - """Test the success path of `adk eval` by fully executing it with a stub module, up to summary generation.""" - import asyncio - import sys - import types - - # stub cli_eval module - stub = types.ModuleType("google.adk.cli.cli_eval") - eval_sets_manager_stub = types.ModuleType( - "google.adk.evaluation.local_eval_sets_manager" - ) - - class _EvalMetric: - - def __init__(self, metric_name: str, threshold: float) -> None: - ... - - class _EvalCaseResult(BaseModel): - eval_set_id: str - eval_id: str - final_eval_status: Any - user_id: str - session_id: str - session_details: Optional[Session] = None - eval_metric_results: list = {} - overall_eval_metric_results: list = {} - eval_metric_result_per_invocation: list = {} - - class EvalCase(BaseModel): - eval_id: str +def test_cli_eval_with_eval_set_file_path( + mock_load_eval_set_from_file, + mock_get_root_agent, + tmp_path, +): + agent_path = tmp_path / "my_agent" + agent_path.mkdir() + (agent_path / "__init__.py").touch() - class EvalSet(BaseModel): - eval_set_id: str - eval_cases: list[EvalCase] + eval_set_file = tmp_path / "my_evals.json" + eval_set_file.write_text("{}") - def mock_save_eval_set_result(cls, *args, **kwargs): - return None - - monkeypatch.setattr( - local_eval_set_results_manager.LocalEvalSetResultsManager, - "save_eval_set_result", - mock_save_eval_set_result, + mock_load_eval_set_from_file.return_value = EvalSet( + eval_set_id="my_evals", + eval_cases=[EvalCase(eval_id="case1", conversation=[])], ) - # minimal enum-like namespace - _EvalStatus = types.SimpleNamespace(PASSED="PASSED", FAILED="FAILED") - - # helper funcs - stub.EvalMetric = _EvalMetric - stub.EvalCaseResult = _EvalCaseResult - stub.EvalStatus = _EvalStatus - stub.MISSING_EVAL_DEPENDENCIES_MESSAGE = "stub msg" - - stub.get_evaluation_criteria_or_default = lambda _p: {"foo": 1.0} - stub.get_root_agent = lambda _p: object() - stub.try_get_reset_func = lambda _p: None - stub.parse_and_get_evals_to_run = lambda _paths: {"set1.json": ["e1", "e2"]} - eval_sets_manager_stub.load_eval_set_from_file = lambda x, y: EvalSet( - eval_set_id="test_eval_set_id", - eval_cases=[EvalCase(eval_id="e1"), EvalCase(eval_id="e2")], + result = CliRunner().invoke( + cli_tools_click.cli_eval, + [str(agent_path), str(eval_set_file)], ) - # Create an async generator function for run_evals - async def mock_run_evals(*_a, **_k): - yield _EvalCaseResult( - eval_set_id="set1.json", - eval_id="e1", - final_eval_status=_EvalStatus.PASSED, - user_id="user", - session_id="session1", - overall_eval_metric_results=[{ - "metricName": "some_metric", - "threshold": 0.0, - "score": 1.0, - "evalStatus": _EvalStatus.PASSED, - }], - ) - yield _EvalCaseResult( - eval_set_id="set1.json", - eval_id="e2", - final_eval_status=_EvalStatus.FAILED, - user_id="user", - session_id="session2", - overall_eval_metric_results=[{ - "metricName": "some_metric", - "threshold": 0.0, - "score": 0.0, - "evalStatus": _EvalStatus.FAILED, - }], - ) - - stub.run_evals = mock_run_evals - - # Replace asyncio.run with a function that properly handles coroutines - def mock_asyncio_run(coro): - # Create a new event loop - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete(coro) - finally: - loop.close() - - monkeypatch.setattr(cli_tools_click.asyncio, "run", mock_asyncio_run) - - # inject stub - monkeypatch.setitem(sys.modules, "google.adk.cli.cli_eval", stub) - monkeypatch.setitem( - sys.modules, - "google.adk.evaluation.local_eval_sets_manager", - eval_sets_manager_stub, + assert result.exit_code == 0 + # Assert that we wrote eval set results + eval_set_results_manager = LocalEvalSetResultsManager( + agents_dir=str(tmp_path) ) - - # create dummy agent directory - agent_dir = tmp_path / "agent5" - agent_dir.mkdir() - (agent_dir / "__init__.py").touch() - - # inject monkeypatch - monkeypatch.setattr( - cli_tools_click.envs, "load_dotenv_for_agent", lambda *a, **k: None + eval_set_results = eval_set_results_manager.list_eval_set_results( + app_name="my_agent" + ) + assert len(eval_set_results) == 1 + + +def test_cli_eval_with_eval_set_id( + mock_get_root_agent, + tmp_path, +): + app_name = "test_app" + eval_set_id = "test_eval_set_id" + agent_path = tmp_path / app_name + agent_path.mkdir() + (agent_path / "__init__.py").touch() + + eval_sets_manager = LocalEvalSetsManager(agents_dir=str(tmp_path)) + eval_sets_manager.create_eval_set(app_name=app_name, eval_set_id=eval_set_id) + eval_sets_manager.add_eval_case( + app_name=app_name, + eval_set_id=eval_set_id, + eval_case=EvalCase(eval_id="case1", conversation=[]), + ) + eval_sets_manager.add_eval_case( + app_name=app_name, + eval_set_id=eval_set_id, + eval_case=EvalCase(eval_id="case2", conversation=[]), ) - runner = CliRunner() - result = runner.invoke( - cli_tools_click.main, - ["eval", str(agent_dir), str(tmp_path / "dummy_eval.json")], + result = CliRunner().invoke( + cli_tools_click.cli_eval, + [str(agent_path), "--eval_set_id", "test_eval_set_id:case1,case2"], ) assert result.exit_code == 0 - assert "Eval Run Summary" in result.output - assert "Tests passed: 1" in result.output - assert "Tests failed: 1" in result.output + # Assert that we wrote eval set results + eval_set_results_manager = LocalEvalSetResultsManager( + agents_dir=str(tmp_path) + ) + eval_set_results = eval_set_results_manager.list_eval_set_results( + app_name=app_name + ) + assert len(eval_set_results) == 2