Skip to content

feat: Use LocalEvalService to run all evals in cli and web #1979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 69 additions & 8 deletions src/google/adk/cli/cli_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@
from typing import Optional
import uuid

from typing_extensions import deprecated

from ..agents import Agent
from ..artifacts.base_artifact_service import BaseArtifactService
from ..evaluation.base_eval_service import BaseEvalService
from ..evaluation.base_eval_service import EvaluateConfig
from ..evaluation.base_eval_service import EvaluateRequest
from ..evaluation.base_eval_service import InferenceConfig
from ..evaluation.base_eval_service import InferenceRequest
from ..evaluation.base_eval_service import InferenceResult
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
from ..evaluation.eval_case import EvalCase
from ..evaluation.eval_metrics import EvalMetric
Expand Down Expand Up @@ -107,26 +115,79 @@ def try_get_reset_func(agent_module_file_path: str) -> Any:


def parse_and_get_evals_to_run(
eval_set_file_path: tuple[str],
evals_to_run_info: list[str],
) -> dict[str, list[str]]:
"""Returns a dictionary of eval sets to evals that should be run."""
"""Returns a dictionary of eval set info to evals that should be run.

Args:
evals_to_run_info: While the structure is quite simple, a list of string,
each string actually is formatted with the following convention:
<eval_set_file_path | eval_set_id>:[comma separated eval case ids]
"""
eval_set_to_evals = {}
for input_eval_set in eval_set_file_path:
for input_eval_set in evals_to_run_info:
evals = []
if ":" not in input_eval_set:
eval_set_file = input_eval_set
# We don't have any eval cases specified. This would be the case where the
# the user wants to run all eval cases in the eval set.
eval_set = input_eval_set
else:
eval_set_file = input_eval_set.split(":")[0]
# There are eval cases that we need to parse. The user wants to run
# specific eval cases from the eval set.
eval_set = input_eval_set.split(":")[0]
evals = input_eval_set.split(":")[1].split(",")

if eval_set_file not in eval_set_to_evals:
eval_set_to_evals[eval_set_file] = []
if eval_set not in eval_set_to_evals:
eval_set_to_evals[eval_set] = []

eval_set_to_evals[eval_set_file].extend(evals)
eval_set_to_evals[eval_set].extend(evals)

return eval_set_to_evals


async def _collect_inferences(
inference_requests: list[InferenceRequest],
eval_service: BaseEvalService,
) -> list[InferenceResult]:
"""Simple utility methods to collect inferences from an eval service.

The method is intentionally kept private to prevent general usage.
"""
inference_results = []
for inference_request in inference_requests:
async for inference_result in eval_service.perform_inference(
inference_request=inference_request
):
inference_results.append(inference_result)
return inference_results


async def _collect_eval_results(
inference_results: list[InferenceResult],
eval_service: BaseEvalService,
eval_metrics: list[EvalMetric],
) -> list[EvalCaseResult]:
"""Simple utility methods to collect eval results from an eval service.

The method is intentionally kept private to prevent general usage.
"""
eval_results = []
evaluate_request = EvaluateRequest(
inference_results=inference_results,
evaluate_config=EvaluateConfig(eval_metrics=eval_metrics),
)
async for eval_result in eval_service.evaluate(
evaluate_request=evaluate_request
):
eval_results.append(eval_result)

return eval_results


