Skip to content

Commit 8d53484

Browse files
ankursharmaszshehov
authored andcommitted
chore: Update ResponseEvaluator to use newer version of Eval SDK
Also, - removed functionality that was marked deprecated from the ResponseEvaluator class. - Added unit test cases PiperOrigin-RevId: 778568884
1 parent 6ed7eea commit 8d53484

File tree

3 files changed

+193
-359
lines changed

3 files changed

+193
-359
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ a2a = [
8585

8686
eval = [
8787
# go/keep-sorted start
88-
"google-cloud-aiplatform[evaluation]>=1.87.0",
88+
"google-cloud-aiplatform[evaluation]>=1.100.0",
8989
"pandas>=2.2.3",
9090
"tabulate>=0.9.0",
9191
"rouge-score>=0.1.2",

src/google/adk/evaluation/response_evaluator.py

Lines changed: 28 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,15 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any
17+
import os
1818
from typing import Optional
1919

2020
from google.genai import types as genai_types
2121
import pandas as pd
22-
from tabulate import tabulate
23-
from typing_extensions import deprecated
2422
from typing_extensions import override
25-
from vertexai.preview.evaluation import EvalTask
26-
from vertexai.preview.evaluation import MetricPromptTemplateExamples
23+
from vertexai import Client as VertexAiClient
24+
from vertexai import types as vertexai_types
2725

28-
from .eval_case import IntermediateData
2926
from .eval_case import Invocation
3027
from .eval_metrics import EvalMetric
3128
from .evaluator import EvalStatus
@@ -57,7 +54,7 @@ def __init__(
5754
metric_name = eval_metric.metric_name
5855

5956
if "response_evaluation_score" == metric_name:
60-
self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
57+
self._metric_name = vertexai_types.PrebuiltMetric.COHERENCE
6158
elif "response_match_score" == metric_name:
6259
self._metric_name = "response_match_score"
6360
else:
@@ -87,17 +84,11 @@ def evaluate_invocations(
8784
prompt = self._get_text(expected.user_content)
8885
reference = self._get_text(expected.final_response)
8986
response = self._get_text(actual.final_response)
90-
actual_tool_use = self._get_tool_use_trajectory(actual.intermediate_data)
91-
reference_trajectory = self._get_tool_use_trajectory(
92-
expected.intermediate_data
93-
)
9487

9588
eval_case = {
9689
"prompt": prompt,
9790
"reference": reference,
9891
"response": response,
99-
"actual_tool_user": actual_tool_use,
100-
"reference_trajectory": reference_trajectory,
10192
}
10293

10394
eval_case_result = ResponseEvaluator._perform_eval(
@@ -112,11 +103,15 @@ def evaluate_invocations(
112103
eval_status=self._get_eval_status(score),
113104
)
114105
)
115-
total_score += score
116-
num_invocations += 1
106+
107+
if score:
108+
total_score += score
109+
num_invocations += 1
117110

118111
if per_invocation_results:
119-
overall_score = total_score / num_invocations
112+
overall_score = (
113+
total_score / num_invocations if num_invocations > 0 else None
114+
)
120115
return EvaluationResult(
121116
overall_score=overall_score,
122117
overall_eval_status=self._get_eval_status(overall_score),
@@ -131,150 +126,31 @@ def _get_text(self, content: Optional[genai_types.Content]) -> str:
131126

132127
return ""
133128

134-
def _get_tool_use_trajectory(
135-
self, intermediate_data: Optional[IntermediateData]
136-
) -> list[dict[str, Any]]:
137-
tool_use_trajectory = []
138-
if not intermediate_data:
139-
return tool_use_trajectory
140-
141-
for function_call in intermediate_data.tool_uses:
142-
tool_use_trajectory.append({
143-
"tool_name": function_call.name,
144-
"tool_input": function_call.args or {},
145-
})
146-
147-
return tool_use_trajectory
148-
149-
def _get_score(self, eval_result) -> float:
150-
return eval_result.summary_metrics[f"{self._metric_name}/mean"].item()
151-
152-
def _get_eval_status(self, score: float):
153-
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
154-
155-
@staticmethod
156-
@deprecated(
157-
"This method has been deprecated and will be removed soon. Please use"
158-
" evaluate_invocations instead."
159-
)
160-
def evaluate(
161-
raw_eval_dataset: list[list[dict[str, Any]]],
162-
evaluation_criteria: list[str],
163-
*,
164-
print_detailed_results: bool = False,
165-
):
166-
r"""Returns the value of requested evaluation metrics.
167-
168-
Args:
169-
raw_eval_dataset: The dataset that will be evaluated.
170-
evaluation_criteria: The evaluation criteria to be used. This method
171-
support two criteria, `response_evaluation_score` and
172-
`response_match_score`.
173-
print_detailed_results: Prints detailed results on the console. This is
174-
usually helpful during debugging.
175-
176-
A note on evaluation_criteria:
177-
`response_match_score`: This metric compares the agents final natural
178-
language response with the expected final response, stored in the
179-
"reference" field in test/eval files. We use Rouge metric to compare the
180-
two responses.
181-
182-
Value Range: [0, 1]. A score closer to 0 means poor similarity between
183-
response and reference. A score closer to 1 means strong similarity
184-
between response and reference.
185-
186-
`response_evaluation_score`: Uses LLM to evalaute coherence of the
187-
response, including tool use. This is pointwise metric.
188-
189-
Value range: [0, 5], where 0 means that the agent's response is not
190-
coherent, while 5 means it is . High values are good.
191-
A note on raw_eval_dataset:
192-
The dataset should be a list session, where each session is represented
193-
as a list of interaction that need evaluation. Each evaluation is
194-
represented as a dictionary that is expected to have values for the
195-
following keys:
129+
def _get_score(self, eval_result) -> Optional[float]:
130+
if eval_result and eval_result.summary_metrics:
131+
return eval_result.summary_metrics[0].mean_score
196132

197-
1) query
198-
2) response
199-
3) acutal_tool_use
200-
4) expected_tool_use
201-
5) reference
133+
return None
202134

203-
Here is a sample eval_dataset value with one entry:
204-
[
205-
[
206-
{
207-
"query": "roll a die for me",
208-
"response": "I rolled a 16 sided die and got 13.\n",
209-
"expected_tool_use": [
210-
{
211-
"tool_name": "roll_die",
212-
"tool_input": {
213-
"sides": 16
214-
}
215-
}
216-
],
217-
"acutal_tool_use": [
218-
{
219-
"tool_name": "roll_die",
220-
"tool_input": {
221-
"sides": 16
222-
}
223-
}
224-
],
225-
"reference": "I rolled a 16 sided die and got 13.\n"
226-
}
227-
]
228-
]
229-
"""
230-
if not raw_eval_dataset:
231-
raise ValueError("The evaluation dataset is empty.")
232-
233-
metrics = ResponseEvaluator._get_metrics(
234-
raw_eval_dataset, evaluation_criteria
235-
)
236-
flattened_queries = [
237-
item for sublist in raw_eval_dataset for item in sublist
238-
]
239-
eval_dataset = pd.DataFrame(flattened_queries).rename(
240-
columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
241-
)
242-
243-
eval_result = ResponseEvaluator._perform_eval(
244-
dataset=eval_dataset, metrics=metrics
245-
)
246-
247-
if print_detailed_results:
248-
ResponseEvaluator._print_results(eval_result)
249-
return eval_result.summary_metrics
135+
def _get_eval_status(self, score: Optional[float]):
136+
if score:
137+
return (
138+
EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
139+
)
250140

251-
@staticmethod
252-
def _get_metrics(raw_eval_dataset, criteria):
253-
metrics = []
254-
if (
255-
"response_evaluation_score" in criteria
256-
and "query" in raw_eval_dataset[0][0]
257-
and "expected_tool_use" in raw_eval_dataset[0][0]
258-
):
259-
metrics.append(MetricPromptTemplateExamples.Pointwise.COHERENCE)
260-
if (
261-
"response_match_score" in criteria
262-
and "reference" in raw_eval_dataset[0][0]
263-
):
264-
metrics.append("rouge_1")
265-
return metrics
141+
return EvalStatus.NOT_EVALUATED
266142

267143
@staticmethod
268144
def _perform_eval(dataset, metrics):
269145
"""This method hides away the call to external service.
270146
271147
Primarily helps with unit testing.
272148
"""
273-
eval_task = EvalTask(dataset=dataset, metrics=metrics)
149+
project_id = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))
150+
location = os.environ.get("GOOGLE_CLOUD_REGION")
151+
client = VertexAiClient(project=project_id, location=location)
274152

275-
return eval_task.evaluate()
276-
277-
@staticmethod
278-
def _print_results(eval_result):
279-
print("Evaluation Summary Metrics:", eval_result.summary_metrics)
280-
print(tabulate(eval_result.metrics_table, headers="keys", tablefmt="grid"))
153+
return client.evals.evaluate(
154+
dataset=vertexai_types.EvaluationDataset(eval_dataset_df=dataset),
155+
metrics=metrics,
156+
)

0 commit comments

Comments
 (0)