Skip to content

Commit fa025d7

Browse files
jcpagadora737copybara-github
authored andcommitted
feat: Add a new option eval_storage_uri in adk web & adk eval to specify GCS bucket to store eval data
PiperOrigin-RevId: 774947795
1 parent 120cbab commit fa025d7

File tree

5 files changed

+139
-15
lines changed

5 files changed

+139
-15
lines changed

src/google/adk/cli/cli_tools_click.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@
3131
from . import cli_create
3232
from . import cli_deploy
3333
from .. import version
34+
from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
35+
from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
3436
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
3537
from ..sessions.in_memory_session_service import InMemorySessionService
3638
from .cli import run_cli
3739
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
3840
from .fast_api import get_fast_api_app
3941
from .utils import envs
42+
from .utils import evals
4043
from .utils import logs
4144

4245
LOG_LEVELS = click.Choice(
@@ -282,11 +285,21 @@ def cli_run(
282285
default=False,
283286
help="Optional. Whether to print detailed results on console or not.",
284287
)
288+
@click.option(
289+
"--eval_storage_uri",
290+
type=str,
291+
help=(
292+
"Optional. The evals storage URI to store agent evals,"
293+
" supported URIs: gs://<bucket name>."
294+
),
295+
default=None,
296+
)
285297
def cli_eval(
286298
agent_module_file_path: str,
287-
eval_set_file_path: tuple[str],
299+
eval_set_file_path: list[str],
288300
config_file_path: str,
289301
print_detailed_results: bool,
302+
eval_storage_uri: Optional[str] = None,
290303
):
291304
"""Evaluates an agent given the eval sets.
292305
@@ -338,12 +351,33 @@ def cli_eval(
338351
root_agent = get_root_agent(agent_module_file_path)
339352
reset_func = try_get_reset_func(agent_module_file_path)
340353

354+
gcs_eval_sets_manager = None
355+
eval_set_results_manager = None
356+
if eval_storage_uri:
357+
gcs_eval_managers = evals.create_gcs_eval_managers_from_uri(
358+
eval_storage_uri
359+
)
360+
gcs_eval_sets_manager = gcs_eval_managers.eval_sets_manager
361+
eval_set_results_manager = gcs_eval_managers.eval_set_results_manager
362+
else:
363+
eval_set_results_manager = LocalEvalSetResultsManager(
364+
agents_dir=os.path.dirname(agent_module_file_path)
365+
)
341366
eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
342367
eval_set_id_to_eval_cases = {}
343368

344369
# Read the eval_set files and get the cases.
345370
for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items():
346-
eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path)
371+
if gcs_eval_sets_manager:
372+
eval_set = gcs_eval_sets_manager._load_eval_set_from_blob(
373+
eval_set_file_path
374+
)
375+
if not eval_set:
376+
raise click.ClickException(
377+
f"Eval set {eval_set_file_path} not found in GCS."
378+
)
379+
else:
380+
eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path)
347381
eval_cases = eval_set.eval_cases
348382

349383
if eval_case_ids:
@@ -378,16 +412,13 @@ async def _collect_eval_results() -> list[EvalCaseResult]:
378412
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
379413

380414
# Write eval set results.
381-
local_eval_set_results_manager = LocalEvalSetResultsManager(
382-
agents_dir=os.path.dirname(agent_module_file_path)
383-
)
384415
eval_set_id_to_eval_results = collections.defaultdict(list)
385416
for eval_case_result in eval_results:
386417
eval_set_id = eval_case_result.eval_set_id
387418
eval_set_id_to_eval_results[eval_set_id].append(eval_case_result)
388419

389420
for eval_set_id, eval_case_results in eval_set_id_to_eval_results.items():
390-
local_eval_set_results_manager.save_eval_set_result(
421+
eval_set_results_manager.save_eval_set_result(
391422
app_name=os.path.basename(agent_module_file_path),
392423
eval_set_id=eval_set_id,
393424
eval_case_results=eval_case_results,
@@ -444,6 +475,15 @@ def decorator(func):
444475
),
445476
default=None,
446477
)
478+
@click.option(
479+
"--eval_storage_uri",
480+
type=str,
481+
help=(
482+
"Optional. The evals storage URI to store agent evals,"
483+
" supported URIs: gs://<bucket name>."
484+
),
485+
default=None,
486+
)
447487
@click.option(
448488
"--memory_service_uri",
449489
type=str,
@@ -564,6 +604,7 @@ def wrapper(*args, **kwargs):
564604
)
565605
def cli_web(
566606
agents_dir: str,
607+
eval_storage_uri: Optional[str] = None,
567608
log_level: str = "INFO",
568609
allow_origins: Optional[list[str]] = None,
569610
host: str = "127.0.0.1",
@@ -616,6 +657,7 @@ async def _lifespan(app: FastAPI):
616657
session_service_uri=session_service_uri,
617658
artifact_service_uri=artifact_service_uri,
618659
memory_service_uri=memory_service_uri,
660+
eval_storage_uri=eval_storage_uri,
619661
allow_origins=allow_origins,
620662
web=True,
621663
trace_to_cloud=trace_to_cloud,
@@ -654,6 +696,7 @@ async def _lifespan(app: FastAPI):
654696
)
655697
def cli_api_server(
656698
agents_dir: str,
699+
eval_storage_uri: Optional[str] = None,
657700
log_level: str = "INFO",
658701
allow_origins: Optional[list[str]] = None,
659702
host: str = "127.0.0.1",
@@ -685,6 +728,7 @@ def cli_api_server(
685728
session_service_uri=session_service_uri,
686729
artifact_service_uri=artifact_service_uri,
687730
memory_service_uri=memory_service_uri,
731+
eval_storage_uri=eval_storage_uri,
688732
allow_origins=allow_origins,
689733
web=False,
690734
trace_to_cloud=trace_to_cloud,
@@ -771,6 +815,15 @@ def cli_api_server(
771815
" version in the dev environment)"
772816
),
773817
)
818+
@click.option(
819+
"--eval_storage_uri",
820+
type=str,
821+
help=(
822+
"Optional. The evals storage URI to store agent evals,"
823+
" supported URIs: gs://<bucket name>."
824+
),
825+
default=None,
826+
)
774827
@adk_services_options()
775828
@deprecated_adk_services_options()
776829
@click.argument(
@@ -797,6 +850,7 @@ def cli_deploy_cloud_run(
797850
session_service_uri: Optional[str] = None,
798851
artifact_service_uri: Optional[str] = None,
799852
memory_service_uri: Optional[str] = None,
853+
eval_storage_uri: Optional[str] = None,
800854
session_db_url: Optional[str] = None, # Deprecated
801855
artifact_storage_uri: Optional[str] = None, # Deprecated
802856
):

src/google/adk/cli/fast_api.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
from ..evaluation.eval_metrics import EvalMetricResult
6666
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
6767
from ..evaluation.eval_result import EvalSetResult
68+
from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
69+
from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
6870
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
6971
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
7072
from ..events.event import Event
@@ -198,6 +200,7 @@ def get_fast_api_app(
198200
session_service_uri: Optional[str] = None,
199201
artifact_service_uri: Optional[str] = None,
200202
memory_service_uri: Optional[str] = None,
203+
eval_storage_uri: Optional[str] = None,
201204
allow_origins: Optional[list[str]] = None,
202205
web: bool,
203206
trace_to_cloud: bool = False,
@@ -256,8 +259,18 @@ async def internal_lifespan(app: FastAPI):
256259

257260
runner_dict = {}
258261

259-
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
260-
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
262+
# Set up eval managers.
263+
eval_sets_manager = None
264+
eval_set_results_manager = None
265+
if eval_storage_uri:
266+
gcs_eval_managers = evals.create_gcs_eval_managers_from_uri(
267+
eval_storage_uri
268+
)
269+
eval_sets_manager = gcs_eval_managers.eval_sets_manager
270+
eval_set_results_manager = gcs_eval_managers.eval_set_results_manager
271+
else:
272+
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
273+
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
261274

262275
# Build the Memory service
263276
if memory_service_uri:

src/google/adk/cli/utils/evals.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,36 @@
1414

1515
from __future__ import annotations
1616

17+
import dataclasses
18+
import os
1719
from typing import Any
1820
from typing import Tuple
1921

2022
from google.genai import types as genai_types
23+
from pydantic import alias_generators
24+
from pydantic import BaseModel
25+
from pydantic import ConfigDict
2126
from typing_extensions import deprecated
2227

2328
from ...evaluation.eval_case import IntermediateData
2429
from ...evaluation.eval_case import Invocation
30+
from ...evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
31+
from ...evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
2532
from ...sessions.session import Session
2633

2734

35+
class GcsEvalManagers(BaseModel):
36+
model_config = ConfigDict(
37+
alias_generator=alias_generators.to_camel,
38+
populate_by_name=True,
39+
arbitrary_types_allowed=True,
40+
)
41+
42+
eval_sets_manager: GcsEvalSetsManager
43+
44+
eval_set_results_manager: GcsEvalSetResultsManager
45+
46+
2847
@deprecated('Use convert_session_to_eval_invocations instead.')
2948
def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
3049
"""Converts a session data into eval format.
@@ -176,3 +195,37 @@ def convert_session_to_eval_invocations(session: Session) -> list[Invocation]:
176195
)
177196