@deprecated(
"This method is deprecated and will be removed in fututre release. Please"
" use LocalEvalService to define your custom evals."
)
async def run_evals(
eval_cases_by_eval_set_id: dict[str, list[EvalCase]],
root_agent: Agent,
Expand Down
186 changes: 119 additions & 67 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ def cli_run(
),
)
@click.argument("eval_set_file_path", nargs=-1)
@click.option(
"--eval_set_id",
multiple=True,
help="Optional. The eval set id to use for evals.",
)
@click.option("--config_file_path", help="Optional. The path to config file.")
@click.option(
"--print_detailed_results",
Expand All @@ -297,6 +302,7 @@ def cli_run(
def cli_eval(
agent_module_file_path: str,
eval_set_file_path: list[str],
eval_set_id: list[str],
config_file_path: str,
print_detailed_results: bool,
eval_storage_uri: Optional[str] = None,
Expand All @@ -314,31 +320,69 @@ def cli_eval(
separated list of eval names and then add that as a suffix to the eval set
file name, demarcated by a `:`.

For example,
For example, we have `sample_eval_set_file.json` file that has following the
eval cases:
sample_eval_set_file.json:
|....... eval_1
|....... eval_2
|....... eval_3
|....... eval_4
|....... eval_5

sample_eval_set_file.json:eval_1,eval_2,eval_3

This will only run eval_1, eval_2 and eval_3 from sample_eval_set_file.json.

EVAL_SET_ID: You can specify one or more eval set ids.

For each eval set, all evals will be run by default.

If you want to run only specific evals from a eval set, first create a comma
separated list of eval names and then add that as a suffix to the eval set
file name, demarcated by a `:`.

For example, we have `sample_eval_set_id` that has following the eval cases:
sample_eval_set_id:
|....... eval_1
|....... eval_2
|....... eval_3
|....... eval_4
|....... eval_5

If we did:
sample_eval_set_id:eval_1,eval_2,eval_3

This will only run eval_1, eval_2 and eval_3 from sample_eval_set_id.

CONFIG_FILE_PATH: The path to config file.

PRINT_DETAILED_RESULTS: Prints detailed results on the console.
"""
envs.load_dotenv_for_agent(agent_module_file_path, ".")

try:
from ..evaluation.base_eval_service import InferenceConfig
from ..evaluation.base_eval_service import InferenceRequest
from ..evaluation.eval_metrics import EvalMetric
from ..evaluation.eval_result import EvalCaseResult
from ..evaluation.evaluator import EvalStatus
from ..evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager
from ..evaluation.local_eval_service import LocalEvalService
from ..evaluation.local_eval_sets_manager import load_eval_set_from_file
from .cli_eval import EvalCaseResult
from .cli_eval import EvalMetric
from .cli_eval import EvalStatus
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
from .cli_eval import _collect_eval_results
from .cli_eval import _collect_inferences
from .cli_eval import get_evaluation_criteria_or_default
from .cli_eval import get_root_agent
from .cli_eval import parse_and_get_evals_to_run
from .cli_eval import run_evals
from .cli_eval import try_get_reset_func
except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)

if eval_set_file_path and eval_set_id:
raise click.ClickException(
"Only one of eval_set_file_path or eval_set_id can be specified."
)

evaluation_criteria = get_evaluation_criteria_or_default(config_file_path)
eval_metrics = []
for metric_name, threshold in evaluation_criteria.items():
Expand All @@ -349,80 +393,87 @@ def cli_eval(
print(f"Using evaluation criteria: {evaluation_criteria}")

root_agent = get_root_agent(agent_module_file_path)
reset_func = try_get_reset_func(agent_module_file_path)

gcs_eval_sets_manager = None
app_name = os.path.basename(agent_module_file_path)
agents_dir = os.path.dirname(agent_module_file_path)
eval_sets_manager = None
eval_set_results_manager = None

if eval_storage_uri:
gcs_eval_managers = evals.create_gcs_eval_managers_from_uri(
eval_storage_uri
)
gcs_eval_sets_manager = gcs_eval_managers.eval_sets_manager
eval_sets_manager = gcs_eval_managers.eval_sets_manager
eval_set_results_manager = gcs_eval_managers.eval_set_results_manager
else:
eval_set_results_manager = LocalEvalSetResultsManager(
agents_dir=os.path.dirname(agent_module_file_path)
)
eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
eval_set_id_to_eval_cases = {}

# Read the eval_set files and get the cases.
for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items():
if gcs_eval_sets_manager:
eval_set = gcs_eval_sets_manager._load_eval_set_from_blob(
eval_set_file_path
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)

inference_requests = []
if eval_set_file_path:
eval_sets_manager = InMemoryEvalSetsManager()
eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path)

# Read the eval_set files and get the cases.
for (
eval_set_file_path,
eval_case_ids,
) in eval_set_file_path_to_evals.items():
eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path)

eval_sets_manager.create_eval_set(
app_name=app_name, eval_set_id=eval_set.eval_set_id
)
if not eval_set:
raise click.ClickException(
f"Eval set {eval_set_file_path} not found in GCS."
for eval_case in eval_set.eval_cases:
eval_sets_manager.add_eval_case(
app_name=app_name,
eval_set_id=eval_set.eval_set_id,
eval_case=eval_case,
)
else:
eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path)
eval_cases = eval_set.eval_cases

if eval_case_ids:
# There are eval_ids that we should select.
eval_cases = [
e for e in eval_set.eval_cases if e.eval_id in eval_case_ids
]

eval_set_id_to_eval_cases[eval_set.eval_set_id] = eval_cases

async def _collect_eval_results() -> list[EvalCaseResult]:
session_service = InMemorySessionService()
eval_case_results = []
async for eval_case_result in run_evals(
eval_set_id_to_eval_cases,
root_agent,
reset_func,
eval_metrics,
session_service=session_service,
):
eval_case_result.session_details = await session_service.get_session(
app_name=os.path.basename(agent_module_file_path),
user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id,
inference_requests.append(
InferenceRequest(
app_name=app_name,
eval_set_id=eval_set.eval_set_id,
eval_case_ids=eval_case_ids,
inference_config=InferenceConfig(),
)
)
elif eval_set_id:
eval_sets_manager = (
eval_sets_manager
if eval_storage_uri
else LocalEvalSetsManager(agents_dir=agents_dir)
)
eval_set_id_to_eval_cases = parse_and_get_evals_to_run(eval_set_id)
for eval_set_id_key, eval_case_ids in eval_set_id_to_eval_cases.items():
inference_requests.append(
InferenceRequest(
app_name=app_name,
eval_set_id=eval_set_id_key,
eval_case_ids=eval_case_ids,
inference_config=InferenceConfig(),
)
)
eval_case_results.append(eval_case_result)
return eval_case_results

try:
eval_results = asyncio.run(_collect_eval_results())
except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
eval_service = LocalEvalService(
root_agent=root_agent,
eval_sets_manager=eval_sets_manager,
eval_set_results_manager=eval_set_results_manager,
)

# Write eval set results.
eval_set_id_to_eval_results = collections.defaultdict(list)
for eval_case_result in eval_results:
eval_set_id = eval_case_result.eval_set_id
eval_set_id_to_eval_results[eval_set_id].append(eval_case_result)

for eval_set_id, eval_case_results in eval_set_id_to_eval_results.items():
eval_set_results_manager.save_eval_set_result(
app_name=os.path.basename(agent_module_file_path),
eval_set_id=eval_set_id,
eval_case_results=eval_case_results,
inference_results = asyncio.run(
_collect_inferences(
inference_requests=inference_requests, eval_service=eval_service
)
)
eval_results = asyncio.run(
_collect_eval_results(
inference_results=inference_results,
eval_service=eval_service,
eval_metrics=eval_metrics,
)
)
except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)

print("*********************************************************************")
eval_run_summary = {}
Expand Down Expand Up @@ -1021,7 +1072,8 @@ def cli_deploy_agent_engine(
Example:

adk deploy agent_engine --project=[project] --region=[region]
--staging_bucket=[staging_bucket] --display_name=[app_name] path/to/my_agent
--staging_bucket=[staging_bucket] --display_name=[app_name]
path/to/my_agent
"""
try:
cli_deploy.to_agent_engine(
Expand Down
Loading