178197
return invocations
198+
199+
200+
def create_gcs_eval_managers_from_uri(
201+
eval_storage_uri: str,
202+
) -> GcsEvalManagers:
203+
"""Creates GcsEvalManagers from eval_storage_uri.
204+
205+
Args:
206+
eval_storage_uri: The evals storage URI to use. Supported URIs:
207+
gs://<bucket name>. If a path is provided, the bucket will be extracted.
208+
209+
Returns:
210+
GcsEvalManagers: The GcsEvalManagers object.
211+
212+
Raises:
213+
ValueError: If the eval_storage_uri is not supported.
214+
"""
215+
if eval_storage_uri.startswith('gs://'):
216+
gcs_bucket = eval_storage_uri.split('://')[1]
217+
eval_sets_manager = GcsEvalSetsManager(
218+
bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT']
219+
)
220+
eval_set_results_manager = GcsEvalSetResultsManager(
221+
bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT']
222+
)
223+
return GcsEvalManagers(
224+
eval_sets_manager=eval_sets_manager,
225+
eval_set_results_manager=eval_set_results_manager,
226+
)
227+
else:
228+
raise ValueError(
229+
f'Unsupported evals storage URI: {eval_storage_uri}. Supported URIs:'
230+
' gs://<bucket name>'
231+
)

src/google/adk/evaluation/gcs_eval_sets_manager.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def _validate_id(self, id_name: str, id_value: str):
7272
f"Invalid {id_name}. {id_name} should have the `{pattern}` format",
7373
)
7474

75+
def _load_eval_set_from_blob(self, blob_name: str) -> Optional[EvalSet]:
76+
blob = self.bucket.blob(blob_name)
77+
if not blob.exists():
78+
return None
79+
eval_set_data = blob.download_as_text()
80+
return EvalSet.model_validate_json(eval_set_data)
81+
7582
def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet):
7683
"""Writes an EvalSet to GCS."""
7784
blob = self.bucket.blob(blob_name)
@@ -88,11 +95,7 @@ def _save_eval_set(self, app_name: str, eval_set_id: str, eval_set: EvalSet):
8895
def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]:
8996
"""Returns an EvalSet identified by an app_name and eval_set_id."""
9097
eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id)
91-
blob = self.bucket.blob(eval_set_blob_name)
92-
if not blob.exists():
93-
return None
94-
eval_set_data = blob.download_as_text()
95-
return EvalSet.model_validate_json(eval_set_data)
98+
return self._load_eval_set_from_blob(eval_set_blob_name)
9699

97100
@override
98101
def create_eval_set(self, app_name: str, eval_set_id: str):

tests/unittests/cli/test_fast_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
level=logging.INFO,
4141
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
4242
)
43-
logger = logging.getLogger(__name__)
43+
logger = logging.getLogger("google_adk." + __name__)
4444

4545

4646
# Here we create a dummy agent module that get_fast_api_app expects
@@ -138,6 +138,7 @@ async def mock_run_evals_for_fast_api(*args, **kwargs):
138138
final_eval_status=1, # Matches expected (assuming 1 is PASSED)
139139
user_id="test_user", # Placeholder, adapt if needed
140140
session_id="test_session_for_eval_case", # Placeholder
141+
eval_set_file="test_eval_set_file", # Placeholder
141142
overall_eval_metric_results=[{ # Matches expected
142143
"metricName": "tool_trajectory_avg_score",
143144
"threshold": 0.5,
@@ -372,7 +373,7 @@ def add_eval_case(self, app_name, eval_set_id, eval_case):
372373

373374
@pytest.fixture
374375
def mock_eval_set_results_manager():
375-
"""Create a mock eval set results manager."""
376+
"""Create a mock local eval set results manager."""
376377

377378
# Storage for eval set results.
378379
eval_set_results = {}

0 commit comments

Comments
 (0)